From acbd882d837d8ac402c2aa2a940de2315ae3f118 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Mon, 28 Dec 2020 11:58:56 -0500 Subject: [PATCH 01/91] Initialize main project files. --- pyro/contrib/mue/__init__.py | 2 ++ pyro/contrib/mue/biosequenceloaders.py | 2 ++ pyro/contrib/mue/statearrangers.py | 2 ++ pyro/contrib/mue/variablelengthhmm.py | 2 ++ 4 files changed, 8 insertions(+) create mode 100644 pyro/contrib/mue/__init__.py create mode 100644 pyro/contrib/mue/biosequenceloaders.py create mode 100644 pyro/contrib/mue/statearrangers.py create mode 100644 pyro/contrib/mue/variablelengthhmm.py diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py new file mode 100644 index 0000000000..d6960608d6 --- /dev/null +++ b/pyro/contrib/mue/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/pyro/contrib/mue/biosequenceloaders.py b/pyro/contrib/mue/biosequenceloaders.py new file mode 100644 index 0000000000..d6960608d6 --- /dev/null +++ b/pyro/contrib/mue/biosequenceloaders.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py new file mode 100644 index 0000000000..d6960608d6 --- /dev/null +++ b/pyro/contrib/mue/statearrangers.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py new file mode 100644 index 0000000000..d6960608d6 --- /dev/null +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 From cbf6102954e3fb8d8988b4d29a529943a34ad163 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Mon, 28 Dec 2020 15:59:57 -0500 Subject: [PATCH 02/91] Module for converting MuE parameters (in particular, profile HMM) to standard HMM parameters. --- pyro/contrib/mue/statearrangers.py | 152 +++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index d6960608d6..a0f530b223 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -1,2 +1,154 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import torch +from pyro.nn import PyroModule + + +def mg2k(m, g): + """Convert from (m, g) indexing to k indexing.""" + return 2*m + 1 - g + + +class profile(PyroModule): + + def __init__(self, M, dtype=torch.float64, epsilon=1e-32): + super().__init__() + self.M = M + self.K = 2*(M+1) + self.dtype = dtype + self.epsilon = epsilon + + self._make_transfer() + + def _make_transfer(self): + """Set up linear transformations (transfer matrices) for converting + from profile HMM parameters to standard HMM parameters.""" + M, K = self.M, self.K + + # Overview: + # r -> insertion parameters + # u -> deletion parameters + # indices: m in {0, ..., M} and j in {0, 1, 2}; final index corresponds + # to simplex dimensions, i.e. 1 - r and r (in that order) + # null -> locations in the transition matrix equal to 0 + # ...transf_0 -> initial transition vector + # ...transf -> transition matrix + self.register_buffer('r_transf_0', + torch.zeros((M+1, 3, 2, K), dtype=self.dtype)) + self.register_buffer('u_transf_0', + torch.zeros((M+1, 3, 2, K), dtype=self.dtype)) + self.register_buffer('null_transf_0', + torch.zeros((K,), dtype=self.dtype)) + m, g = -1, 0 + for mp in range(M+1): + for gp in range(2): + kp = mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 0, kp] = 1 + + elif m + 1 - g < mp and mp <= M and gp == 0: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 1, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf_0[mpp, 2, 0, kp] = 1 + self.u_transf_0[mpp, 2, 1, kp] = 1 + self.r_transf_0[mp, 2, 0, kp] = 1 + self.u_transf_0[mp, 2, 0, kp] = 1 + + elif m + 1 - g == mp and gp == 1: + self.r_transf_0[m+1-g, g, 1, kp] = 1 + + elif m + 1 - g < mp and mp <= M and gp == 1: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 1, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf_0[mpp, 2, 0, kp] = 1 + self.u_transf_0[mpp, 2, 1, kp] = 1 + self.r_transf_0[mp, 2, 1, kp] = 1 + + else: + self.null_transf_0[kp] = 1 + self.u_transf_0[-1, :, :, :] = 0. + + self.register_buffer('r_transf', + torch.zeros((M+1, 3, 2, K, K), dtype=self.dtype)) + self.register_buffer('u_transf', + torch.zeros((M+1, 3, 2, K, K), dtype=self.dtype)) + self.register_buffer('null_transf', + torch.zeros((K, K), dtype=self.dtype)) + for m in range(M+1): + for g in range(2): + for mp in range(M+1): + for gp in range(2): + k, kp = mg2k(m, g), mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 0, k, kp] = 1 + + elif m + 1 - g < mp and mp <= M and gp == 0: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 1, k, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf[mpp, 2, 0, k, kp] = 1 + self.u_transf[mpp, 2, 1, k, kp] = 1 + self.r_transf[mp, 2, 0, k, kp] = 1 + self.u_transf[mp, 2, 0, k, kp] = 1 + + elif m + 1 - g == mp and gp == 1: + self.r_transf[m+1-g, g, 1, k, kp] = 1 + + elif m + 1 - g < mp and mp <= M and gp == 1: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 1, k, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf[mpp, 2, 0, k, kp] = 1 + self.u_transf[mpp, 2, 1, k, kp] = 1 + self.r_transf[mp, 2, 1, k, kp] = 1 + + elif not (m == M and mp == M and g == 0 and gp == 0): + self.null_transf[k, kp] = 1 + self.u_transf[-1, :, :, :, :] = 0. + + self.register_buffer('vx_transf', + torch.zeros((M+1, K), dtype=self.dtype)) + self.register_buffer('vc_transf', + torch.zeros((M+1, K), dtype=self.dtype)) + for m in range(M+1): + for g in range(2): + k = mg2k(m, g) + if g == 0: + self.vx_transf[m, k] = 1 + elif g == 1: + self.vc_transf[m, k] = 1 + + def forward(self, ancestor_seq_logits, insert_seq_logits, + insert_logits, delete_logits, subsitute_logits=None): + """Assemble the HMM parameters based on the transfer matrices.""" + + initial_logits = ( + torch.einsum('...ijk,...ijkl->...l', + delete_logits, self.u_transf_0) + + torch.einsum('...ijk,...ijkl->...l', + insert_logits, self.r_transf_0) + + (-1/self.epsilon)*self.null_transf_0) + transition_logits = ( + torch.einsum('...ijk,...ijklf->...lf', + delete_logits, self.u_transf) + + torch.einsum('...ijk,...ijklf->...lf', + insert_logits, self.r_transf) + + (-1/self.epsilon)*self.null_transf) + seq_logits = ( + torch.einsum('...ij,...ik->...kj', + ancestor_seq_logits, self.vx_transf) + + torch.einsum('...ij,...ik->...kj', + insert_seq_logits, self.vc_transf)) + + # Option to include the substitution matrix. + if subsitute_logits is not None: + observation_logits = torch.logsumexp( + seq_logits[:, :, None] + subsitute_logits[None, :, :], dim=1) + else: + observation_logits = seq_logits + + return initial_logits, transition_logits, observation_logits From 1fe82b12921520d6409c5e323252e619d173265a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Mon, 28 Dec 2020 18:58:47 -0500 Subject: [PATCH 03/91] Test state arranger. --- pyro/contrib/mue/statearrangers.py | 19 +++-- tests/contrib/mue/test_statearrangers.py | 98 ++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 10 deletions(-) create mode 100644 tests/contrib/mue/test_statearrangers.py diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index a0f530b223..ff512263fe 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -11,11 +11,10 @@ def mg2k(m, g): class profile(PyroModule): - def __init__(self, M, dtype=torch.float64, epsilon=1e-32): + def __init__(self, M, epsilon=1e-32): super().__init__() self.M = M self.K = 2*(M+1) - self.dtype = dtype self.epsilon = epsilon self._make_transfer() @@ -34,11 +33,11 @@ def _make_transfer(self): # ...transf_0 -> initial transition vector # ...transf -> transition matrix self.register_buffer('r_transf_0', - torch.zeros((M+1, 3, 2, K), dtype=self.dtype)) + torch.zeros((M+1, 3, 2, K))) self.register_buffer('u_transf_0', - torch.zeros((M+1, 3, 2, K), dtype=self.dtype)) + torch.zeros((M+1, 3, 2, K))) self.register_buffer('null_transf_0', - torch.zeros((K,), dtype=self.dtype)) + torch.zeros((K,))) m, g = -1, 0 for mp in range(M+1): for gp in range(2): @@ -72,11 +71,11 @@ def _make_transfer(self): self.u_transf_0[-1, :, :, :] = 0. self.register_buffer('r_transf', - torch.zeros((M+1, 3, 2, K, K), dtype=self.dtype)) + torch.zeros((M+1, 3, 2, K, K))) self.register_buffer('u_transf', - torch.zeros((M+1, 3, 2, K, K), dtype=self.dtype)) + torch.zeros((M+1, 3, 2, K, K))) self.register_buffer('null_transf', - torch.zeros((K, K), dtype=self.dtype)) + torch.zeros((K, K))) for m in range(M+1): for g in range(2): for mp in range(M+1): @@ -111,9 +110,9 @@ def _make_transfer(self): self.u_transf[-1, :, :, :, :] = 0. self.register_buffer('vx_transf', - torch.zeros((M+1, K), dtype=self.dtype)) + torch.zeros((M+1, K))) self.register_buffer('vc_transf', - torch.zeros((M+1, K), dtype=self.dtype)) + torch.zeros((M+1, K))) for m in range(M+1): for g in range(2): k = mg2k(m, g) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py new file mode 100644 index 0000000000..0ce85ededa --- /dev/null +++ b/tests/contrib/mue/test_statearrangers.py @@ -0,0 +1,98 @@ +import torch + +from pyro.contrib.mue import profile, mg2k +import pytest + + +@pytest.mark.parameterize('M', [2, 20]) +@pytest.mark.parameterize('batch_size', [None]) +@pytest.mark.parameterize('substitute', [False, True]) +def test_profile(M, batch_size): + torch.set_default_tensor_type('torch.DoubleTensor') + + pf_arranger = profile(M) + + u1 = torch.rand((M+1, 3)) + u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) + r1 = torch.rand((M+1, 3)) + r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) + s = torch.rand((M+1, 4)) + s = s/torch.sum(s, dim=1, keepdim=True) + c = torch.rand((M+1, 4)) + c = c/torch.sum(c, dim=1, keepdim=True) + + if batch_size is not None: + s = torch.ones([batch_size, 1, 1]) * s[None, :, :] + u = torch.ones([batch_size, 1, 1]) * u[None, :, :] + + a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), + torch.log(r), torch.log(u)) + + # - Remake transition matrices. - + u1[-1] = 1e-32 + K = 2*(M+1) + chk_a = torch.zeros((K, K)) + chk_a0 = torch.zeros((K,)) + m, g = -1, 0 + for mp in range(M+1): + for gp in range(2): + kp = mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + chk_a0[kp] = (1 - r1[m+1-g])*(1 - u1[m+1-g]) + elif m + 1 - g < mp and gp == 0: + chk_a0[kp] = ( + (1 - r1[m+1-g]) * u1[m+1-g] * + torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in + range(m+2-g, mp)]) * + (1 - r1[mp]) * (1 - u1[mp])) + elif m + 1 - g == mp and gp == 1: + chk_a0[kp] = r1[m+1-g] + elif m + 1 - g < mp and gp == 1: + chk_a0[kp] = ( + (1 - r1[m+1-g]) * u1[m+1-g] * + torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in + range(m+2-g, mp)]) * r1[mp]) + for m in range(M+1): + for g in range(2): + k = mg2k(m, g) + for mp in range(M+1): + for gp in range(2): + kp = mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + chk_a[k, kp] = (1 - r1[m+1-g])*(1 - u1[m+1-g]) + elif m + 1 - g < mp and gp == 0: + chk_a[k, kp] = ( + (1 - r1[m+1-g]) * u1[m+1-g] * + torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in + range(m+2-g, mp)]) * + (1 - r1[mp]) * (1 - u1[mp])) + elif m + 1 - g == mp and gp == 1: + chk_a[k, kp] = r1[m+1-g] + elif m + 1 - g < mp and gp == 1: + chk_a[k, kp] = ( + (1 - r1[m+1-g]) * u1[m+1-g] * + torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in + range(m+2-g, mp)]) * r1[mp]) + elif m == M and mp == M and g == 0 and gp == 0: + chk_a[k, kp] = 1. + + chk_e = torch.zeros((2*(M+1), 4)) + for m in range(M+1): + for g in range(2): + k = mg2k(m, g) + if g == 0: + chk_e[k, :] = s[m, :].numpy() + else: + chk_e[k, :] = c[m, :].numpy() + # - - + + assert torch.allclose(chk_a0, torch.exp(a0ln)) + assert torch.allclose(chk_a, torch.exp(aln)) + assert torch.allclose(chk_e, torch.exp(eln)) + + # Check normalization. + assert torch.allclose(torch.sum(torch.exp(a0ln)), 1., atol=1e-3, + rtol=1e-3) + assert torch.allclose(torch.sum(torch.exp(aln), axis=1)[:-1], + torch.ones(2*(M+1)-1), atol=1e-3, + rtol=1e-3) From 6e2868b151569a618672492308cb17f9153bd2ce Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 29 Dec 2020 15:42:05 -0500 Subject: [PATCH 04/91] Debug state arranger tests. --- pyro/contrib/mue/__init__.py | 9 ++ pyro/contrib/mue/statearrangers.py | 19 ++- tests/contrib/mue/test_statearrangers.py | 181 ++++++++++++++--------- 3 files changed, 130 insertions(+), 79 deletions(-) diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index d6960608d6..a963573304 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,2 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +The :mod:`pyro.contrib.mue` module provides tools for working with mutational +emission (MuE) distributions. +""" +from pyro.contrib.mue.statearrangers import profile + +__all__ = [ + "profile" +] diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index ff512263fe..258487f4cc 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -122,31 +122,30 @@ def _make_transfer(self): self.vc_transf[m, k] = 1 def forward(self, ancestor_seq_logits, insert_seq_logits, - insert_logits, delete_logits, subsitute_logits=None): + insert_logits, delete_logits, substitute_logits=None): """Assemble the HMM parameters based on the transfer matrices.""" - initial_logits = ( - torch.einsum('...ijk,...ijkl->...l', + torch.einsum('...ijk,ijkl->...l', delete_logits, self.u_transf_0) + - torch.einsum('...ijk,...ijkl->...l', + torch.einsum('...ijk,ijkl->...l', insert_logits, self.r_transf_0) + (-1/self.epsilon)*self.null_transf_0) transition_logits = ( - torch.einsum('...ijk,...ijklf->...lf', + torch.einsum('...ijk,ijklf->...lf', delete_logits, self.u_transf) + - torch.einsum('...ijk,...ijklf->...lf', + torch.einsum('...ijk,ijklf->...lf', insert_logits, self.r_transf) + (-1/self.epsilon)*self.null_transf) seq_logits = ( - torch.einsum('...ij,...ik->...kj', + torch.einsum('...ij,ik->...kj', ancestor_seq_logits, self.vx_transf) + - torch.einsum('...ij,...ik->...kj', + torch.einsum('...ij,ik->...kj', insert_seq_logits, self.vc_transf)) # Option to include the substitution matrix. - if subsitute_logits is not None: + if substitute_logits is not None: observation_logits = torch.logsumexp( - seq_logits[:, :, None] + subsitute_logits[None, :, :], dim=1) + seq_logits.unsqueeze(-1) + substitute_logits, dim=-2) else: observation_logits = seq_logits diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 0ce85ededa..4ca3f7a194 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,13 +1,21 @@ import torch -from pyro.contrib.mue import profile, mg2k +from pyro.contrib.mue.statearrangers import profile, mg2k import pytest -@pytest.mark.parameterize('M', [2, 20]) -@pytest.mark.parameterize('batch_size', [None]) -@pytest.mark.parameterize('substitute', [False, True]) -def test_profile(M, batch_size): +def simpleprod(lst): + # Product of list of scalar tensors, as numpy would do it. + if len(lst) == 0: + return torch.tensor(1.) + else: + return torch.prod(torch.cat([elem[None] for elem in lst])) + + +@pytest.mark.parametrize('M', [2, 20]) +@pytest.mark.parametrize('batch_size', [None, 5]) +@pytest.mark.parametrize('substitute', [False, True]) +def test_profile(M, batch_size, substitute): torch.set_default_tensor_type('torch.DoubleTensor') pf_arranger = profile(M) @@ -22,77 +30,112 @@ def test_profile(M, batch_size): c = c/torch.sum(c, dim=1, keepdim=True) if batch_size is not None: - s = torch.ones([batch_size, 1, 1]) * s[None, :, :] - u = torch.ones([batch_size, 1, 1]) * u[None, :, :] + s = torch.rand((batch_size, M+1, 4)) + s = s/torch.sum(s, dim=2, keepdim=True) + u1 = torch.rand((batch_size, M+1, 3)) + u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) - a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), - torch.log(r), torch.log(u)) + if substitute: + ll = torch.rand((4, 5)) + ll = ll/torch.sum(ll, dim=1, keepdim=True) + a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), + torch.log(r), torch.log(u), + torch.log(ll)) + else: + a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), + torch.log(r), torch.log(u)) # - Remake transition matrices. - - u1[-1] = 1e-32 K = 2*(M+1) - chk_a = torch.zeros((K, K)) - chk_a0 = torch.zeros((K,)) - m, g = -1, 0 - for mp in range(M+1): - for gp in range(2): - kp = mg2k(mp, gp) - if m + 1 - g == mp and gp == 0: - chk_a0[kp] = (1 - r1[m+1-g])*(1 - u1[m+1-g]) - elif m + 1 - g < mp and gp == 0: - chk_a0[kp] = ( - (1 - r1[m+1-g]) * u1[m+1-g] * - torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in - range(m+2-g, mp)]) * - (1 - r1[mp]) * (1 - u1[mp])) - elif m + 1 - g == mp and gp == 1: - chk_a0[kp] = r1[m+1-g] - elif m + 1 - g < mp and gp == 1: - chk_a0[kp] = ( - (1 - r1[m+1-g]) * u1[m+1-g] * - torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in - range(m+2-g, mp)]) * r1[mp]) - for m in range(M+1): - for g in range(2): - k = mg2k(m, g) - for mp in range(M+1): - for gp in range(2): - kp = mg2k(mp, gp) - if m + 1 - g == mp and gp == 0: - chk_a[k, kp] = (1 - r1[m+1-g])*(1 - u1[m+1-g]) - elif m + 1 - g < mp and gp == 0: - chk_a[k, kp] = ( - (1 - r1[m+1-g]) * u1[m+1-g] * - torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in - range(m+2-g, mp)]) * - (1 - r1[mp]) * (1 - u1[mp])) - elif m + 1 - g == mp and gp == 1: - chk_a[k, kp] = r1[m+1-g] - elif m + 1 - g < mp and gp == 1: - chk_a[k, kp] = ( - (1 - r1[m+1-g]) * u1[m+1-g] * - torch.prod([(1 - r1[mpp])*u1[mpp] for mpp in - range(m+2-g, mp)]) * r1[mp]) - elif m == M and mp == M and g == 0 and gp == 0: - chk_a[k, kp] = 1. + if batch_size is None: + batch_dim_size = 1 + r1 = r1.unsqueeze(0) + u1 = u1.unsqueeze(0) + s = s.unsqueeze(0) + c = c.unsqueeze(0) + if substitute: + ll = ll.unsqueeze(0) + else: + batch_dim_size = batch_size + r1 = r1[None, :, :] * torch.ones([batch_size, 1, 1]) + c = c[None, :, :] * torch.ones([batch_size, 1, 1]) + if substitute: + ll = ll.unsqueeze(0) + chk_a = torch.zeros((batch_dim_size, K, K)) + chk_a0 = torch.zeros((batch_dim_size, K)) + chk_e = torch.zeros((batch_dim_size, K, 4)) + for b in range(batch_dim_size): + m, g = -1, 0 + u1[b][-1] = 1e-32 + for mp in range(M+1): + for gp in range(2): + kp = mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + chk_a0[b, kp] = (1 - r1[b, m+1-g, g])*(1 - u1[b, m+1-g, g]) + elif m + 1 - g < mp and gp == 0: + chk_a0[b, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)]) * + (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + elif m + 1 - g == mp and gp == 1: + chk_a0[b, kp] = r1[b, m+1-g, g] + elif m + 1 - g < mp and gp == 1: + chk_a0[b, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)]) * r1[b, mp, 2]) + for m in range(M+1): + for g in range(2): + k = mg2k(m, g) + for mp in range(M+1): + for gp in range(2): + kp = mg2k(mp, gp) + if m + 1 - g == mp and gp == 0: + chk_a[b, k, kp] = (1 - r1[b, m+1-g, g] + )*(1 - u1[b, m+1-g, g]) + elif m + 1 - g < mp and gp == 0: + chk_a[b, k, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2]) * + u1[b, mpp, 2] + for mpp in range(m+2-g, mp)]) * + (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + elif m + 1 - g == mp and gp == 1: + chk_a[b, k, kp] = r1[b, m+1-g, g] + elif m + 1 - g < mp and gp == 1: + chk_a[b, k, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2]) * + u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)] + ) * r1[b, mp, 2]) + elif m == M and mp == M and g == 0 and gp == 0: + chk_a[b, k, kp] = 1. - chk_e = torch.zeros((2*(M+1), 4)) - for m in range(M+1): - for g in range(2): - k = mg2k(m, g) - if g == 0: - chk_e[k, :] = s[m, :].numpy() - else: - chk_e[k, :] = c[m, :].numpy() + for m in range(M+1): + for g in range(2): + k = mg2k(m, g) + if g == 0: + chk_e[b, k, :] = s[b, m, :] + else: + chk_e[b, k, :] = c[b, m, :] + if substitute: + chk_e = torch.matmul(chk_e, ll) # - - + if batch_size is None: + chk_a = chk_a.squeeze() + chk_a0 = chk_a0.squeeze() + chk_e = chk_e.squeeze() + assert torch.allclose(torch.sum(torch.exp(a0ln)), torch.tensor(1.), + atol=1e-3, rtol=1e-3) + assert torch.allclose(torch.sum(torch.exp(aln), axis=1)[:-1], + torch.ones(2*(M+1)-1), atol=1e-3, + rtol=1e-3) assert torch.allclose(chk_a0, torch.exp(a0ln)) assert torch.allclose(chk_a, torch.exp(aln)) assert torch.allclose(chk_e, torch.exp(eln)) - - # Check normalization. - assert torch.allclose(torch.sum(torch.exp(a0ln)), 1., atol=1e-3, - rtol=1e-3) - assert torch.allclose(torch.sum(torch.exp(aln), axis=1)[:-1], - torch.ones(2*(M+1)-1), atol=1e-3, - rtol=1e-3) From 7ddecb663f7eda13713fa115dca1de474b54a650 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 29 Dec 2020 19:45:04 -0500 Subject: [PATCH 05/91] Variable length discrete hmm class and log probability function. --- pyro/contrib/mue/variablelengthhmm.py | 75 +++++++++++++++++++++ tests/contrib/mue/test_statearrangers.py | 2 +- tests/contrib/mue/test_variablelengthhmm.py | 5 ++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 tests/contrib/mue/test_variablelengthhmm.py diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index d6960608d6..927ec01c12 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -1,2 +1,77 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import torch +from torch.distributions import constraints + +from pyro.distributions.torch_distribution import TorchDistribution +from pyro.distributions.hmm import _sequential_logmatmulexp +from pyro.distributions.util import broadcast_shape + + +class VariableLengthDiscreteHMM(TorchDistribution): + """ + HMM with discrete latent states and discrete observations, allowing for + variable length sequences. + """ + arg_constraints = {"initial_logits": constraints.real, + "transition_logits": constraints.real, + "observation_logits": constraints.real} + + def __init__(self, initial_logits, transition_logits, observation_logits, validate_args=None): + if initial_logits.dim() < 1: + raise ValueError("expected initial_logits to have at least one dim, " + "actual shape = {}".format(initial_logits.shape)) + if transition_logits.dim() < 2: + raise ValueError("expected transition_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) + if observation_logits.dim() < 2: + raise ValueError("expected observation_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) + shape = broadcast_shape(initial_logits.shape[:-1] + (1,), + transition_logits.shape[:-2], + observation_logits.shape[:-2]) + batch_shape = shape[:-1] + self.event_shape = observation_logits.shape[-1:] + self.initial_logits = (initial_logits - + initial_logits.logsumexp(-1, True)) + self.transition_logits = (transition_logits - + transition_logits.logsumexp(-1, True)) + self.observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + super(VariableLengthDiscreteHMM, self).__init__( + batch_shape, self.event_shape, validate_args=validate_args) + + def log_prob(self, value): + + # observation_logits: + # batch_shape (option) x state_dim x observation_dim + # value: + # batch_shape (option) x num_steps x observation_dim + # value_logits + # batch_shape (option) x num_steps x state_dim + # transition_logits: + # batch_shape (option) x state_dim x state_dim + # result 1 + # batch_shape (option) x num_steps x state_dim (old) x state_dim (new) + # result 2 + # batch_shape (option) x state_dim (old) x state_dim (new) + # initial_logits + # batch_shape (option) x state_dim + # result 3 + # batch_shape (option) + + # Combine observation and transition factors. + value = value.unsqueeze(-2) + value_logits = torch.einsum('bso,bno->bns', self.observation_logits, + value) + result = self.transition_logits + value_logits.unsqueeze(-2) + + # Eliminate time dimension. + result = _sequential_logmatmulexp(result) + + # Combine initial factor. + result = self.initial_logits + result.logsumexp(-1) + + # Marginalize out final state. + result = result.logsumexp(-1) + return result diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 4ca3f7a194..881fa80c2a 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,7 +1,7 @@ import torch +import pytest from pyro.contrib.mue.statearrangers import profile, mg2k -import pytest def simpleprod(lst): diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_variablelengthhmm.py new file mode 100644 index 0000000000..2b95465236 --- /dev/null +++ b/tests/contrib/mue/test_variablelengthhmm.py @@ -0,0 +1,5 @@ +import torch +import pytest + +from pyro.contrib.mue.statearrangers import profile +from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM From c2458f40a4c356d061abf5588205100b01b6645e Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 30 Dec 2020 12:58:58 -0500 Subject: [PATCH 06/91] Test for log probability of variable length hmm. --- pyro/contrib/mue/variablelengthhmm.py | 41 +++++++----- tests/contrib/mue/test_variablelengthhmm.py | 69 ++++++++++++++++++++- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 927ec01c12..124c66c387 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -17,21 +17,25 @@ class VariableLengthDiscreteHMM(TorchDistribution): "transition_logits": constraints.real, "observation_logits": constraints.real} - def __init__(self, initial_logits, transition_logits, observation_logits, validate_args=None): + def __init__(self, initial_logits, transition_logits, observation_logits, + validate_args=None): if initial_logits.dim() < 1: - raise ValueError("expected initial_logits to have at least one dim, " - "actual shape = {}".format(initial_logits.shape)) + raise ValueError( + "expected initial_logits to have at least one dim, " + "actual shape = {}".format(initial_logits.shape)) if transition_logits.dim() < 2: - raise ValueError("expected transition_logits to have at least two dims, " - "actual shape = {}".format(transition_logits.shape)) + raise ValueError( + "expected transition_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) if observation_logits.dim() < 2: - raise ValueError("expected observation_logits to have at least two dims, " - "actual shape = {}".format(transition_logits.shape)) + raise ValueError( + "expected observation_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) shape = broadcast_shape(initial_logits.shape[:-1] + (1,), transition_logits.shape[:-2], observation_logits.shape[:-2]) batch_shape = shape[:-1] - self.event_shape = observation_logits.shape[-1:] + event_shape = observation_logits.shape[-1:] self.initial_logits = (initial_logits - initial_logits.logsumexp(-1, True)) self.transition_logits = (transition_logits - @@ -39,18 +43,23 @@ def __init__(self, initial_logits, transition_logits, observation_logits, valida self.observation_logits = (observation_logits - observation_logits.logsumexp(-1, True)) super(VariableLengthDiscreteHMM, self).__init__( - batch_shape, self.event_shape, validate_args=validate_args) + batch_shape, event_shape, validate_args=validate_args) def log_prob(self, value): - + """Warning: like in pyro's DiscreteHMM, the probability of the first + state is computed as + initial.T @ transition @ emission + rather than the more conventional HMM parameterization, + initial.T @ emission + """ # observation_logits: # batch_shape (option) x state_dim x observation_dim # value: # batch_shape (option) x num_steps x observation_dim # value_logits - # batch_shape (option) x num_steps x state_dim + # batch_shape (option) x num_steps x state_dim (new) # transition_logits: - # batch_shape (option) x state_dim x state_dim + # batch_shape (option) x state_dim (old) x state_dim (new) # result 1 # batch_shape (option) x num_steps x state_dim (old) x state_dim (new) # result 2 @@ -61,10 +70,10 @@ def log_prob(self, value): # batch_shape (option) # Combine observation and transition factors. - value = value.unsqueeze(-2) - value_logits = torch.einsum('bso,bno->bns', self.observation_logits, - value) - result = self.transition_logits + value_logits.unsqueeze(-2) + value_logits = torch.matmul( + value, torch.transpose(self.observation_logits, -2, -1)) + result = (self.transition_logits.unsqueeze(-3) + + value_logits.unsqueeze(-2)) # Eliminate time dimension. result = _sequential_logmatmulexp(result) diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_variablelengthhmm.py index 2b95465236..b12492f3bd 100644 --- a/tests/contrib/mue/test_variablelengthhmm.py +++ b/tests/contrib/mue/test_variablelengthhmm.py @@ -1,5 +1,70 @@ import torch -import pytest -from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM + + +def test_hmm_log_prob(): + torch.set_default_tensor_type('torch.DoubleTensor') + + a0 = torch.tensor([0.9, 0.08, 0.02]) + a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) + e = torch.tensor([[0.99, 0.01], [0.01, 0.99], [0.5, 0.5]]) + + x = torch.tensor([[0., 1.], + [1., 0.], + [0., 1.], + [0., 1.], + [1., 0.], + [0., 0.]]) + + hmm_distr = VariableLengthDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) + lp = hmm_distr.log_prob(x) + + f = torch.matmul(a0, a) * e[:, 1] + f = torch.matmul(f, a) * e[:, 0] + f = torch.matmul(f, a) * e[:, 1] + f = torch.matmul(f, a) * e[:, 1] + f = torch.matmul(f, a) * e[:, 0] + chk_lp = torch.log(torch.sum(f)) + + assert torch.allclose(lp, chk_lp) + + # Batch values. + x = torch.cat([ + x[None, :, :], + torch.tensor([[1., 0.], + [1., 0.], + [1., 0.], + [0., 0.], + [0., 0.], + [0., 0.]])[None, :, :]], dim=0) + lp = hmm_distr.log_prob(x) + + f = torch.matmul(a0, a) * e[:, 0] + f = torch.matmul(f, a) * e[:, 0] + f = torch.matmul(f, a) * e[:, 0] + chk_lp = torch.cat([chk_lp[None], torch.log(torch.sum(f))[None]]) + + assert torch.allclose(lp, chk_lp) + + # Batch both parameters and values. + a0 = torch.cat([a0[None, :], torch.tensor([0.2, 0.7, 0.1])[None, :]]) + a = torch.cat([ + a[None, :, :], + torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.6, 0.2], [0.1, 0.1, 0.8]] + )[None, :, :]], dim=0) + e = torch.cat([ + e[None, :, :], + torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :]], + dim=0) + hmm_distr = VariableLengthDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) + lp = hmm_distr.log_prob(x) + + f = torch.matmul(a0[1, :], a[1, :, :]) * e[1, :, 0] + f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] + f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] + chk_lp = torch.cat([chk_lp[0][None], torch.log(torch.sum(f))[None]]) + + assert torch.allclose(lp, chk_lp) From 1ea5af60f4c3324ea1a22745d1974251f21ce739 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 30 Dec 2020 17:13:18 -0500 Subject: [PATCH 07/91] Simple profile hmm example model. --- examples/contrib/mue/phmm.py | 169 ++++++++++++++++++++++++++ pyro/contrib/mue/statearrangers.py | 4 +- pyro/contrib/mue/variablelengthhmm.py | 4 +- 3 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 examples/contrib/mue/phmm.py diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py new file mode 100644 index 0000000000..10b7a8a455 --- /dev/null +++ b/examples/contrib/mue/phmm.py @@ -0,0 +1,169 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +A standard profile HMM model example, using the MuE package. +""" + +import torch +import torch.nn as nn +from torch.nn.functional import softplus + +import pyro +import pyro.distributions as dist +from pyro.optim import Adam +from pyro.infer import SVI, Trace_ELBO +import pyro.poutine as poutine + +from pyro.contrib.mue.statearrangers import profile +from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM + +import datetime +import matplotlib.pyplot as plt + + +class ProfileHMM(nn.Module): + + def __init__(self, latent_seq_length, alphabet_length, + prior_scale=1., indel_prior_strength=100.): + super().__init__() + + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + self.seq_shape = (latent_seq_length+1, alphabet_length) + self.indel_shape = (latent_seq_length+1, 3, 2) + + assert isinstance(prior_scale, float) + self.prior_scale = prior_scale + assert isinstance(indel_prior_strength, float) + self.indel_prior = torch.tensor([indel_prior_strength, 0.]) + + # Initialize state arranger. + self.statearrange = profile(latent_seq_length) + + def model(self, data): + + # Latent sequence. + ancestor_seq = pyro.sample("ancestor_seq", dist.Normal( + torch.zeros(self.seq_shape), + self.prior_scale * torch.ones(self.seq_shape)).to_event(2)) + ancestor_seq_logits = ancestor_seq - ancestor_seq.logsumexp(-1, True) + insert_seq = pyro.sample("insert_seq", dist.Normal( + torch.zeros(self.seq_shape), + self.prior_scale * torch.ones(self.seq_shape)).to_event(2)) + insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True) + + # Indel probabilities. + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Construct HMM parameters. + initial_logits, transition_logits, observation_logits = ( + self.statearrange(ancestor_seq_logits, insert_seq_logits, + insert_logits, delete_logits)) + # Draw samples. + for i in pyro.plate("data", data.shape[0]): + pyro.sample("obs_{}".format(i), + VariableLengthDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=data[i]) + + def guide(self, data): + # Sequence. + ancestor_seq_q_mn = pyro.param("ancestor_seq_q_mn", + torch.zeros(self.seq_shape)) + ancestor_seq_q_sd = pyro.param("ancestor_seq_q_sd", + torch.zeros(self.seq_shape)) + pyro.sample("ancestor_seq", dist.Normal( + ancestor_seq_q_mn, softplus(ancestor_seq_q_sd)).to_event(2)) + insert_seq_q_mn = pyro.param("insert_seq_q_mn", + torch.zeros(self.seq_shape)) + insert_seq_q_sd = pyro.param("insert_seq_q_sd", + torch.zeros(self.seq_shape)) + pyro.sample("insert_seq", dist.Normal( + insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2)) + + # Indels. + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + + +def main(): + small_test = False + + if small_test: + mult_dat = 1 + mult_step = 1 + else: + mult_dat = 10 + mult_step = 10 + + data = torch.cat([torch.tensor([[0., 1.], + [1., 0.], + [0., 1.], + [0., 1.], + [1., 0.], + [0., 0.]])[None, :, :] + for j in range(6*mult_dat)] + + [torch.tensor([[0., 1.], + [1., 0.], + [1., 0.], + [0., 1.], + [0., 0.], + [0., 0.]])[None, :, :] + for j in range(4*mult_dat)], dim=0) + print('data shape', data.shape) + # Set up inference. + latent_seq_length, alphabet_length = 6, 2 + adam_params = {"lr": 0.05, "betas": (0.90, 0.999)} + optimizer = Adam(adam_params) + model = ProfileHMM(latent_seq_length, alphabet_length) + + trace = poutine.trace(model.model).get_trace(data) + trace.compute_log_prob() # optional, but allows printing of log_prob shapes + print('format shapes', trace.format_shapes()) + + svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO()) + n_steps = 10*mult_step + + # Run inference. + losses = [] + t0 = datetime.datetime.now() + for step in range(n_steps): + loss = svi.step(data) + losses.append(loss) + if step % 10 == 0: + print(loss, ' ', datetime.datetime.now() - t0) + + # Plots. + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + plt.savefig('phmm/loss_{}.pdf'.format(time_stamp)) + + +if __name__ == '__main__': + main() diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 258487f4cc..953e1fbb3d 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 import torch -from pyro.nn import PyroModule +import torch.nn as nn def mg2k(m, g): @@ -9,7 +9,7 @@ def mg2k(m, g): return 2*m + 1 - g -class profile(PyroModule): +class profile(nn.Module): def __init__(self, M, epsilon=1e-32): super().__init__() diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 124c66c387..3259f73ae9 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -34,8 +34,8 @@ def __init__(self, initial_logits, transition_logits, observation_logits, shape = broadcast_shape(initial_logits.shape[:-1] + (1,), transition_logits.shape[:-2], observation_logits.shape[:-2]) - batch_shape = shape[:-1] - event_shape = observation_logits.shape[-1:] + batch_shape = shape + event_shape = (1, observation_logits.shape[-1]) self.initial_logits = (initial_logits - initial_logits.logsumexp(-1, True)) self.transition_logits = (transition_logits - From 2f72384377c2cdf2b25f9ea3d283a6537c4e8ff9 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 30 Dec 2020 17:44:05 -0500 Subject: [PATCH 08/91] Switch to standard hmm log probability convention. --- examples/contrib/mue/phmm.py | 31 +++++++++++++++++---- pyro/contrib/mue/variablelengthhmm.py | 13 +++++---- tests/contrib/mue/test_variablelengthhmm.py | 6 ++-- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py index 10b7a8a455..d60875290e 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/phmm.py @@ -25,7 +25,7 @@ class ProfileHMM(nn.Module): def __init__(self, latent_seq_length, alphabet_length, - prior_scale=1., indel_prior_strength=100.): + prior_scale=1., indel_prior_strength=10.): super().__init__() assert isinstance(latent_seq_length, int) and latent_seq_length > 0 @@ -133,17 +133,12 @@ def main(): [0., 0.], [0., 0.]])[None, :, :] for j in range(4*mult_dat)], dim=0) - print('data shape', data.shape) # Set up inference. latent_seq_length, alphabet_length = 6, 2 adam_params = {"lr": 0.05, "betas": (0.90, 0.999)} optimizer = Adam(adam_params) model = ProfileHMM(latent_seq_length, alphabet_length) - trace = poutine.trace(model.model).get_trace(data) - trace.compute_log_prob() # optional, but allows printing of log_prob shapes - print('format shapes', trace.format_shapes()) - svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO()) n_steps = 10*mult_step @@ -164,6 +159,30 @@ def main(): plt.ylabel('loss') plt.savefig('phmm/loss_{}.pdf'.format(time_stamp)) + plt.figure(figsize=(6, 6)) + ancestor_seq = pyro.param("ancestor_seq_q_mn").detach() + ancestor_seq_expect = torch.exp(ancestor_seq - + ancestor_seq.logsumexp(-1, True)) + plt.plot(ancestor_seq_expect[:, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of character 1') + plt.savefig('phmm/ancestor_seq_prob_{}.pdf'.format(time_stamp)) + + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + plt.savefig('phmm/insert_prob_{}.pdf'.format(time_stamp)) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + plt.savefig('phmm/delete_prob_{}.pdf'.format(time_stamp)) + if __name__ == '__main__': main() diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 3259f73ae9..4a434e5fbf 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -46,10 +46,10 @@ def __init__(self, initial_logits, transition_logits, observation_logits, batch_shape, event_shape, validate_args=validate_args) def log_prob(self, value): - """Warning: like in pyro's DiscreteHMM, the probability of the first - state is computed as + """Warning: unlike in pyro's DiscreteHMM, which computes the + probability of the first state as initial.T @ transition @ emission - rather than the more conventional HMM parameterization, + this distribution uses the standard HMM convention, initial.T @ emission """ # observation_logits: @@ -61,7 +61,7 @@ def log_prob(self, value): # transition_logits: # batch_shape (option) x state_dim (old) x state_dim (new) # result 1 - # batch_shape (option) x num_steps x state_dim (old) x state_dim (new) + # batch_shape (option) x num_steps-1 x state_dim (old) x state_dim (new) # result 2 # batch_shape (option) x state_dim (old) x state_dim (new) # initial_logits @@ -73,13 +73,14 @@ def log_prob(self, value): value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1)) result = (self.transition_logits.unsqueeze(-3) + - value_logits.unsqueeze(-2)) + value_logits[..., 1:, None, :]) # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. - result = self.initial_logits + result.logsumexp(-1) + result = (self.initial_logits + value_logits[..., 0, :] + + result.logsumexp(-1)) # Marginalize out final state. result = result.logsumexp(-1) diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_variablelengthhmm.py index b12492f3bd..751071009a 100644 --- a/tests/contrib/mue/test_variablelengthhmm.py +++ b/tests/contrib/mue/test_variablelengthhmm.py @@ -21,7 +21,7 @@ def test_hmm_log_prob(): torch.log(e)) lp = hmm_distr.log_prob(x) - f = torch.matmul(a0, a) * e[:, 1] + f = a0 * e[:, 1] f = torch.matmul(f, a) * e[:, 0] f = torch.matmul(f, a) * e[:, 1] f = torch.matmul(f, a) * e[:, 1] @@ -41,7 +41,7 @@ def test_hmm_log_prob(): [0., 0.]])[None, :, :]], dim=0) lp = hmm_distr.log_prob(x) - f = torch.matmul(a0, a) * e[:, 0] + f = a0 * e[:, 0] f = torch.matmul(f, a) * e[:, 0] f = torch.matmul(f, a) * e[:, 0] chk_lp = torch.cat([chk_lp[None], torch.log(torch.sum(f))[None]]) @@ -62,7 +62,7 @@ def test_hmm_log_prob(): torch.log(e)) lp = hmm_distr.log_prob(x) - f = torch.matmul(a0[1, :], a[1, :, :]) * e[1, :, 0] + f = a0[1, :] * e[1, :, 0] f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] chk_lp = torch.cat([chk_lp[0][None], torch.log(torch.sum(f))[None]]) From 1725392316fb467349090ca59b69a49793f7fe83 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 30 Dec 2020 19:56:39 -0500 Subject: [PATCH 09/91] FactorMuE forward pass. --- examples/contrib/mue/FactorMuE.py | 132 ++++++++++++++++++++++++++++++ examples/contrib/mue/phmm.py | 2 +- 2 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 examples/contrib/mue/FactorMuE.py diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py new file mode 100644 index 0000000000..e5322bc513 --- /dev/null +++ b/examples/contrib/mue/FactorMuE.py @@ -0,0 +1,132 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +A PCA model with a MuE emission (FactorMuE). Uses the MuE package. +""" + +import torch +import torch.nn as nn +from torch.nn.functional import softplus + +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.optim import Adam +from pyro.infer import SVI, Trace_ELBO + +from pyro.contrib.mue.statearrangers import profile +from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM + +import datetime +import matplotlib.pyplot as plt + + +class Encoder(nn.Module): + def __init__(self, obs_seq_length, alphabet_length, z_dim): + super().__init__() + + self.input_size = obs_seq_length * alphabet_length + self.f1_mn = nn.Linear(self.input_size, z_dim) + self.f1_sd = nn.Linear(self.input_size, z_dim) + + def forward(self, data): + + data = data.reshape(-1, self.input_size) + z_loc = self.f1_mn(data) + z_scale = softplus(self.f1_sd(data)) + + return z_loc, z_scale + + +class Decoder(nn.Module): + def __init__(self, obs_seq_length, alphabet_length, z_dim): + super().__init__() + + self.obs_seq_length = obs_seq_length + self.alphabet_length = alphabet_length + self.output_size = 2 * obs_seq_length * alphabet_length + self.f = nn.Linear(z_dim, self.output_size) + + def forward(self, z): + + seq = self.f(z) + seq = seq.reshape([-1, 2, self.obs_seq_length, + self.alphabet_length]) + return seq + + +class FactorMuE(nn.Module): + + def __init__(self, obs_seq_length, alphabet_length, z_dim, + latent_seq_length=None, prior_scale=1., + indel_prior_strength=10.): + super().__init__() + + # Constants. + assert isinstance(obs_seq_length, int) and obs_seq_length > 0 + self.obs_seq_length = obs_seq_length + if latent_seq_length is None: + latent_seq_length = obs_seq_length + else: + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + assert isinstance(z_dim, int) and z_dim > 0 + self.z_dim = z_dim + + # Parameter shapes. + self.seq_shape = (latent_seq_length+1, alphabet_length) + self.indel_shape = (latent_seq_length+1, 3, 2) + + assert isinstance(prior_scale, float) + self.prior_scale = prior_scale + assert isinstance(indel_prior_strength, float) + self.indel_prior = torch.tensor([indel_prior_strength, 0.]) + + # Initialize layers. + self.encoder = Encoder(obs_seq_length, alphabet_length, z_dim) + self.decoder = Decoder(obs_seq_length, alphabet_length, z_dim) + self.statearrange = profile(latent_seq_length) + + def model(self, data): + + pyro.module("decoder", self.decoder) + + # Indel probabilities. + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Temperature. + # pyro.sample("inverse_temp", dist.Normal()) + + with pyro.plate("batch", data.shape[0]), poutine.scale( + scale=self.scale_factor): + # Sample latent variable from prior. + z = pyro.sample("latent", dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim)).to_event(1)) + # Decode latent sequence. + latent_seq = self.decoder.forward(z) + # Construct ancestral and insertion sequences. + ancestor_seq_logits = latent_seq[..., 0, :, :] + ancestor_seq_logits = (ancestor_seq_logits - + ancestor_seq_logits.logsumexp(-1, True)) + insert_seq_logits = latent_seq[..., 1, :, :] + insert_seq_logits = (insert_seq_logits - + insert_seq_logits.logsumexp(-1, True)) + # Construct HMM parameters. + initial_logits, transition_logits, observation_logits = ( + self.statearrange(ancestor_seq_logits, insert_seq_logits, + insert_logits, delete_logits)) + # Draw samples. + pyro.sample("obs", + VariableLengthDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=data) diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py index d60875290e..6ae59666b5 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/phmm.py @@ -13,7 +13,6 @@ import pyro.distributions as dist from pyro.optim import Adam from pyro.infer import SVI, Trace_ELBO -import pyro.poutine as poutine from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM @@ -32,6 +31,7 @@ def __init__(self, latent_seq_length, alphabet_length, self.latent_seq_length = latent_seq_length assert isinstance(alphabet_length, int) and alphabet_length > 0 self.alphabet_length = alphabet_length + self.seq_shape = (latent_seq_length+1, alphabet_length) self.indel_shape = (latent_seq_length+1, 3, 2) From bfba6ae96897225b815fbd3c7651c0d5c35abfa8 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 31 Dec 2020 12:52:57 -0500 Subject: [PATCH 10/91] FactorMuE plots, debug. --- examples/contrib/mue/FactorMuE.py | 202 ++++++++++++++++++++++++++++-- 1 file changed, 190 insertions(+), 12 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index e5322bc513..24053da1b7 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -21,6 +21,8 @@ import datetime import matplotlib.pyplot as plt +import pdb + class Encoder(nn.Module): def __init__(self, obs_seq_length, alphabet_length, z_dim): @@ -40,18 +42,18 @@ def forward(self, data): class Decoder(nn.Module): - def __init__(self, obs_seq_length, alphabet_length, z_dim): + def __init__(self, latent_seq_length, alphabet_length, z_dim): super().__init__() - self.obs_seq_length = obs_seq_length + self.latent_seq_length = latent_seq_length self.alphabet_length = alphabet_length - self.output_size = 2 * obs_seq_length * alphabet_length + self.output_size = 2 * (latent_seq_length+1) * alphabet_length self.f = nn.Linear(z_dim, self.output_size) def forward(self, z): seq = self.f(z) - seq = seq.reshape([-1, 2, self.obs_seq_length, + seq = seq.reshape([-1, 2, self.latent_seq_length+1, self.alphabet_length]) return seq @@ -59,8 +61,9 @@ def forward(self, z): class FactorMuE(nn.Module): def __init__(self, obs_seq_length, alphabet_length, z_dim, + scale_factor=1., latent_seq_length=None, prior_scale=1., - indel_prior_strength=10.): + indel_prior_strength=10., inverse_temp_prior=100.): super().__init__() # Constants. @@ -70,6 +73,7 @@ def __init__(self, obs_seq_length, alphabet_length, z_dim, latent_seq_length = obs_seq_length else: assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length assert isinstance(alphabet_length, int) and alphabet_length > 0 self.alphabet_length = alphabet_length assert isinstance(z_dim, int) and z_dim > 0 @@ -79,14 +83,20 @@ def __init__(self, obs_seq_length, alphabet_length, z_dim, self.seq_shape = (latent_seq_length+1, alphabet_length) self.indel_shape = (latent_seq_length+1, 3, 2) + # Priors. assert isinstance(prior_scale, float) - self.prior_scale = prior_scale + self.prior_scale = torch.tensor(prior_scale) assert isinstance(indel_prior_strength, float) self.indel_prior = torch.tensor([indel_prior_strength, 0.]) + assert isinstance(inverse_temp_prior, float) + self.inverse_temp_prior = torch.tensor(inverse_temp_prior) + + # Batch control. + self.scale_factor = scale_factor # Initialize layers. self.encoder = Encoder(obs_seq_length, alphabet_length, z_dim) - self.decoder = Decoder(obs_seq_length, alphabet_length, z_dim) + self.decoder = Decoder(latent_seq_length, alphabet_length, z_dim) self.statearrange = profile(latent_seq_length) def model(self, data): @@ -103,8 +113,9 @@ def model(self, data): self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) delete_logits = delete - delete.logsumexp(-1, True) - # Temperature. - # pyro.sample("inverse_temp", dist.Normal()) + # Inverse temperature. + inverse_temp = pyro.sample("inverse_temp", dist.Normal( + self.inverse_temp_prior, torch.tensor(1.))) with pyro.plate("batch", data.shape[0]), poutine.scale( scale=self.scale_factor): @@ -112,21 +123,188 @@ def model(self, data): z = pyro.sample("latent", dist.Normal( torch.zeros(self.z_dim), torch.ones(self.z_dim)).to_event(1)) # Decode latent sequence. - latent_seq = self.decoder.forward(z) + latent_seq = self.decoder(z) # Construct ancestral and insertion sequences. - ancestor_seq_logits = latent_seq[..., 0, :, :] + ancestor_seq_logits = (latent_seq[..., 0, :, :] * + softplus(inverse_temp)) ancestor_seq_logits = (ancestor_seq_logits - ancestor_seq_logits.logsumexp(-1, True)) - insert_seq_logits = latent_seq[..., 1, :, :] + insert_seq_logits = (latent_seq[..., 1, :, :] * + softplus(inverse_temp)) insert_seq_logits = (insert_seq_logits - insert_seq_logits.logsumexp(-1, True)) # Construct HMM parameters. initial_logits, transition_logits, observation_logits = ( self.statearrange(ancestor_seq_logits, insert_seq_logits, insert_logits, delete_logits)) + print('initial_logits', initial_logits) + print('transition_logits', transition_logits) # Draw samples. pyro.sample("obs", VariableLengthDiscreteHMM(initial_logits, transition_logits, observation_logits), obs=data) + print(VariableLengthDiscreteHMM(initial_logits, + transition_logits, + observation_logits).log_prob(data)) + + def guide(self, data): + # Register encoder with pyro. + pyro.module("encoder", self.encoder) + + # Indel probabilities. + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + print('insert_q_sd', insert_q_mn) + print('softplus insert', softplus(insert_q_sd)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + # Inverse temperature. + inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.)) + inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.)) + pyro.sample("inverse_temp", dist.Normal( + inverse_temp_q_mn, softplus(inverse_temp_q_sd))) + + # Per data latent variables. + with pyro.plate("batch", data.shape[0]), poutine.scale( + scale=self.scale_factor): + # Encode seq. + z_loc, z_scale = self.encoder(data) + # Sample. + pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + + def reconstruct_ancestor_seq(self, data, inverse_temp=1.): + # Encode seq. + z_loc = self.encoder(data)[0] + # Reconstruct + latent_seq = self.decoder(z_loc) + # Construct ancestral. + ancestor_seq_logits = latent_seq[..., 0, :, :] * softplus(inverse_temp) + ancestor_seq_logits = (ancestor_seq_logits - + ancestor_seq_logits.logsumexp(-1, True)) + return torch.exp(ancestor_seq_logits) + + +def main(): + + torch.manual_seed(9) + torch.set_default_tensor_type('torch.DoubleTensor') + + small_test = False + + if small_test: + mult_dat = 1 + mult_step = 1 + else: + mult_dat = 10 + mult_step = 100 + + xs = [torch.tensor([[0., 1.], + [1., 0.], + [0., 1.], + [0., 1.], + [1., 0.], + [0., 0.]]), + torch.tensor([[0., 1.], + [1., 0.], + [1., 0.], + [0., 1.], + [0., 0.], + [0., 0.]]), + torch.tensor([[0., 1.], + [1., 0.], + [0., 1.], + [0., 1.], + [0., 1.], + [0., 0.]])] + data = torch.cat([xs[0][None, :, :] for j in range(6*mult_dat)] + + [xs[1][None, :, :] for j in range(4*mult_dat)] + + [xs[2][None, :, :] for j in range(4*mult_dat)], dim=0) + # Set up inference. + obs_seq_length, alphabet_length, z_dim = 6, 2, 2 + adam_params = {"lr": 0.1, "betas": (0.90, 0.999)} + optimizer = Adam(adam_params) + model = FactorMuE(obs_seq_length, alphabet_length, z_dim) + + svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO()) + n_steps = 10*mult_step + + # Run inference. + losses = [] + t0 = datetime.datetime.now() + for step in range(n_steps): + + #trace = poutine.trace(poutine.enum(model.model, first_available_dim=-3)).get_trace(data) + #trace.compute_log_prob() + #print(trace.nodes) + + loss = svi.step(data) + losses.append(loss) + if step % 10 == 0: + print(loss, ' ', datetime.datetime.now() - t0) + + # Plots. + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + plt.savefig('FactorMuE/loss_{}.pdf'.format(time_stamp)) + + plt.figure(figsize=(6, 6)) + latent = model.encoder(data)[0].detach() + plt.scatter(latent[:, 0], latent[:, 1]) + plt.xlabel('z_1') + plt.ylabel('z_2') + plt.savefig('FactorMuE/latent_{}.pdf'.format(time_stamp)) + + plt.figure(figsize=(6, 6)) + decoder_bias = pyro.param('decoder$$$f.bias').detach() + decoder_bias = decoder_bias.reshape( + [-1, 2, model.latent_seq_length+1, model.alphabet_length]) + plt.plot(decoder_bias[0, 0, :, 1]) + plt.xlabel('position') + plt.ylabel('bias for character 1') + plt.savefig('FactorMuE/decoder_bias_{}.pdf'.format(time_stamp)) + + for xi, x in enumerate(xs): + reconstruct_x = model.reconstruct_ancestor_seq( + x, pyro.param("inverse_temp_q_mn")).detach() + plt.figure(figsize=(6, 6)) + plt.plot(reconstruct_x[0, :, 1], label="reconstruct") + plt.plot(x[:, 1], label="data") + plt.xlabel('position') + plt.ylabel('probability of character 1') + plt.legend() + plt.savefig('FactorMuE/reconstruction_{}_{}.pdf'.format( + xi, time_stamp)) + + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + plt.savefig('FactorMuE/insert_prob_{}.pdf'.format(time_stamp)) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + plt.savefig('FactorMuE/delete_prob_{}.pdf'.format(time_stamp)) + + +if __name__ == '__main__': + main() From c017cf54d548d088421f8497510b3a59cd9b6d08 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 31 Dec 2020 13:12:33 -0500 Subject: [PATCH 11/91] Multistage training. --- examples/contrib/mue/FactorMuE.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 24053da1b7..cf54724944 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -8,11 +8,12 @@ import torch import torch.nn as nn from torch.nn.functional import softplus +from torch.optim import Adam import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.optim import Adam +from pyro.optim import MultiStepLR from pyro.infer import SVI, Trace_ELBO from pyro.contrib.mue.statearrangers import profile @@ -137,17 +138,12 @@ def model(self, data): initial_logits, transition_logits, observation_logits = ( self.statearrange(ancestor_seq_logits, insert_seq_logits, insert_logits, delete_logits)) - print('initial_logits', initial_logits) - print('transition_logits', transition_logits) # Draw samples. pyro.sample("obs", VariableLengthDiscreteHMM(initial_logits, transition_logits, observation_logits), obs=data) - print(VariableLengthDiscreteHMM(initial_logits, - transition_logits, - observation_logits).log_prob(data)) def guide(self, data): # Register encoder with pyro. @@ -159,8 +155,6 @@ def guide(self, data): * self.indel_prior) insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape)) - print('insert_q_sd', insert_q_mn) - print('softplus insert', softplus(insert_q_sd)) pyro.sample("insert", dist.Normal( insert_q_mn, softplus(insert_q_sd)).to_event(3)) delete_q_mn = pyro.param("delete_q_mn", @@ -208,7 +202,7 @@ def main(): mult_step = 1 else: mult_dat = 10 - mult_step = 100 + mult_step = 400 xs = [torch.tensor([[0., 1.], [1., 0.], @@ -233,11 +227,15 @@ def main(): [xs[2][None, :, :] for j in range(4*mult_dat)], dim=0) # Set up inference. obs_seq_length, alphabet_length, z_dim = 6, 2, 2 - adam_params = {"lr": 0.1, "betas": (0.90, 0.999)} - optimizer = Adam(adam_params) + # adam_params = {"lr": 0.1, "betas": (0.90, 0.999)} + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.1}, + 'milestones': [20, 100, 1000, 2000], + 'gamma': 0.5}) + # optimizer = Adam(adam_params) model = FactorMuE(obs_seq_length, alphabet_length, z_dim) - svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO()) + svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO()) n_steps = 10*mult_step # Run inference. @@ -251,8 +249,9 @@ def main(): loss = svi.step(data) losses.append(loss) + scheduler.step() if step % 10 == 0: - print(loss, ' ', datetime.datetime.now() - t0) + print(step, loss, ' ', datetime.datetime.now() - t0) # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") From 0ae60704eca27d46f4c2cac1e4ae7778da1c17b8 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 31 Dec 2020 14:16:48 -0500 Subject: [PATCH 12/91] Cleanup. --- examples/contrib/mue/FactorMuE.py | 20 +++++++------------- pyro/contrib/mue/__init__.py | 5 +++++ pyro/contrib/mue/variablelengthhmm.py | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index cf54724944..efcf34c117 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -5,6 +5,9 @@ A PCA model with a MuE emission (FactorMuE). Uses the MuE package. """ +import datetime + +import matplotlib.pyplot as plt import torch import torch.nn as nn from torch.nn.functional import softplus @@ -13,16 +16,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.optim import MultiStepLR -from pyro.infer import SVI, Trace_ELBO - from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM - -import datetime -import matplotlib.pyplot as plt - -import pdb +from pyro.infer import SVI, Trace_ELBO +from pyro.optim import MultiStepLR class Encoder(nn.Module): @@ -183,7 +180,7 @@ def reconstruct_ancestor_seq(self, data, inverse_temp=1.): z_loc = self.encoder(data)[0] # Reconstruct latent_seq = self.decoder(z_loc) - # Construct ancestral. + # Construct ancestral sequence. ancestor_seq_logits = latent_seq[..., 0, :, :] * softplus(inverse_temp) ancestor_seq_logits = (ancestor_seq_logits - ancestor_seq_logits.logsumexp(-1, True)) @@ -204,6 +201,7 @@ def main(): mult_dat = 10 mult_step = 400 + # Construct example dataset. xs = [torch.tensor([[0., 1.], [1., 0.], [0., 1.], @@ -243,10 +241,6 @@ def main(): t0 = datetime.datetime.now() for step in range(n_steps): - #trace = poutine.trace(poutine.enum(model.model, first_available_dim=-3)).get_trace(data) - #trace.compute_log_prob() - #print(trace.nodes) - loss = svi.step(data) losses.append(loss) scheduler.step() diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index a963573304..4126c6dd60 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -3,9 +3,14 @@ """ The :mod:`pyro.contrib.mue` module provides tools for working with mutational emission (MuE) distributions. +See Weinstein and Marks (2020), +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. +Primary developer is Eli N. Weinstein (https://eweinstein.github.io/). """ from pyro.contrib.mue.statearrangers import profile +from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM __all__ = [ "profile" + "VariableLengthDiscreteHMM" ] diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 4a434e5fbf..147884e255 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -3,8 +3,8 @@ import torch from torch.distributions import constraints -from pyro.distributions.torch_distribution import TorchDistribution from pyro.distributions.hmm import _sequential_logmatmulexp +from pyro.distributions.torch_distribution import TorchDistribution from pyro.distributions.util import broadcast_shape From 112e86aa53d60f66e5f557b1a65ae47d24ba7e17 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 31 Dec 2020 14:26:52 -0500 Subject: [PATCH 13/91] Add parser, adjust plot saving to avoid creating new folders. --- examples/contrib/mue/FactorMuE.py | 34 ++++++++++++++++++------------- examples/contrib/mue/phmm.py | 24 +++++++++++++++------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index efcf34c117..f1cca749c3 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -5,9 +5,6 @@ A PCA model with a MuE emission (FactorMuE). Uses the MuE package. """ -import datetime - -import matplotlib.pyplot as plt import torch import torch.nn as nn from torch.nn.functional import softplus @@ -16,10 +13,15 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine +from pyro.optim import MultiStepLR +from pyro.infer import SVI, Trace_ELBO + from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM -from pyro.infer import SVI, Trace_ELBO -from pyro.optim import MultiStepLR + +import argparse +import datetime +import matplotlib.pyplot as plt class Encoder(nn.Module): @@ -187,12 +189,12 @@ def reconstruct_ancestor_seq(self, data, inverse_temp=1.): return torch.exp(ancestor_seq_logits) -def main(): +def main(args): torch.manual_seed(9) torch.set_default_tensor_type('torch.DoubleTensor') - small_test = False + small_test = args.test if small_test: mult_dat = 1 @@ -253,14 +255,14 @@ def main(): plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - plt.savefig('FactorMuE/loss_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) latent = model.encoder(data)[0].detach() plt.scatter(latent[:, 0], latent[:, 1]) plt.xlabel('z_1') plt.ylabel('z_2') - plt.savefig('FactorMuE/latent_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE.latent_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) decoder_bias = pyro.param('decoder$$$f.bias').detach() @@ -269,7 +271,7 @@ def main(): plt.plot(decoder_bias[0, 0, :, 1]) plt.xlabel('position') plt.ylabel('bias for character 1') - plt.savefig('FactorMuE/decoder_bias_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE.decoder_bias_{}.pdf'.format(time_stamp)) for xi, x in enumerate(xs): reconstruct_x = model.reconstruct_ancestor_seq( @@ -280,7 +282,7 @@ def main(): plt.xlabel('position') plt.ylabel('probability of character 1') plt.legend() - plt.savefig('FactorMuE/reconstruction_{}_{}.pdf'.format( + plt.savefig('FactorMuE.reconstruction_{}_{}.pdf'.format( xi, time_stamp)) plt.figure(figsize=(6, 6)) @@ -289,15 +291,19 @@ def main(): plt.plot(insert_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of insert') - plt.savefig('FactorMuE/insert_prob_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE.insert_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of delete') - plt.savefig('FactorMuE/delete_prob_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE.delete_prob_{}.pdf'.format(time_stamp)) if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description="Basic Factor MuE model.") + parser.add_argument('-t', '--test', action='store_true', default=False, + help='small dataset, a few steps') + args = parser.parse_args() + main(args) diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py index 6ae59666b5..b114eedc21 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/phmm.py @@ -17,6 +17,7 @@ from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +import argparse import datetime import matplotlib.pyplot as plt @@ -109,8 +110,12 @@ def guide(self, data): delete_q_mn, softplus(delete_q_sd)).to_event(3)) -def main(): - small_test = False +def main(args): + + torch.manual_seed(0) + torch.set_default_tensor_type('torch.DoubleTensor') + + small_test = args.test if small_test: mult_dat = 1 @@ -157,7 +162,7 @@ def main(): plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - plt.savefig('phmm/loss_{}.pdf'.format(time_stamp)) + plt.savefig('phmm.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) ancestor_seq = pyro.param("ancestor_seq_q_mn").detach() @@ -166,7 +171,7 @@ def main(): plt.plot(ancestor_seq_expect[:, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of character 1') - plt.savefig('phmm/ancestor_seq_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm.ancestor_seq_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() @@ -174,15 +179,20 @@ def main(): plt.plot(insert_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of insert') - plt.savefig('phmm/insert_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm.insert_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of delete') - plt.savefig('phmm/delete_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm.delete_prob_{}.pdf'.format(time_stamp)) if __name__ == '__main__': - main() + parser = argparse.ArgumentParser( + description="Basic profile HMM model (constant + MuE).") + parser.add_argument('-t', '--test', action='store_true', default=False, + help='small dataset, a few steps') + args = parser.parse_args() + main(args) From f80100ff39b69888dc2279abfcb7ede4e150ddf2 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 31 Dec 2020 14:47:38 -0500 Subject: [PATCH 14/91] Cleanup imports. --- examples/contrib/mue/FactorMuE.py | 25 ++++++++++++------------ examples/contrib/mue/phmm.py | 19 +++++++++--------- tests/contrib/mue/test_statearrangers.py | 4 ++-- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index f1cca749c3..13b164a2c2 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -5,6 +5,10 @@ A PCA model with a MuE emission (FactorMuE). Uses the MuE package. """ +import argparse +import datetime +import matplotlib.pyplot as plt + import torch import torch.nn as nn from torch.nn.functional import softplus @@ -12,16 +16,13 @@ import pyro import pyro.distributions as dist -import pyro.poutine as poutine -from pyro.optim import MultiStepLR -from pyro.infer import SVI, Trace_ELBO from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM -import argparse -import datetime -import matplotlib.pyplot as plt +from pyro.infer import SVI, Trace_ELBO +from pyro.optim import MultiStepLR +import pyro.poutine as poutine class Encoder(nn.Module): @@ -255,14 +256,14 @@ def main(args): plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - plt.savefig('FactorMuE.loss_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE_plot.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) latent = model.encoder(data)[0].detach() plt.scatter(latent[:, 0], latent[:, 1]) plt.xlabel('z_1') plt.ylabel('z_2') - plt.savefig('FactorMuE.latent_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE_plot.latent_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) decoder_bias = pyro.param('decoder$$$f.bias').detach() @@ -271,7 +272,7 @@ def main(args): plt.plot(decoder_bias[0, 0, :, 1]) plt.xlabel('position') plt.ylabel('bias for character 1') - plt.savefig('FactorMuE.decoder_bias_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE_plot.decoder_bias_{}.pdf'.format(time_stamp)) for xi, x in enumerate(xs): reconstruct_x = model.reconstruct_ancestor_seq( @@ -282,7 +283,7 @@ def main(args): plt.xlabel('position') plt.ylabel('probability of character 1') plt.legend() - plt.savefig('FactorMuE.reconstruction_{}_{}.pdf'.format( + plt.savefig('FactorMuE_plot.reconstruction_{}_{}.pdf'.format( xi, time_stamp)) plt.figure(figsize=(6, 6)) @@ -291,14 +292,14 @@ def main(args): plt.plot(insert_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of insert') - plt.savefig('FactorMuE.insert_prob_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of delete') - plt.savefig('FactorMuE.delete_prob_{}.pdf'.format(time_stamp)) + plt.savefig('FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp)) if __name__ == '__main__': diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py index b114eedc21..c1d45810b7 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/phmm.py @@ -5,21 +5,22 @@ A standard profile HMM model example, using the MuE package. """ +import argparse +import datetime +import matplotlib.pyplot as plt + import torch import torch.nn as nn from torch.nn.functional import softplus import pyro import pyro.distributions as dist -from pyro.optim import Adam -from pyro.infer import SVI, Trace_ELBO from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM -import argparse -import datetime -import matplotlib.pyplot as plt +from pyro.infer import SVI, Trace_ELBO +from pyro.optim import Adam class ProfileHMM(nn.Module): @@ -162,7 +163,7 @@ def main(args): plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - plt.savefig('phmm.loss_{}.pdf'.format(time_stamp)) + plt.savefig('phmm_plot.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) ancestor_seq = pyro.param("ancestor_seq_q_mn").detach() @@ -171,7 +172,7 @@ def main(args): plt.plot(ancestor_seq_expect[:, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of character 1') - plt.savefig('phmm.ancestor_seq_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm_plot.ancestor_seq_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() @@ -179,14 +180,14 @@ def main(args): plt.plot(insert_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of insert') - plt.savefig('phmm.insert_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm_plot.insert_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of delete') - plt.savefig('phmm.delete_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm_plot.delete_prob_{}.pdf'.format(time_stamp)) if __name__ == '__main__': diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 881fa80c2a..d8d6ad4fa3 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,7 +1,7 @@ -import torch import pytest +import torch -from pyro.contrib.mue.statearrangers import profile, mg2k +from pyro.contrib.mue.statearrangers import mg2k, profile def simpleprod(lst): From a11aef008d16bd7adda7b43bf4e9f5b3e239e9b5 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 14 Jan 2021 18:17:08 -0500 Subject: [PATCH 15/91] More extensive testing for VariableLengthHMM. --- tests/contrib/mue/test_statearrangers.py | 69 ++++++++++++- tests/contrib/mue/test_variablelengthhmm.py | 103 ++++++++++++++++++++ 2 files changed, 170 insertions(+), 2 deletions(-) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index d8d6ad4fa3..e1eeed68bf 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -18,6 +18,7 @@ def simpleprod(lst): def test_profile(M, batch_size, substitute): torch.set_default_tensor_type('torch.DoubleTensor') + # --- Setup random model. --- pf_arranger = profile(M) u1 = torch.rand((M+1, 3)) @@ -35,6 +36,7 @@ def test_profile(M, batch_size, substitute): u1 = torch.rand((batch_size, M+1, 3)) u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) + # Compute forward pass of state arranger to get HMM parameters. if substitute: ll = torch.rand((4, 5)) ll = ll/torch.sum(ll, dim=1, keepdim=True) @@ -45,7 +47,11 @@ def test_profile(M, batch_size, substitute): a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), torch.log(r), torch.log(u)) - # - Remake transition matrices. - + # - Remake HMM parameters to check. - + # Here we implement Equation S40 from the MuE paper + # (https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf) + # more directly, iterating over all the indices of the transition matrix + # and initial transition vector. K = 2*(M+1) if batch_size is None: batch_dim_size = 1 @@ -125,7 +131,8 @@ def test_profile(M, batch_size, substitute): chk_e[b, k, :] = c[b, m, :] if substitute: chk_e = torch.matmul(chk_e, ll) - # - - + + # --- Check --- if batch_size is None: chk_a = chk_a.squeeze() chk_a0 = chk_a0.squeeze() @@ -139,3 +146,61 @@ def test_profile(M, batch_size, substitute): assert torch.allclose(chk_a0, torch.exp(a0ln)) assert torch.allclose(chk_a, torch.exp(aln)) assert torch.allclose(chk_e, torch.exp(eln)) + + +@pytest.mark.parametrize('batch_ancestor_seq', [False, True]) +@pytest.mark.parametrize('batch_insert_seq', [False, True]) +@pytest.mark.parametrize('batch_insert', [False, True]) +@pytest.mark.parametrize('batch_delete', [False, True]) +@pytest.mark.parametrize('batch_substitute', [False, True]) +def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, + batch_delete, batch_substitute): + + M = 5 + pf_arranger = profile(M) + u1 = torch.rand((M+1, 3)) + u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) + r1 = torch.rand((M+1, 3)) + r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) + s = torch.rand((M+1, 4)) + s = s/torch.sum(s, dim=1, keepdim=True) + c = torch.rand((M+1, 4)) + c = c/torch.sum(c, dim=1, keepdim=True) + ll = torch.rand((4, 5)) + ll = ll/torch.sum(ll, dim=1, keepdim=True) + a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), + torch.log(r), torch.log(u), + torch.log(ll)) + + +@pytest.mark.parametrize('M', [2, 20]) +@pytest.mark.parametrize('batch_size', [None, 5]) +def test_profile_trivial_case(M, batch_size): + + torch.set_default_tensor_type('torch.DoubleTensor') + + # --- Setup random model. --- + pf_arranger = profile(M) + + u1 = torch.rand((M+1, 3)) + u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) + r1 = torch.rand((M+1, 3)) + r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) + s = torch.rand((M+1, 4)) + s = s/torch.sum(s, dim=1, keepdim=True) + c = torch.rand((M+1, 4)) + c = c/torch.sum(c, dim=1, keepdim=True) + + if batch_size is not None: + s = torch.rand((batch_size, M+1, 4)) + s = s/torch.sum(s, dim=2, keepdim=True) + u1 = torch.rand((batch_size, M+1, 3)) + u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) + + # Compute forward pass of state arranger to get HMM parameters. + if substitute: + ll = torch.rand((4, 5)) + ll = ll/torch.sum(ll, dim=1, keepdim=True) + a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), + torch.log(r), torch.log(u), + torch.log(ll)) diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_variablelengthhmm.py index 751071009a..b99e225625 100644 --- a/tests/contrib/mue/test_variablelengthhmm.py +++ b/tests/contrib/mue/test_variablelengthhmm.py @@ -1,3 +1,6 @@ +from pyro.distributions import DiscreteHMM, Categorical +from pyro.distributions.util import broadcast_shape +import pytest import torch from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM @@ -68,3 +71,103 @@ def test_hmm_log_prob(): chk_lp = torch.cat([chk_lp[0][None], torch.log(torch.sum(f))[None]]) assert torch.allclose(lp, chk_lp) + + +@pytest.mark.parametrize('batch_initial', [False, True]) +@pytest.mark.parametrize('batch_transition', [False, True]) +@pytest.mark.parametrize('batch_observation', [False, True]) +@pytest.mark.parametrize('batch_data', [False, True]) +def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): + torch.set_default_tensor_type('torch.DoubleTensor') + + # Dimensions. + batch_size = 3 + state_dim, observation_dim, num_steps = 4, 5, 6 + + # Model initialization. + initial_logits = torch.randn([batch_size]*batch_initial + [state_dim]) + initial_logits = (initial_logits - + initial_logits.logsumexp(-1, True)) + transition_logits = torch.randn([batch_size]*batch_transition + + [state_dim, state_dim]) + transition_logits = (transition_logits - + transition_logits.logsumexp(-1, True)) + observation_logits = torch.randn([batch_size]*batch_observation + + [state_dim, observation_dim]) + observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + + hmm = VariableLengthDiscreteHMM(initial_logits, transition_logits, + observation_logits) + + # Random observations. + value = (torch.randint(observation_dim, + [batch_size]*batch_data + [num_steps]).unsqueeze(-1) + == torch.arange(observation_dim)).double() + + # Log probability. + lp = hmm.log_prob(value) + + # Check shapes: + if all([not batch_initial, not batch_transition, not batch_observation, + not batch_data]): + assert lp.shape == () + else: + assert lp.shape == (batch_size,) + + +@pytest.mark.parametrize('batch_initial', [False, True]) +@pytest.mark.parametrize('batch_transition', [False, True]) +@pytest.mark.parametrize('batch_observation', [False, True]) +@pytest.mark.parametrize('batch_data', [False, True]) +def test_DiscreteHMM_comparison(batch_initial, batch_transition, + batch_observation, batch_data): + # Dimensions. + batch_size = 3 + state_dim, observation_dim, num_steps = 4, 5, 6 + + # -- Model setup --. + transition_logits_vldhmm = torch.randn([batch_size]*batch_transition + + [state_dim, state_dim]) + transition_logits_vldhmm = (transition_logits_vldhmm - + transition_logits_vldhmm.logsumexp(-1, True)) + # Adjust for DiscreteHMM broadcasting convention. + transition_logits_dhmm = transition_logits_vldhmm.unsqueeze(-3) + # Convert between discrete HMM convention for initial state and variable + # length HMM convention. + initial_logits_dhmm = torch.randn([batch_size]*batch_initial + [state_dim]) + initial_logits_dhmm = (initial_logits_dhmm - + initial_logits_dhmm.logsumexp(-1, True)) + initial_logits_vldhmm = (initial_logits_dhmm.unsqueeze(-1) + + transition_logits_vldhmm).logsumexp(-2) + observation_logits = torch.randn([batch_size]*batch_observation + + [state_dim, observation_dim]) + observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + # Create distribution object for DiscreteHMM + observation_dist = Categorical(logits=observation_logits.unsqueeze(-3)) + + vldhmm = VariableLengthDiscreteHMM(initial_logits_vldhmm, + transition_logits_vldhmm, + observation_logits) + dhmm = DiscreteHMM(initial_logits_dhmm, transition_logits_dhmm, + observation_dist) + + # Random observations. + value = torch.randint(observation_dim, + [batch_size]*batch_data + [num_steps]) + value_oh = (value.unsqueeze(-1) + == torch.arange(observation_dim)).double() + + # -- Check. -- + # Log probability. + lp_vldhmm = vldhmm.log_prob(value_oh) + lp_dhmm = dhmm.log_prob(value) + # Shapes. + if all([not batch_initial, not batch_transition, not batch_observation, + not batch_data]): + assert lp_vldhmm.shape == () + else: + assert lp_vldhmm.shape == (batch_size,) + # Values. + assert torch.allclose(lp_vldhmm, lp_dhmm) From 81e0d49815165a411beba39cae56b6ff213c5c97 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 14 Jan 2021 20:11:50 -0500 Subject: [PATCH 16/91] Docs for mue. --- docs/source/contrib.mue.rst | 29 ++++++++++++ docs/source/index.rst | 4 +- pyro/contrib/mue/__init__.py | 7 --- pyro/contrib/mue/statearrangers.py | 14 +++--- tests/contrib/mue/test_statearrangers.py | 58 ++++++++---------------- 5 files changed, 59 insertions(+), 53 deletions(-) create mode 100644 docs/source/contrib.mue.rst diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst new file mode 100644 index 0000000000..335bcd8965 --- /dev/null +++ b/docs/source/contrib.mue.rst @@ -0,0 +1,29 @@ +MuE +=== +.. automodule:: pyro.contrib.mue + +.. warning:: Code in ``pyro.contrib.mue`` is under development. + This code makes no guarantee about maintaining backwards compatibility. + +``pyro.contrib.mue`` provides modeling tools for working with biological +sequence data. In particular it implements MuE distributions, which are used as +a fully probabilistic alternative to multiple sequence alignment-based +preprocessing. + +Reference: +MuE models were described in Weinstein and Marks (2020), +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. + +State Arrangers for Parameterizing MuEs +--------------------------------------- +.. automodule:: pyro.contrib.mue.statearrangers + :members: + :show-inheritance: + :member-order: bysource + +Variable Length/Missing Data HMM +-------------------------------- +.. automodule:: pyro.contrib.mue.variablelengthhmm + :members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/index.rst b/docs/source/index.rst index 01a49f154e..107263d847 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,7 +14,7 @@ Pyro Documentation :caption: Pyro Core: getting_started - primitives + primitives inference distributions parameters @@ -38,6 +38,7 @@ Pyro Documentation contrib.funsor contrib.gp contrib.minipyro + contrib.mue contrib.oed contrib.randomvariable contrib.timeseries @@ -51,4 +52,3 @@ Indices and tables * :ref:`search` .. * :ref:`modindex` - diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index 4126c6dd60..2aa86413ab 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,12 +1,5 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -""" -The :mod:`pyro.contrib.mue` module provides tools for working with mutational -emission (MuE) distributions. -See Weinstein and Marks (2020), -https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. -Primary developer is Eli N. Weinstein (https://eweinstein.github.io/). -""" from pyro.contrib.mue.statearrangers import profile from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 953e1fbb3d..e5963e7c45 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -4,13 +4,10 @@ import torch.nn as nn -def mg2k(m, g): - """Convert from (m, g) indexing to k indexing.""" - return 2*m + 1 - g - - class profile(nn.Module): - + """ + Profile HMM state arrangement. + """ def __init__(self, M, epsilon=1e-32): super().__init__() self.M = M @@ -150,3 +147,8 @@ def forward(self, ancestor_seq_logits, insert_seq_logits, observation_logits = seq_logits return initial_logits, transition_logits, observation_logits + + +def mg2k(m, g): + """Convert from (m, g) indexing to k indexing.""" + return 2*m + 1 - g diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index e1eeed68bf..4e1c4ac714 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -156,51 +156,33 @@ def test_profile(M, batch_size, substitute): def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, batch_delete, batch_substitute): - M = 5 + M, D, B = 5, 2, 6 + K = 2*(M+1) + batch_size = 6 pf_arranger = profile(M) - u1 = torch.rand((M+1, 3)) - u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) - r1 = torch.rand((M+1, 3)) - r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) - s = torch.rand((M+1, 4)) - s = s/torch.sum(s, dim=1, keepdim=True) - c = torch.rand((M+1, 4)) - c = c/torch.sum(c, dim=1, keepdim=True) - ll = torch.rand((4, 5)) - ll = ll/torch.sum(ll, dim=1, keepdim=True) - a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), - torch.log(r), torch.log(u), - torch.log(ll)) + sln = torch.randn([batch_size]*batch_ancestor_seq + [M+1, 4]) + cln = torch.randn([batch_size]*batch_insert_seq + [M+1, 4]) + rln = torch.randn([batch_size]*batch_insert + [M+1, 3, 2]) + uln = torch.randn([batch_size]*batch_delete + [M+1, 3, 2]) + lln = torch.randn([batch_size]*batch_substitute + [D, B]) + a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) + + if all([not batch_ancestor_seq, not batch_insert_seq, not batch_insert, + not batch_delete, not batch_substitute]): + assert a0ln.shape == (K,) + assert aln.shape == (K, K) + assert eln.shape == (K, B) + else: + assert a0ln.shape == (batch_size, K) + assert aln.shape == (batch_size, K, K) + assert eln.shape == (batch_size, K, B) @pytest.mark.parametrize('M', [2, 20]) @pytest.mark.parametrize('batch_size', [None, 5]) -def test_profile_trivial_case(M, batch_size): +def test_profile_trivial_cases(M, batch_size): torch.set_default_tensor_type('torch.DoubleTensor') # --- Setup random model. --- pf_arranger = profile(M) - - u1 = torch.rand((M+1, 3)) - u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) - r1 = torch.rand((M+1, 3)) - r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) - s = torch.rand((M+1, 4)) - s = s/torch.sum(s, dim=1, keepdim=True) - c = torch.rand((M+1, 4)) - c = c/torch.sum(c, dim=1, keepdim=True) - - if batch_size is not None: - s = torch.rand((batch_size, M+1, 4)) - s = s/torch.sum(s, dim=2, keepdim=True) - u1 = torch.rand((batch_size, M+1, 3)) - u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) - - # Compute forward pass of state arranger to get HMM parameters. - if substitute: - ll = torch.rand((4, 5)) - ll = ll/torch.sum(ll, dim=1, keepdim=True) - a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), - torch.log(r), torch.log(u), - torch.log(ll)) From e18b03334e7b010e3e280aa98a666dd41996c65d Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 15 Jan 2021 17:12:26 -0500 Subject: [PATCH 17/91] Shape tests and trivial case tests for profile statearranger. --- pyro/contrib/mue/statearrangers.py | 3 +- tests/contrib/mue/test_statearrangers.py | 51 +++++++++++++++++++----- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index e5963e7c45..14ca84e561 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -142,7 +142,8 @@ def forward(self, ancestor_seq_logits, insert_seq_logits, # Option to include the substitution matrix. if substitute_logits is not None: observation_logits = torch.logsumexp( - seq_logits.unsqueeze(-1) + substitute_logits, dim=-2) + seq_logits.unsqueeze(-1) + substitute_logits.unsqueeze(-3), + dim=-2) else: observation_logits = seq_logits diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 4e1c4ac714..9ffcedccf2 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -2,6 +2,7 @@ import torch from pyro.contrib.mue.statearrangers import mg2k, profile +from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM def simpleprod(lst): @@ -156,33 +157,61 @@ def test_profile(M, batch_size, substitute): def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, batch_delete, batch_substitute): - M, D, B = 5, 2, 6 + M, D, B = 5, 2, 3 K = 2*(M+1) batch_size = 6 pf_arranger = profile(M) - sln = torch.randn([batch_size]*batch_ancestor_seq + [M+1, 4]) - cln = torch.randn([batch_size]*batch_insert_seq + [M+1, 4]) + sln = torch.randn([batch_size]*batch_ancestor_seq + [M+1, D]) + cln = torch.randn([batch_size]*batch_insert_seq + [M+1, D]) rln = torch.randn([batch_size]*batch_insert + [M+1, 3, 2]) uln = torch.randn([batch_size]*batch_delete + [M+1, 3, 2]) lln = torch.randn([batch_size]*batch_substitute + [D, B]) a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) - if all([not batch_ancestor_seq, not batch_insert_seq, not batch_insert, - not batch_delete, not batch_substitute]): + if all([not batch_ancestor_seq, not batch_insert_seq, + not batch_substitute]): + assert eln.shape == (K, B) + else: + assert eln.shape == (batch_size, K, B) + + if all([not batch_insert, not batch_delete]): assert a0ln.shape == (K,) assert aln.shape == (K, K) - assert eln.shape == (K, B) else: assert a0ln.shape == (batch_size, K) assert aln.shape == (batch_size, K, K) - assert eln.shape == (batch_size, K, B) @pytest.mark.parametrize('M', [2, 20]) -@pytest.mark.parametrize('batch_size', [None, 5]) -def test_profile_trivial_cases(M, batch_size): +def test_profile_trivial_cases(M): + # Trivial case: indel probabability of zero. Expected value of + # HMM should match ancestral sequence times substitution matrix. + # --- Setup model. --- torch.set_default_tensor_type('torch.DoubleTensor') - - # --- Setup random model. --- + D, B = 2, 2 + batch_size = 5 pf_arranger = profile(M) + sln = torch.randn([batch_size, M+1, D]) + sln = sln - sln.logsumexp(-1, True) + cln = torch.randn([batch_size, M+1, D]) + cln = cln - cln.logsumexp(-1, True) + rln = torch.cat([torch.zeros([M+1, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M+1, 3, 1])], axis=-1) + uln = torch.cat([torch.zeros([M+1, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M+1, 3, 1])], axis=-1) + lln = torch.randn([D, B]) + lln = lln - lln.logsumexp(-1, True) + + a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) + + # --- Compute expected value per step. --- + # TODO: replace with VariableLengthDiscreteHMM function once implemented. + Eyln = torch.zeros([batch_size, M, B]) + ai = a0ln + for j in range(M): + Eyln[:, j, :] = torch.logsumexp(ai.unsqueeze(-1) + eln, axis=-2) + ai = torch.logsumexp(ai.unsqueeze(-1) + aln, axis=-2) + + no_indel = torch.logsumexp(sln.unsqueeze(-1) + lln.unsqueeze(-3), axis=-2) + assert torch.allclose(Eyln, no_indel[:, :-1, :]) From a6ad40529bd82ffeb5d4c1c8742b7f0b429dd237 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 15 Jan 2021 20:00:26 -0500 Subject: [PATCH 18/91] Cleaning up naming conventions and doc string conventions. --- pyro/contrib/mue/__init__.py | 8 +++---- .../{biosequenceloaders.py => dataloaders.py} | 0 pyro/contrib/mue/statearrangers.py | 2 +- pyro/contrib/mue/variablelengthhmm.py | 14 +++++------ tests/contrib/mue/test_statearrangers.py | 11 ++++----- tests/contrib/mue/test_variablelengthhmm.py | 23 ++++++++----------- 6 files changed, 26 insertions(+), 32 deletions(-) rename pyro/contrib/mue/{biosequenceloaders.py => dataloaders.py} (100%) diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index 2aa86413ab..601318853b 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.contrib.mue.statearrangers import profile -from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +from pyro.contrib.mue.statearrangers import Profile +from pyro.contrib.mue.variablelengthhmm import MissingDataDiscreteHMM __all__ = [ - "profile" - "VariableLengthDiscreteHMM" + "Profile" + "MissingDataDiscreteHMM" ] diff --git a/pyro/contrib/mue/biosequenceloaders.py b/pyro/contrib/mue/dataloaders.py similarity index 100% rename from pyro/contrib/mue/biosequenceloaders.py rename to pyro/contrib/mue/dataloaders.py diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 14ca84e561..6339c38772 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -4,7 +4,7 @@ import torch.nn as nn -class profile(nn.Module): +class Profile(nn.Module): """ Profile HMM state arrangement. """ diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 147884e255..2667d47a1e 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -8,10 +8,16 @@ from pyro.distributions.util import broadcast_shape -class VariableLengthDiscreteHMM(TorchDistribution): +class MissingDataDiscreteHMM(TorchDistribution): """ HMM with discrete latent states and discrete observations, allowing for variable length sequences. + + .. warning:: Unlike in pyro's DiscreteHMM, which computes the + probability of the first state as + initial.T @ transition @ emission + this distribution uses the standard HMM convention, + initial.T @ emission """ arg_constraints = {"initial_logits": constraints.real, "transition_logits": constraints.real, @@ -46,12 +52,6 @@ def __init__(self, initial_logits, transition_logits, observation_logits, batch_shape, event_shape, validate_args=validate_args) def log_prob(self, value): - """Warning: unlike in pyro's DiscreteHMM, which computes the - probability of the first state as - initial.T @ transition @ emission - this distribution uses the standard HMM convention, - initial.T @ emission - """ # observation_logits: # batch_shape (option) x state_dim x observation_dim # value: diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 9ffcedccf2..2d72a95d48 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,8 +1,7 @@ import pytest import torch -from pyro.contrib.mue.statearrangers import mg2k, profile -from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +from pyro.contrib.mue.statearrangers import mg2k, Profile def simpleprod(lst): @@ -17,10 +16,9 @@ def simpleprod(lst): @pytest.mark.parametrize('batch_size', [None, 5]) @pytest.mark.parametrize('substitute', [False, True]) def test_profile(M, batch_size, substitute): - torch.set_default_tensor_type('torch.DoubleTensor') # --- Setup random model. --- - pf_arranger = profile(M) + pf_arranger = Profile(M) u1 = torch.rand((M+1, 3)) u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) @@ -160,7 +158,7 @@ def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, M, D, B = 5, 2, 3 K = 2*(M+1) batch_size = 6 - pf_arranger = profile(M) + pf_arranger = Profile(M) sln = torch.randn([batch_size]*batch_ancestor_seq + [M+1, D]) cln = torch.randn([batch_size]*batch_insert_seq + [M+1, D]) rln = torch.randn([batch_size]*batch_insert + [M+1, 3, 2]) @@ -188,10 +186,9 @@ def test_profile_trivial_cases(M): # HMM should match ancestral sequence times substitution matrix. # --- Setup model. --- - torch.set_default_tensor_type('torch.DoubleTensor') D, B = 2, 2 batch_size = 5 - pf_arranger = profile(M) + pf_arranger = Profile(M) sln = torch.randn([batch_size, M+1, D]) sln = sln - sln.logsumexp(-1, True) cln = torch.randn([batch_size, M+1, D]) diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_variablelengthhmm.py index b99e225625..96ef2855ee 100644 --- a/tests/contrib/mue/test_variablelengthhmm.py +++ b/tests/contrib/mue/test_variablelengthhmm.py @@ -1,13 +1,11 @@ from pyro.distributions import DiscreteHMM, Categorical -from pyro.distributions.util import broadcast_shape import pytest import torch -from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +from pyro.contrib.mue.variablelengthhmm import MissingDataDiscreteHMM def test_hmm_log_prob(): - torch.set_default_tensor_type('torch.DoubleTensor') a0 = torch.tensor([0.9, 0.08, 0.02]) a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) @@ -20,8 +18,8 @@ def test_hmm_log_prob(): [1., 0.], [0., 0.]]) - hmm_distr = VariableLengthDiscreteHMM(torch.log(a0), torch.log(a), - torch.log(e)) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) lp = hmm_distr.log_prob(x) f = a0 * e[:, 1] @@ -61,8 +59,8 @@ def test_hmm_log_prob(): e[None, :, :], torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :]], dim=0) - hmm_distr = VariableLengthDiscreteHMM(torch.log(a0), torch.log(a), - torch.log(e)) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) lp = hmm_distr.log_prob(x) f = a0[1, :] * e[1, :, 0] @@ -78,7 +76,6 @@ def test_hmm_log_prob(): @pytest.mark.parametrize('batch_observation', [False, True]) @pytest.mark.parametrize('batch_data', [False, True]) def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): - torch.set_default_tensor_type('torch.DoubleTensor') # Dimensions. batch_size = 3 @@ -97,8 +94,8 @@ def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): observation_logits = (observation_logits - observation_logits.logsumexp(-1, True)) - hmm = VariableLengthDiscreteHMM(initial_logits, transition_logits, - observation_logits) + hmm = MissingDataDiscreteHMM(initial_logits, transition_logits, + observation_logits) # Random observations. value = (torch.randint(observation_dim, @@ -147,9 +144,9 @@ def test_DiscreteHMM_comparison(batch_initial, batch_transition, # Create distribution object for DiscreteHMM observation_dist = Categorical(logits=observation_logits.unsqueeze(-3)) - vldhmm = VariableLengthDiscreteHMM(initial_logits_vldhmm, - transition_logits_vldhmm, - observation_logits) + vldhmm = MissingDataDiscreteHMM(initial_logits_vldhmm, + transition_logits_vldhmm, + observation_logits) dhmm = DiscreteHMM(initial_logits_dhmm, transition_logits_dhmm, observation_dist) From d74a6f9bfe91e85c0e5606e066b2b02f919309a3 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 24 Jan 2021 20:10:22 -0500 Subject: [PATCH 19/91] Docstrings with parameter details. --- pyro/contrib/mue/statearrangers.py | 54 +++++++++++++++++++++++++-- pyro/contrib/mue/variablelengthhmm.py | 29 +++++++++++++- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 6339c38772..bfee76ebed 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -6,7 +6,24 @@ class Profile(nn.Module): """ - Profile HMM state arrangement. + Profile HMM state arrangement. Parameterizes an HMM according to + Equation S40 in [1]. For further background on profile HMMs see [2]. + + **References** + + [1] E. N. Weinstein, D. S. Marks (2020) + "Generative probabilistic biological sequence models that account for + mutational variability" + https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf + [2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) + "Biological sequence analysis: probabilistic models of proteins and nucleic + acids" + Cambridge university press + + :param M: Length of precursor (ancestral) sequence. + :type M: int + :param epsilon: Small value for approximate zeros in log space. + :type epsilon: float """ def __init__(self, M, epsilon=1e-32): super().__init__() @@ -118,9 +135,38 @@ def _make_transfer(self): elif g == 1: self.vc_transf[m, k] = 1 - def forward(self, ancestor_seq_logits, insert_seq_logits, + def forward(self, precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits, substitute_logits=None): - """Assemble the HMM parameters based on the transfer matrices.""" + """ + Assemble HMM parameters given profile parameters. + + :param ~torch.Tensor precursor_seq_logits: Initial (relaxed) sequence + *log(x)*. Should have rightmost dimension ``(M+1, D)`` and be + broadcastable to ``(batch_size, M+1, D)``, where + D is the latent alphabet size. Should be normalized to one along the + final axis, i.e. ``precursor_seq_logits.logsumexp(-1) = zeros``. + :param ~torch.Tensor insert_seq_logits: Insertion sequence *log(c)*. + Should have rightmost dimension ``(M+1, D)`` and be broadcastable + to ``(batch_size, M+1, D)``. Should be normalized + along the final axis. + :param ~torch.Tensor insert_logits: Insertion probabilities *log(r)*. + Should have rightmost dimension ``(M+1, 3, 2)`` and be broadcastable + to ``(batch_size, M+1, 3, 2)``. Should be normalized along the + final axis. + :param ~torch.Tensor delete_logits: Deletion probabilities *log(u)*. + Should have rightmost dimension ``(M+1, 3, 2)`` and be broadcastable + to ``(batch_size, M+1, 3, 2)``. Should be normalized along the + final axis. + :param ~torch.Tensor substitute_logits: Substiution probabilities + *log(l)*. Should have rightmost dimension ``(D, B)``, where + B is the alphabet size of the data, and broadcastable to + ``(batch_size, D, B)``. Must be normalized along the + final axis. + :return: *initial_logits*, *transition_logits*, and + *observation_logits*. These parameters can be used to directly + initialize the MissingDataDiscreteHMM distribution. + :rtype: ~torch.Tensor, ~torch.Tensor, ~torch.Tensor + """ initial_logits = ( torch.einsum('...ijk,ijkl->...l', delete_logits, self.u_transf_0) + @@ -135,7 +181,7 @@ def forward(self, ancestor_seq_logits, insert_seq_logits, (-1/self.epsilon)*self.null_transf) seq_logits = ( torch.einsum('...ij,ik->...kj', - ancestor_seq_logits, self.vx_transf) + + precursor_seq_logits, self.vx_transf) + torch.einsum('...ij,ik->...kj', insert_seq_logits, self.vc_transf)) diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/variablelengthhmm.py index 2667d47a1e..bf778d2237 100644 --- a/pyro/contrib/mue/variablelengthhmm.py +++ b/pyro/contrib/mue/variablelengthhmm.py @@ -11,13 +11,28 @@ class MissingDataDiscreteHMM(TorchDistribution): """ HMM with discrete latent states and discrete observations, allowing for - variable length sequences. + missing data or variable length sequences. Observations are assumed + to be one hot encoded; rows with all zeros indicate missing data. .. warning:: Unlike in pyro's DiscreteHMM, which computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission + + :param ~torch.Tensor initial_logits: A logits tensor for an initial + categorical distribution over latent states. Should have rightmost + size ``state_dim`` and be broadcastable to + ``(batch_size, state_dim)``. + :param ~torch.Tensor transition_logits: A logits tensor for transition + conditional distributions between latent states. Should have rightmost + shape ``(state_dim, state_dim)`` (old, new), and be broadcastable + to ``(batch_size, state_dim, state_dim)``. + :param ~torch.Tensor observation_logits: A logits tensor for observation + distributions from latent states. Should have rightmost shape + ``(state_dim, categorical_size)``, where ``categorical_size`` is the + dimension of the categorical output, and be broadcastable + to ``(batch_size, state_dim, categorical_size)``. """ arg_constraints = {"initial_logits": constraints.real, "transition_logits": constraints.real, @@ -48,10 +63,20 @@ def __init__(self, initial_logits, transition_logits, observation_logits, transition_logits.logsumexp(-1, True)) self.observation_logits = (observation_logits - observation_logits.logsumexp(-1, True)) - super(VariableLengthDiscreteHMM, self).__init__( + super(MissingDataDiscreteHMM, self).__init__( batch_shape, event_shape, validate_args=validate_args) def log_prob(self, value): + """ + :param ~torch.Tensor value: One-hot encoded observation. Must be + real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + Missing data is represented by zeros, i.e. + ``value[batch, step, :] == tensor([0, ..., 0])``. + Variable length observation sequences can be handled by padding + the sequence with zeros at the end. + """ # observation_logits: # batch_shape (option) x state_dim x observation_dim # value: From 827721aaf1cc2df46bf7841d7b3a61e167850b2e Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Mon, 8 Feb 2021 18:28:18 -0500 Subject: [PATCH 20/91] Improve indexing conventions, add unit tests. --- pyro/contrib/mue/statearrangers.py | 95 +++++++++++++----------- tests/contrib/mue/test_statearrangers.py | 88 +++++++++++++--------- 2 files changed, 106 insertions(+), 77 deletions(-) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index bfee76ebed..2f1b799ff7 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn +import pdb class Profile(nn.Module): """ Profile HMM state arrangement. Parameterizes an HMM according to - Equation S40 in [1]. For further background on profile HMMs see [2]. + Equation S40 in [1] (with r_{M+1,j} = 1 and u_{M+1,j} = 0 + for j in {0, 1, 2}). For further background on profile HMMs see [2]. **References** @@ -28,7 +30,7 @@ class Profile(nn.Module): def __init__(self, M, epsilon=1e-32): super().__init__() self.M = M - self.K = 2*(M+1) + self.K = 2*M+1 self.epsilon = epsilon self._make_transfer() @@ -46,21 +48,22 @@ def _make_transfer(self): # null -> locations in the transition matrix equal to 0 # ...transf_0 -> initial transition vector # ...transf -> transition matrix + # We fix r_{M+1,j} = 1 for j in {0, 1, 2} self.register_buffer('r_transf_0', - torch.zeros((M+1, 3, 2, K))) + torch.zeros((M, 3, 2, K))) self.register_buffer('u_transf_0', - torch.zeros((M+1, 3, 2, K))) + torch.zeros((M, 3, 2, K))) self.register_buffer('null_transf_0', torch.zeros((K,))) m, g = -1, 0 - for mp in range(M+1): - for gp in range(2): - kp = mg2k(mp, gp) + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: self.r_transf_0[m+1-g, g, 0, kp] = 1 self.u_transf_0[m+1-g, g, 0, kp] = 1 - elif m + 1 - g < mp and mp <= M and gp == 0: + elif m + 1 - g < mp and gp == 0: self.r_transf_0[m+1-g, g, 0, kp] = 1 self.u_transf_0[m+1-g, g, 1, kp] = 1 for mpp in range(m+2-g, mp): @@ -70,36 +73,37 @@ def _make_transfer(self): self.u_transf_0[mp, 2, 0, kp] = 1 elif m + 1 - g == mp and gp == 1: - self.r_transf_0[m+1-g, g, 1, kp] = 1 + if mp < M: + self.r_transf_0[m+1-g, g, 1, kp] = 1 - elif m + 1 - g < mp and mp <= M and gp == 1: + elif m + 1 - g < mp and gp == 1: self.r_transf_0[m+1-g, g, 0, kp] = 1 self.u_transf_0[m+1-g, g, 1, kp] = 1 for mpp in range(m+2-g, mp): self.r_transf_0[mpp, 2, 0, kp] = 1 self.u_transf_0[mpp, 2, 1, kp] = 1 - self.r_transf_0[mp, 2, 1, kp] = 1 + if mp < M: + self.r_transf_0[mp, 2, 1, kp] = 1 else: self.null_transf_0[kp] = 1 - self.u_transf_0[-1, :, :, :] = 0. self.register_buffer('r_transf', - torch.zeros((M+1, 3, 2, K, K))) + torch.zeros((M, 3, 2, K, K))) self.register_buffer('u_transf', - torch.zeros((M+1, 3, 2, K, K))) + torch.zeros((M, 3, 2, K, K))) self.register_buffer('null_transf', torch.zeros((K, K))) - for m in range(M+1): - for g in range(2): - for mp in range(M+1): - for gp in range(2): - k, kp = mg2k(m, g), mg2k(mp, gp) + for g in range(2): + for m in range(M+g): + for gp in range(2): + for mp in range(M+gp): + k, kp = mg2k(m, g, M), mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: self.r_transf[m+1-g, g, 0, k, kp] = 1 self.u_transf[m+1-g, g, 0, k, kp] = 1 - elif m + 1 - g < mp and mp <= M and gp == 0: + elif m + 1 - g < mp and gp == 0: self.r_transf[m+1-g, g, 0, k, kp] = 1 self.u_transf[m+1-g, g, 1, k, kp] = 1 for mpp in range(m+2-g, mp): @@ -109,27 +113,28 @@ def _make_transfer(self): self.u_transf[mp, 2, 0, k, kp] = 1 elif m + 1 - g == mp and gp == 1: - self.r_transf[m+1-g, g, 1, k, kp] = 1 + if mp < M: + self.r_transf[m+1-g, g, 1, k, kp] = 1 - elif m + 1 - g < mp and mp <= M and gp == 1: + elif m + 1 - g < mp and gp == 1: self.r_transf[m+1-g, g, 0, k, kp] = 1 self.u_transf[m+1-g, g, 1, k, kp] = 1 for mpp in range(m+2-g, mp): self.r_transf[mpp, 2, 0, k, kp] = 1 self.u_transf[mpp, 2, 1, k, kp] = 1 - self.r_transf[mp, 2, 1, k, kp] = 1 + if mp < M: + self.r_transf[mp, 2, 1, k, kp] = 1 - elif not (m == M and mp == M and g == 0 and gp == 0): + else: self.null_transf[k, kp] = 1 - self.u_transf[-1, :, :, :, :] = 0. self.register_buffer('vx_transf', - torch.zeros((M+1, K))) + torch.zeros((M, K))) self.register_buffer('vc_transf', torch.zeros((M+1, K))) - for m in range(M+1): - for g in range(2): - k = mg2k(m, g) + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) if g == 0: self.vx_transf[m, k] = 1 elif g == 1: @@ -141,8 +146,8 @@ def forward(self, precursor_seq_logits, insert_seq_logits, Assemble HMM parameters given profile parameters. :param ~torch.Tensor precursor_seq_logits: Initial (relaxed) sequence - *log(x)*. Should have rightmost dimension ``(M+1, D)`` and be - broadcastable to ``(batch_size, M+1, D)``, where + *log(x)*. Should have rightmost dimension ``(M, D)`` and be + broadcastable to ``(batch_size, M, D)``, where D is the latent alphabet size. Should be normalized to one along the final axis, i.e. ``precursor_seq_logits.logsumexp(-1) = zeros``. :param ~torch.Tensor insert_seq_logits: Insertion sequence *log(c)*. @@ -150,14 +155,14 @@ def forward(self, precursor_seq_logits, insert_seq_logits, to ``(batch_size, M+1, D)``. Should be normalized along the final axis. :param ~torch.Tensor insert_logits: Insertion probabilities *log(r)*. - Should have rightmost dimension ``(M+1, 3, 2)`` and be broadcastable - to ``(batch_size, M+1, 3, 2)``. Should be normalized along the + Should have rightmost dimension ``(M, 3, 2)`` and be broadcastable + to ``(batch_size, M, 3, 2)``. Should be normalized along the final axis. :param ~torch.Tensor delete_logits: Deletion probabilities *log(u)*. - Should have rightmost dimension ``(M+1, 3, 2)`` and be broadcastable - to ``(batch_size, M+1, 3, 2)``. Should be normalized along the + Should have rightmost dimension ``(M, 3, 2)`` and be broadcastable + to ``(batch_size, M, 3, 2)``. Should be normalized along the final axis. - :param ~torch.Tensor substitute_logits: Substiution probabilities + :param ~torch.Tensor substitute_logits: Substitution probabilities *log(l)*. Should have rightmost dimension ``(D, B)``, where B is the alphabet size of the data, and broadcastable to ``(batch_size, D, B)``. Must be normalized along the @@ -179,11 +184,15 @@ def forward(self, precursor_seq_logits, insert_seq_logits, torch.einsum('...ijk,ijklf->...lf', insert_logits, self.r_transf) + (-1/self.epsilon)*self.null_transf) - seq_logits = ( - torch.einsum('...ij,ik->...kj', - precursor_seq_logits, self.vx_transf) + - torch.einsum('...ij,ik->...kj', - insert_seq_logits, self.vc_transf)) + # Broadcasting for concatenation. + if len(precursor_seq_logits.size()) > len(insert_seq_logits.size()): + insert_seq_logits = insert_seq_logits.unsqueeze(0).expand( + [precursor_seq_logits.size()[0], -1, -1]) + elif len(insert_seq_logits.size()) > len(precursor_seq_logits.size()): + precursor_seq_logits = precursor_seq_logits.unsqueeze(0).expand( + [insert_seq_logits.size()[0], -1, -1]) + seq_logits = torch.cat([precursor_seq_logits, insert_seq_logits], + dim=-2) # Option to include the substitution matrix. if substitute_logits is not None: @@ -196,6 +205,6 @@ def forward(self, precursor_seq_logits, insert_seq_logits, return initial_logits, transition_logits, observation_logits -def mg2k(m, g): +def mg2k(m, g, M): """Convert from (m, g) indexing to k indexing.""" - return 2*m + 1 - g + return m + M*g diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 2d72a95d48..60953b1cd4 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,6 +1,8 @@ import pytest import torch +import pdb + from pyro.contrib.mue.statearrangers import mg2k, Profile @@ -15,43 +17,49 @@ def simpleprod(lst): @pytest.mark.parametrize('M', [2, 20]) @pytest.mark.parametrize('batch_size', [None, 5]) @pytest.mark.parametrize('substitute', [False, True]) -def test_profile(M, batch_size, substitute): +def test_profile_alternate_imp(M, batch_size, substitute): # --- Setup random model. --- pf_arranger = Profile(M) u1 = torch.rand((M+1, 3)) + u1[M, :] = 0 # Assume u_{M+1, j} = 0 for j in {0, 1, 2} in Eqn. S40. u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) r1 = torch.rand((M+1, 3)) + r1[M, :] = 1 # Assume r_{M+1, j} = 1 for j in {0, 1, 2} in Eqn. S40. r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) - s = torch.rand((M+1, 4)) + s = torch.rand((M, 4)) s = s/torch.sum(s, dim=1, keepdim=True) c = torch.rand((M+1, 4)) c = c/torch.sum(c, dim=1, keepdim=True) if batch_size is not None: - s = torch.rand((batch_size, M+1, 4)) + s = torch.rand((batch_size, M, 4)) s = s/torch.sum(s, dim=2, keepdim=True) u1 = torch.rand((batch_size, M+1, 3)) + u1[:, M, :] = 0 u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) # Compute forward pass of state arranger to get HMM parameters. + # Don't use dimension M, assumed fixed by statearranger. if substitute: ll = torch.rand((4, 5)) ll = ll/torch.sum(ll, dim=1, keepdim=True) - a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), - torch.log(r), torch.log(u), - torch.log(ll)) + a0ln, aln, eln = pf_arranger.forward( + torch.log(s), torch.log(c), + torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :]), + torch.log(ll)) else: - a0ln, aln, eln = pf_arranger.forward(torch.log(s), torch.log(c), - torch.log(r), torch.log(u)) + a0ln, aln, eln = pf_arranger.forward( + torch.log(s), torch.log(c), + torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :])) # - Remake HMM parameters to check. - # Here we implement Equation S40 from the MuE paper # (https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf) # more directly, iterating over all the indices of the transition matrix # and initial transition vector. - K = 2*(M+1) + K = 2*M + 1 if batch_size is None: batch_dim_size = 1 r1 = r1.unsqueeze(0) @@ -72,9 +80,9 @@ def test_profile(M, batch_size, substitute): for b in range(batch_dim_size): m, g = -1, 0 u1[b][-1] = 1e-32 - for mp in range(M+1): - for gp in range(2): - kp = mg2k(mp, gp) + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: chk_a0[b, kp] = (1 - r1[b, m+1-g, g])*(1 - u1[b, m+1-g, g]) elif m + 1 - g < mp and gp == 0: @@ -92,12 +100,12 @@ def test_profile(M, batch_size, substitute): simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] for mpp in range(m+2-g, mp)]) * r1[b, mp, 2]) - for m in range(M+1): - for g in range(2): - k = mg2k(m, g) - for mp in range(M+1): - for gp in range(2): - kp = mg2k(mp, gp) + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: chk_a[b, k, kp] = (1 - r1[b, m+1-g, g] )*(1 - u1[b, m+1-g, g]) @@ -121,9 +129,9 @@ def test_profile(M, batch_size, substitute): elif m == M and mp == M and g == 0 and gp == 0: chk_a[b, k, kp] = 1. - for m in range(M+1): - for g in range(2): - k = mg2k(m, g) + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) if g == 0: chk_e[b, k, :] = s[b, m, :] else: @@ -139,8 +147,8 @@ def test_profile(M, batch_size, substitute): assert torch.allclose(torch.sum(torch.exp(a0ln)), torch.tensor(1.), atol=1e-3, rtol=1e-3) - assert torch.allclose(torch.sum(torch.exp(aln), axis=1)[:-1], - torch.ones(2*(M+1)-1), atol=1e-3, + assert torch.allclose(torch.sum(torch.exp(aln), axis=1), + torch.ones(2*M+1), atol=1e-3, rtol=1e-3) assert torch.allclose(chk_a0, torch.exp(a0ln)) assert torch.allclose(chk_a, torch.exp(aln)) @@ -156,31 +164,42 @@ def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, batch_delete, batch_substitute): M, D, B = 5, 2, 3 - K = 2*(M+1) + K = 2*M + 1 batch_size = 6 pf_arranger = Profile(M) - sln = torch.randn([batch_size]*batch_ancestor_seq + [M+1, D]) + sln = torch.randn([batch_size]*batch_ancestor_seq + [M, D]) + sln = sln - sln.logsumexp(-1, True) cln = torch.randn([batch_size]*batch_insert_seq + [M+1, D]) - rln = torch.randn([batch_size]*batch_insert + [M+1, 3, 2]) - uln = torch.randn([batch_size]*batch_delete + [M+1, 3, 2]) + cln = cln - cln.logsumexp(-1, True) + rln = torch.randn([batch_size]*batch_insert + [M, 3, 2]) + rln = rln - rln.logsumexp(-1, True) + uln = torch.randn([batch_size]*batch_delete + [M, 3, 2]) + uln = uln - uln.logsumexp(-1, True) lln = torch.randn([batch_size]*batch_substitute + [D, B]) + lln = lln - lln.logsumexp(-1, True) a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) if all([not batch_ancestor_seq, not batch_insert_seq, not batch_substitute]): assert eln.shape == (K, B) + assert torch.allclose(eln.logsumexp(-1), torch.zeros(K)) else: assert eln.shape == (batch_size, K, B) + assert torch.allclose(eln.logsumexp(-1), torch.zeros(batch_size, K)) if all([not batch_insert, not batch_delete]): assert a0ln.shape == (K,) + assert torch.allclose(a0ln.logsumexp(-1), torch.zeros(1)) assert aln.shape == (K, K) + assert torch.allclose(aln.logsumexp(-1), torch.zeros(K)) else: assert a0ln.shape == (batch_size, K) + assert torch.allclose(a0ln.logsumexp(-1), torch.zeros(batch_size)) assert aln.shape == (batch_size, K, K) + assert torch.allclose(aln.logsumexp(-1), torch.zeros((batch_size, K))) -@pytest.mark.parametrize('M', [2, 20]) +@pytest.mark.parametrize('M', [2, 20]) # , 20 def test_profile_trivial_cases(M): # Trivial case: indel probabability of zero. Expected value of # HMM should match ancestral sequence times substitution matrix. @@ -189,14 +208,14 @@ def test_profile_trivial_cases(M): D, B = 2, 2 batch_size = 5 pf_arranger = Profile(M) - sln = torch.randn([batch_size, M+1, D]) + sln = torch.randn([batch_size, M, D]) sln = sln - sln.logsumexp(-1, True) cln = torch.randn([batch_size, M+1, D]) cln = cln - cln.logsumexp(-1, True) - rln = torch.cat([torch.zeros([M+1, 3, 1]), - -1/pf_arranger.epsilon*torch.ones([M+1, 3, 1])], axis=-1) - uln = torch.cat([torch.zeros([M+1, 3, 1]), - -1/pf_arranger.epsilon*torch.ones([M+1, 3, 1])], axis=-1) + rln = torch.cat([torch.zeros([M, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) + uln = torch.cat([torch.zeros([M, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) lln = torch.randn([D, B]) lln = lln - lln.logsumexp(-1, True) @@ -210,5 +229,6 @@ def test_profile_trivial_cases(M): Eyln[:, j, :] = torch.logsumexp(ai.unsqueeze(-1) + eln, axis=-2) ai = torch.logsumexp(ai.unsqueeze(-1) + aln, axis=-2) + print(aln.exp()) no_indel = torch.logsumexp(sln.unsqueeze(-1) + lln.unsqueeze(-3), axis=-2) - assert torch.allclose(Eyln, no_indel[:, :-1, :]) + assert torch.allclose(Eyln, no_indel) From e973c231b6a0771321c5b09f54b4d46af3a0103c Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 9 Feb 2021 19:44:14 -0500 Subject: [PATCH 21/91] Rename files and build complete set of options for FactorMuE --- examples/contrib/mue/FactorMuE.py | 195 +------- examples/contrib/mue/phmm.py | 105 +---- pyro/contrib/mue/__init__.py | 2 +- ...variablelengthhmm.py => missingdatahmm.py} | 0 pyro/contrib/mue/models.py | 417 ++++++++++++++++++ ...blelengthhmm.py => test_missingdatahmm.py} | 2 +- tests/contrib/mue/test_statearrangers.py | 2 - 7 files changed, 437 insertions(+), 286 deletions(-) rename pyro/contrib/mue/{variablelengthhmm.py => missingdatahmm.py} (100%) create mode 100644 pyro/contrib/mue/models.py rename tests/contrib/mue/{test_variablelengthhmm.py => test_missingdatahmm.py} (98%) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 13b164a2c2..b6fd2a9660 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -10,184 +10,13 @@ import matplotlib.pyplot as plt import torch -import torch.nn as nn -from torch.nn.functional import softplus from torch.optim import Adam - import pyro -import pyro.distributions as dist -from pyro.contrib.mue.statearrangers import profile -from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +from pyro.contrib.mue.models import FactorMuE from pyro.infer import SVI, Trace_ELBO from pyro.optim import MultiStepLR -import pyro.poutine as poutine - - -class Encoder(nn.Module): - def __init__(self, obs_seq_length, alphabet_length, z_dim): - super().__init__() - - self.input_size = obs_seq_length * alphabet_length - self.f1_mn = nn.Linear(self.input_size, z_dim) - self.f1_sd = nn.Linear(self.input_size, z_dim) - - def forward(self, data): - - data = data.reshape(-1, self.input_size) - z_loc = self.f1_mn(data) - z_scale = softplus(self.f1_sd(data)) - - return z_loc, z_scale - - -class Decoder(nn.Module): - def __init__(self, latent_seq_length, alphabet_length, z_dim): - super().__init__() - - self.latent_seq_length = latent_seq_length - self.alphabet_length = alphabet_length - self.output_size = 2 * (latent_seq_length+1) * alphabet_length - self.f = nn.Linear(z_dim, self.output_size) - - def forward(self, z): - - seq = self.f(z) - seq = seq.reshape([-1, 2, self.latent_seq_length+1, - self.alphabet_length]) - return seq - - -class FactorMuE(nn.Module): - - def __init__(self, obs_seq_length, alphabet_length, z_dim, - scale_factor=1., - latent_seq_length=None, prior_scale=1., - indel_prior_strength=10., inverse_temp_prior=100.): - super().__init__() - - # Constants. - assert isinstance(obs_seq_length, int) and obs_seq_length > 0 - self.obs_seq_length = obs_seq_length - if latent_seq_length is None: - latent_seq_length = obs_seq_length - else: - assert isinstance(latent_seq_length, int) and latent_seq_length > 0 - self.latent_seq_length = latent_seq_length - assert isinstance(alphabet_length, int) and alphabet_length > 0 - self.alphabet_length = alphabet_length - assert isinstance(z_dim, int) and z_dim > 0 - self.z_dim = z_dim - - # Parameter shapes. - self.seq_shape = (latent_seq_length+1, alphabet_length) - self.indel_shape = (latent_seq_length+1, 3, 2) - - # Priors. - assert isinstance(prior_scale, float) - self.prior_scale = torch.tensor(prior_scale) - assert isinstance(indel_prior_strength, float) - self.indel_prior = torch.tensor([indel_prior_strength, 0.]) - assert isinstance(inverse_temp_prior, float) - self.inverse_temp_prior = torch.tensor(inverse_temp_prior) - - # Batch control. - self.scale_factor = scale_factor - - # Initialize layers. - self.encoder = Encoder(obs_seq_length, alphabet_length, z_dim) - self.decoder = Decoder(latent_seq_length, alphabet_length, z_dim) - self.statearrange = profile(latent_seq_length) - - def model(self, data): - - pyro.module("decoder", self.decoder) - - # Indel probabilities. - insert = pyro.sample("insert", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) - insert_logits = insert - insert.logsumexp(-1, True) - delete = pyro.sample("delete", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) - delete_logits = delete - delete.logsumexp(-1, True) - - # Inverse temperature. - inverse_temp = pyro.sample("inverse_temp", dist.Normal( - self.inverse_temp_prior, torch.tensor(1.))) - - with pyro.plate("batch", data.shape[0]), poutine.scale( - scale=self.scale_factor): - # Sample latent variable from prior. - z = pyro.sample("latent", dist.Normal( - torch.zeros(self.z_dim), torch.ones(self.z_dim)).to_event(1)) - # Decode latent sequence. - latent_seq = self.decoder(z) - # Construct ancestral and insertion sequences. - ancestor_seq_logits = (latent_seq[..., 0, :, :] * - softplus(inverse_temp)) - ancestor_seq_logits = (ancestor_seq_logits - - ancestor_seq_logits.logsumexp(-1, True)) - insert_seq_logits = (latent_seq[..., 1, :, :] * - softplus(inverse_temp)) - insert_seq_logits = (insert_seq_logits - - insert_seq_logits.logsumexp(-1, True)) - # Construct HMM parameters. - initial_logits, transition_logits, observation_logits = ( - self.statearrange(ancestor_seq_logits, insert_seq_logits, - insert_logits, delete_logits)) - # Draw samples. - pyro.sample("obs", - VariableLengthDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=data) - - def guide(self, data): - # Register encoder with pyro. - pyro.module("encoder", self.encoder) - - # Indel probabilities. - insert_q_mn = pyro.param("insert_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - insert_q_sd = pyro.param("insert_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("insert", dist.Normal( - insert_q_mn, softplus(insert_q_sd)).to_event(3)) - delete_q_mn = pyro.param("delete_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - delete_q_sd = pyro.param("delete_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("delete", dist.Normal( - delete_q_mn, softplus(delete_q_sd)).to_event(3)) - # Inverse temperature. - inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.)) - inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.)) - pyro.sample("inverse_temp", dist.Normal( - inverse_temp_q_mn, softplus(inverse_temp_q_sd))) - - # Per data latent variables. - with pyro.plate("batch", data.shape[0]), poutine.scale( - scale=self.scale_factor): - # Encode seq. - z_loc, z_scale = self.encoder(data) - # Sample. - pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) - - def reconstruct_ancestor_seq(self, data, inverse_temp=1.): - # Encode seq. - z_loc = self.encoder(data)[0] - # Reconstruct - latent_seq = self.decoder(z_loc) - # Construct ancestral sequence. - ancestor_seq_logits = latent_seq[..., 0, :, :] * softplus(inverse_temp) - ancestor_seq_logits = (ancestor_seq_logits - - ancestor_seq_logits.logsumexp(-1, True)) - return torch.exp(ancestor_seq_logits) def main(args): @@ -234,7 +63,8 @@ def main(args): 'milestones': [20, 100, 1000, 2000], 'gamma': 0.5}) # optimizer = Adam(adam_params) - model = FactorMuE(obs_seq_length, alphabet_length, z_dim) + model = FactorMuE(obs_seq_length, alphabet_length, z_dim, + substitution_matrix=False) svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO()) n_steps = 10*mult_step @@ -265,18 +95,17 @@ def main(args): plt.ylabel('z_2') plt.savefig('FactorMuE_plot.latent_{}.pdf'.format(time_stamp)) - plt.figure(figsize=(6, 6)) - decoder_bias = pyro.param('decoder$$$f.bias').detach() - decoder_bias = decoder_bias.reshape( - [-1, 2, model.latent_seq_length+1, model.alphabet_length]) - plt.plot(decoder_bias[0, 0, :, 1]) - plt.xlabel('position') - plt.ylabel('bias for character 1') - plt.savefig('FactorMuE_plot.decoder_bias_{}.pdf'.format(time_stamp)) + # plt.figure(figsize=(6, 6)) + # decoder_bias = pyro.param('decoder$$$f.bias').detach() + # decoder_bias = decoder_bias.reshape( + # [-1, 2, model.latent_seq_length+1, model.alphabet_length]) + # plt.plot(decoder_bias[0, 0, :, 1]) + # plt.xlabel('position') + # plt.ylabel('bias for character 1') + # plt.savefig('FactorMuE_plot.decoder_bias_{}.pdf'.format(time_stamp)) for xi, x in enumerate(xs): - reconstruct_x = model.reconstruct_ancestor_seq( - x, pyro.param("inverse_temp_q_mn")).detach() + reconstruct_x = model.reconstruct_precursor_seq(x, pyro.param) plt.figure(figsize=(6, 6)) plt.plot(reconstruct_x[0, :, 1], label="reconstruct") plt.plot(x[:, 1], label="data") diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/phmm.py index c1d45810b7..ef45d13f13 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/phmm.py @@ -10,107 +10,14 @@ import matplotlib.pyplot as plt import torch -import torch.nn as nn -from torch.nn.functional import softplus - import pyro -import pyro.distributions as dist -from pyro.contrib.mue.statearrangers import profile -from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM +from pyro.contrib.mue.models import ProfileHMM from pyro.infer import SVI, Trace_ELBO from pyro.optim import Adam -class ProfileHMM(nn.Module): - - def __init__(self, latent_seq_length, alphabet_length, - prior_scale=1., indel_prior_strength=10.): - super().__init__() - - assert isinstance(latent_seq_length, int) and latent_seq_length > 0 - self.latent_seq_length = latent_seq_length - assert isinstance(alphabet_length, int) and alphabet_length > 0 - self.alphabet_length = alphabet_length - - self.seq_shape = (latent_seq_length+1, alphabet_length) - self.indel_shape = (latent_seq_length+1, 3, 2) - - assert isinstance(prior_scale, float) - self.prior_scale = prior_scale - assert isinstance(indel_prior_strength, float) - self.indel_prior = torch.tensor([indel_prior_strength, 0.]) - - # Initialize state arranger. - self.statearrange = profile(latent_seq_length) - - def model(self, data): - - # Latent sequence. - ancestor_seq = pyro.sample("ancestor_seq", dist.Normal( - torch.zeros(self.seq_shape), - self.prior_scale * torch.ones(self.seq_shape)).to_event(2)) - ancestor_seq_logits = ancestor_seq - ancestor_seq.logsumexp(-1, True) - insert_seq = pyro.sample("insert_seq", dist.Normal( - torch.zeros(self.seq_shape), - self.prior_scale * torch.ones(self.seq_shape)).to_event(2)) - insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True) - - # Indel probabilities. - insert = pyro.sample("insert", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) - insert_logits = insert - insert.logsumexp(-1, True) - delete = pyro.sample("delete", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) - delete_logits = delete - delete.logsumexp(-1, True) - - # Construct HMM parameters. - initial_logits, transition_logits, observation_logits = ( - self.statearrange(ancestor_seq_logits, insert_seq_logits, - insert_logits, delete_logits)) - # Draw samples. - for i in pyro.plate("data", data.shape[0]): - pyro.sample("obs_{}".format(i), - VariableLengthDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=data[i]) - - def guide(self, data): - # Sequence. - ancestor_seq_q_mn = pyro.param("ancestor_seq_q_mn", - torch.zeros(self.seq_shape)) - ancestor_seq_q_sd = pyro.param("ancestor_seq_q_sd", - torch.zeros(self.seq_shape)) - pyro.sample("ancestor_seq", dist.Normal( - ancestor_seq_q_mn, softplus(ancestor_seq_q_sd)).to_event(2)) - insert_seq_q_mn = pyro.param("insert_seq_q_mn", - torch.zeros(self.seq_shape)) - insert_seq_q_sd = pyro.param("insert_seq_q_sd", - torch.zeros(self.seq_shape)) - pyro.sample("insert_seq", dist.Normal( - insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2)) - - # Indels. - insert_q_mn = pyro.param("insert_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - insert_q_sd = pyro.param("insert_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("insert", dist.Normal( - insert_q_mn, softplus(insert_q_sd)).to_event(3)) - delete_q_mn = pyro.param("delete_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - delete_q_sd = pyro.param("delete_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("delete", dist.Normal( - delete_q_mn, softplus(delete_q_sd)).to_event(3)) - - def main(args): torch.manual_seed(0) @@ -166,13 +73,13 @@ def main(args): plt.savefig('phmm_plot.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) - ancestor_seq = pyro.param("ancestor_seq_q_mn").detach() - ancestor_seq_expect = torch.exp(ancestor_seq - - ancestor_seq.logsumexp(-1, True)) - plt.plot(ancestor_seq_expect[:, 1].numpy()) + precursor_seq = pyro.param("precursor_seq_q_mn").detach() + precursor_seq_expect = torch.exp(precursor_seq - + precursor_seq.logsumexp(-1, True)) + plt.plot(precursor_seq_expect[:, 1].numpy()) plt.xlabel('position') plt.ylabel('probability of character 1') - plt.savefig('phmm_plot.ancestor_seq_prob_{}.pdf'.format(time_stamp)) + plt.savefig('phmm_plot.precursor_seq_prob_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index 601318853b..b01f59866b 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 from pyro.contrib.mue.statearrangers import Profile -from pyro.contrib.mue.variablelengthhmm import MissingDataDiscreteHMM +from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM __all__ = [ "Profile" diff --git a/pyro/contrib/mue/variablelengthhmm.py b/pyro/contrib/mue/missingdatahmm.py similarity index 100% rename from pyro/contrib/mue/variablelengthhmm.py rename to pyro/contrib/mue/missingdatahmm.py diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py new file mode 100644 index 0000000000..271fa4ce43 --- /dev/null +++ b/pyro/contrib/mue/models.py @@ -0,0 +1,417 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example MuE observation models. +""" + +import torch +import torch.nn as nn +from torch.nn.functional import softplus + +import pyro +import pyro.distributions as dist + +from pyro.contrib.mue.statearrangers import Profile +from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM + +import pyro.poutine as poutine + +import pdb + + +class ProfileHMM(nn.Module): + """Model: Constant + MuE. """ + def __init__(self, latent_seq_length, alphabet_length, + prior_scale=1., indel_prior_strength=10.): + super().__init__() + + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + + self.precursor_seq_shape = (latent_seq_length, alphabet_length) + self.insert_seq_shape = (latent_seq_length+1, alphabet_length) + self.indel_shape = (latent_seq_length, 3, 2) + + assert isinstance(prior_scale, float) + self.prior_scale = prior_scale + assert isinstance(indel_prior_strength, float) + self.indel_prior = torch.tensor([indel_prior_strength, 0.]) + + # Initialize state arranger. + self.statearrange = Profile(latent_seq_length) + + def model(self, data): + + # Latent sequence. + precursor_seq = pyro.sample("precursor_seq", dist.Normal( + torch.zeros(self.precursor_seq_shape), + self.prior_scale * + torch.ones(self.precursor_seq_shape)).to_event(2)) + precursor_seq_logits = precursor_seq - precursor_seq.logsumexp(-1, True) + insert_seq = pyro.sample("insert_seq", dist.Normal( + torch.zeros(self.insert_seq_shape), + self.prior_scale * + torch.ones(self.insert_seq_shape)).to_event(2)) + insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True) + + # Indel probabilities. + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Construct HMM parameters. + initial_logits, transition_logits, observation_logits = ( + self.statearrange(precursor_seq_logits, insert_seq_logits, + insert_logits, delete_logits)) + # Draw samples. + for i in pyro.plate("data", data.shape[0]): + pyro.sample("obs_{}".format(i), + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=data[i]) + + def guide(self, data): + # Sequence. + precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", + torch.zeros(self.precursor_seq_shape)) + precursor_seq_q_sd = pyro.param("precursor_seq_q_sd", + torch.zeros(self.precursor_seq_shape)) + pyro.sample("precursor_seq", dist.Normal( + precursor_seq_q_mn, softplus(precursor_seq_q_sd)).to_event(2)) + insert_seq_q_mn = pyro.param("insert_seq_q_mn", + torch.zeros(self.insert_seq_shape)) + insert_seq_q_sd = pyro.param("insert_seq_q_sd", + torch.zeros(self.insert_seq_shape)) + pyro.sample("insert_seq", dist.Normal( + insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2)) + + # Indels. + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + + +class Encoder(nn.Module): + def __init__(self, data_length, alphabet_length, z_dim): + super().__init__() + + self.input_size = data_length * alphabet_length + self.f1_mn = nn.Linear(self.input_size, z_dim) + self.f1_sd = nn.Linear(self.input_size, z_dim) + + def forward(self, data): + + data = data.reshape(-1, self.input_size) + z_loc = self.f1_mn(data) + z_scale = softplus(self.f1_sd(data)) + + return z_loc, z_scale + + +"""class Decoder(nn.Module): + def __init__(self, latent_seq_length, alphabet_length, z_dim): + super().__init__() + + self.latent_seq_length = latent_seq_length + self.alphabet_length = alphabet_length + self.output_size = (2 * latent_seq_length+1) * alphabet_length + self.f = nn.Linear(z_dim, self.output_size) + + def forward(self, z): + + seq = self.f(z) + seq = seq.reshape([-1, 2, self.latent_seq_length+1, + self.alphabet_length]) + return seq""" + + +class FactorMuE(nn.Module): + """Model: pPCA + MuE.""" + def __init__(self, data_length, alphabet_length, z_dim, + batch_scale_factor=1., + latent_seq_length=None, + indel_factor_dependence=False, + indel_prior_scale=1., + indel_prior_bias=10., + inverse_temp_prior=100., + weights_prior_scale=1., + offset_prior_scale=1., + z_prior_distribution='Normal', + ARD_prior=False, + substitution_matrix=True, + substitution_prior_scale=10., + latent_alphabet_length=None, + length_model=False, + epsilon=1e-32): + super().__init__() + + # Constants. + assert isinstance(data_length, int) and data_length > 0 + self.data_length = data_length + if latent_seq_length is None: + latent_seq_length = data_length + else: + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + assert isinstance(z_dim, int) and z_dim > 0 + self.z_dim = z_dim + + # Parameter shapes. + if latent_alphabet_length is None: + latent_alphabet_length = alphabet_length + self.latent_alphabet_length = latent_alphabet_length + # self.seq_shape = (latent_seq_length, latent_alphabet_length) + self.indel_shape = (latent_seq_length, 3, 2) + self.total_factor_size = ( + (2*latent_seq_length+1)*latent_alphabet_length + + 2*indel_factor_dependence*latent_seq_length*3*2 + + length_model) + + # Architecture. + self.indel_factor_dependence = indel_factor_dependence + self.ARD_prior = ARD_prior + self.substitution_matrix = substitution_matrix + self.length_model = length_model + + # Priors. + assert isinstance(indel_prior_scale, float) + self.indel_prior_scale = torch.tensor(indel_prior_scale) + assert isinstance(indel_prior_bias, float) + self.indel_prior = torch.tensor([indel_prior_bias, 0.]) + assert isinstance(inverse_temp_prior, float) + self.inverse_temp_prior = torch.tensor(inverse_temp_prior) + assert isinstance(weights_prior_scale, float) + self.weights_prior_scale = torch.tensor(weights_prior_scale) + assert isinstance(offset_prior_scale, float) + self.offset_prior_scale = torch.tensor(offset_prior_scale) + assert isinstance(epsilon, float) + self.epsilon = torch.tensor(epsilon) + assert isinstance(substitution_prior_scale, float) + self.substitution_prior_scale = torch.tensor(substitution_prior_scale) + self.z_prior_distribution = z_prior_distribution + + # Batch control. + self.batch_scale_factor = batch_scale_factor + + # Initialize layers. + self.encoder = Encoder(data_length, alphabet_length, z_dim) + # self.decoder = Decoder(latent_seq_length, alphabet_length, z_dim) + self.statearrange = Profile(latent_seq_length) + + def decoder(self, z, W, B, inverse_temp): + + # Project. + v = torch.mm(z, W) + B + + out = dict() + if self.length_model: + # Extract expected length. + v, L_v = v.split([self.total_factor_size-1, 1], dim=1) + out['L_mean'] = softplus(L_v) + if self.indel_factor_dependence: + # Extract insertion and deletion parameters. + v, insert_v, delete_v = v.split([ + (2*self.latent_seq_length+1)*self.latent_alphabet_length, + self.latent_seq_length*3*2, self.latent_seq_length*3*2], dim=1) + insert_v = (insert_v.reshape([-1, self.latent_seq_length, 3, 2]) + + self.indel_prior) + out['insert_logits'] = insert_v - insert_v.logsumexp(-1, True) + delete_v = (delete_v.reshape([-1, self.latent_seq_length, 3, 2]) + + self.indel_prior) + out['delete_logits'] = delete_v - delete_v.logsumexp(-1, True) + # Extraction precursor and insertion sequences. + precursor_seq_v, insert_seq_v = (v*softplus(inverse_temp)).split([ + self.latent_seq_length*self.latent_alphabet_length, + (self.latent_seq_length+1)*self.latent_alphabet_length], dim=1) + precursor_seq_v = precursor_seq_v.reshape([ + -1, self.latent_seq_length, self.latent_alphabet_length]) + out['precursor_seq_logits'] = ( + precursor_seq_v - precursor_seq_v.logsumexp(-1, True)) + insert_seq_v = insert_seq_v.reshape([ + -1, self.latent_seq_length+1, self.latent_alphabet_length]) + out['insert_seq_logits'] = ( + insert_seq_v - insert_seq_v.logsumexp(-1, True)) + + return out + + def model(self, data): + + # pyro.module("decoder", self.decoder) + + # ARD prior. + if self.ARD_prior: + # Relevance factors + alpha = pyro.sample("alpha", dist.Gamma( + torch.ones(self.z_dim), torch.ones(self.z_dim)).to_event(1)) + else: + alpha = torch.ones(self.z_dim) + + # Factor and offset. + W = pyro.sample("W", dist.Normal( + torch.zeros([self.z_dim, self.total_factor_size]), + torch.ones([self.z_dim, self.total_factor_size]) * + self.weights_prior_scale / (alpha[:, None] + self.epsilon) + ).to_event(2)) + B = pyro.sample("B", dist.Normal( + torch.zeros(self.total_factor_size), + torch.ones(self.total_factor_size) * self.offset_prior_scale + ).to_event(1)) + + # Indel probabilities. + if not self.indel_factor_dependence: + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape) + ).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape) + ).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Inverse temperature. + inverse_temp = pyro.sample("inverse_temp", dist.Normal( + self.inverse_temp_prior, torch.tensor(1.))) + + # Substitution matrix. + if self.substitution_matrix: + substitute = pyro.sample("substitute", dist.Normal( + torch.zeros([ + self.latent_alphabet_length, self.alphabet_length]), + self.substitution_prior_scale * torch.ones([ + self.latent_alphabet_length, self.alphabet_length]) + ).to_event(2)) + + with pyro.plate("batch", data.shape[0]), poutine.scale( + scale=self.batch_scale_factor): + # Sample latent variable from prior. + if self.z_prior_distribution == 'Normal': + z = pyro.sample("latent", dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + z = pyro.sample("latent", dist.Laplace( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + + # Decode latent sequence. + decoded = self.decoder(z, W, B, inverse_temp) + if self.indel_factor_dependence: + insert_logits = decoded['insert_logits'] + delete_logits = decoded['delete_logits'] + + # Construct HMM parameters. + if self.substitution_matrix: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits, + substitute)) + else: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits)) + # Draw samples. + if self.length_model: + data, L = data + pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), obs=L) + pyro.sample("obs", + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=data) + + def guide(self, data): + # Register encoder with pyro. + pyro.module("encoder", self.encoder) + + # ARD weightings. + if self.ARD_prior: + alpha_conc = pyro.param("alpha_conc", torch.randn(self.z_dim)) + alpha_rate = pyro.param("alpha_rate", torch.randn(self.z_dim)) + pyro.sample("alpha", dist.Gamma(softplus(alpha_conc), + softplus(alpha_rate)).to_event(1)) + # Factors. + W_q_mn = pyro.param("W_q_mn", torch.randn([ + self.z_dim, self.total_factor_size])) + W_q_sd = pyro.param("W_q_sd", torch.randn([ + self.z_dim, self.total_factor_size])) + pyro.sample("W", dist.Normal(W_q_mn, softplus(W_q_sd)).to_event(2)) + B_q_mn = pyro.param("B_q_mn", torch.randn(self.total_factor_size)) + B_q_sd = pyro.param("B_q_sd", torch.randn(self.total_factor_size)) + pyro.sample("B", dist.Normal(B_q_mn, softplus(B_q_sd)).to_event(1)) + + # Indel probabilities. + if not self.indel_factor_dependence: + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + + # Inverse temperature. + inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.)) + inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.)) + pyro.sample("inverse_temp", dist.Normal( + inverse_temp_q_mn, softplus(inverse_temp_q_sd))) + + # Substitution matrix. + if self.substitution_matrix: + substitute_q_mn = pyro.param("substitute_q_mn", torch.zeros( + [self.latent_alphabet_length, self.alphabet_length])) + substitute_q_sd = pyro.param("substitute_q_sd", torch.zeros( + [self.latent_alphabet_length, self.alphabet_length])) + pyro.sample("substitute", dist.Normal( + substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) + + # Per data latent variables. + with pyro.plate("batch", data.shape[0]), poutine.scale( + scale=self.batch_scale_factor): + # Encode seq. + z_loc, z_scale = self.encoder(data) + # Sample. + if self.z_prior_distribution == 'Normal': + pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) + + def reconstruct_precursor_seq(self, data, param): + # Encode seq. + z_loc = self.encoder(data)[0] + # Reconstruct + decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), + param("inverse_temp_q_mn")) + return torch.exp(decoded['precursor_seq_logits']).detach() diff --git a/tests/contrib/mue/test_variablelengthhmm.py b/tests/contrib/mue/test_missingdatahmm.py similarity index 98% rename from tests/contrib/mue/test_variablelengthhmm.py rename to tests/contrib/mue/test_missingdatahmm.py index 96ef2855ee..03f305b9c2 100644 --- a/tests/contrib/mue/test_variablelengthhmm.py +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -2,7 +2,7 @@ import pytest import torch -from pyro.contrib.mue.variablelengthhmm import MissingDataDiscreteHMM +from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM def test_hmm_log_prob(): diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 60953b1cd4..97fddeed93 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,8 +1,6 @@ import pytest import torch -import pdb - from pyro.contrib.mue.statearrangers import mg2k, Profile From c98ece0aaad479d6ea8edee4c70c90bf59c43b02 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 9 Feb 2021 19:59:18 -0500 Subject: [PATCH 22/91] Cleanup. --- pyro/contrib/mue/models.py | 22 ---------------------- tests/contrib/mue/test_models.py | 0 2 files changed, 22 deletions(-) create mode 100644 tests/contrib/mue/test_models.py diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 271fa4ce43..b3be50c701 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -17,8 +17,6 @@ import pyro.poutine as poutine -import pdb - class ProfileHMM(nn.Module): """Model: Constant + MuE. """ @@ -128,23 +126,6 @@ def forward(self, data): return z_loc, z_scale -"""class Decoder(nn.Module): - def __init__(self, latent_seq_length, alphabet_length, z_dim): - super().__init__() - - self.latent_seq_length = latent_seq_length - self.alphabet_length = alphabet_length - self.output_size = (2 * latent_seq_length+1) * alphabet_length - self.f = nn.Linear(z_dim, self.output_size) - - def forward(self, z): - - seq = self.f(z) - seq = seq.reshape([-1, 2, self.latent_seq_length+1, - self.alphabet_length]) - return seq""" - - class FactorMuE(nn.Module): """Model: pPCA + MuE.""" def __init__(self, data_length, alphabet_length, z_dim, @@ -217,7 +198,6 @@ def __init__(self, data_length, alphabet_length, z_dim, # Initialize layers. self.encoder = Encoder(data_length, alphabet_length, z_dim) - # self.decoder = Decoder(latent_seq_length, alphabet_length, z_dim) self.statearrange = Profile(latent_seq_length) def decoder(self, z, W, B, inverse_temp): @@ -258,8 +238,6 @@ def decoder(self, z, W, B, inverse_temp): def model(self, data): - # pyro.module("decoder", self.decoder) - # ARD prior. if self.ARD_prior: # Relevance factors diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py new file mode 100644 index 0000000000..e69de29bb2 From fa90df4e93147cae6635e0af3677f477897b22f3 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 11 Feb 2021 15:17:05 -0500 Subject: [PATCH 23/91] Basic dataloaders, and start rearranging training to be part of model. --- pyro/contrib/mue/dataloaders.py | 76 ++++++++++++++++++++++++ pyro/contrib/mue/models.py | 28 +++++---- tests/contrib/mue/test_dataloaders.py | 62 +++++++++++++++++++ tests/contrib/mue/test_missingdatahmm.py | 3 + tests/contrib/mue/test_models.py | 4 ++ tests/contrib/mue/test_seqs.fasta | 7 +++ tests/contrib/mue/test_statearrangers.py | 3 + 7 files changed, 171 insertions(+), 12 deletions(-) create mode 100644 tests/contrib/mue/test_dataloaders.py create mode 100644 tests/contrib/mue/test_seqs.fasta diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index d6960608d6..ba5835e879 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -1,2 +1,78 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import torch +from torch.utils.data import Dataset + + +alphabets = {'amino-acid': np.array( + ['R', 'H', 'K', 'D', 'E', + 'S', 'T', 'N', 'Q', 'C', + 'G', 'P', 'A', 'V', 'I', + 'L', 'M', 'F', 'Y', 'W']), + 'dna': np.array(['A', 'C', 'G', 'T'])} + + +class BiosequenceDataset(Dataset): + """Load biological sequence data.""" + + def __init__(self, source, source_type='list', alphabet='amino-acid'): + + # Get sequences. + if source_type == 'list': + seqs = source + elif source_type == 'fasta': + seqs = self._load_fasta(source) + + # Get lengths. + self.L_data = torch.tensor([len(seq) for seq in seqs]) + self.max_length = torch.max(self.L_data) + self.data_size = len(self.L_data) + + # Get alphabet. + if type(alphabet) is list: + alphabet = np.array(alphabet) + elif alphabet in alphabets: + alphabet = alphabets[alphabet] + else: + assert 'Alphabet unavailable, please provide a list of letters.' + + # Build dataset. + self.seq_data = torch.cat([self._one_hot( + seq, alphabet, self.max_length).unsqueeze(0) for seq in seqs]) + + def _load_fasta(self, source): + """A basic multiline fasta parser.""" + seqs = [] + seq = '' + with open(source, 'r') as fr: + for line in fr: + if line[0] == '>': + if seq != '': + seqs.append(seq) + seq = '' + else: + seq += line.strip('\n') + if seq != '': + seqs.append(seq) + return seqs + + def _one_hot(self, seq, alphabet, length): + """One hot encode and pad with zeros to max length.""" + # One hot encode. + oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] + ).astype(np.float64)) + # Pad. + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) + + return x + + def __len__(self): + + return self.data_size + + def __getitem__(self, ind): + + return (self.seq_data[ind], self.L_data[ind]) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index b3be50c701..5a4a26b0fa 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -15,8 +15,6 @@ from pyro.contrib.mue.statearrangers import Profile from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM -import pyro.poutine as poutine - class ProfileHMM(nn.Module): """Model: Constant + MuE. """ @@ -129,7 +127,7 @@ def forward(self, data): class FactorMuE(nn.Module): """Model: pPCA + MuE.""" def __init__(self, data_length, alphabet_length, z_dim, - batch_scale_factor=1., + batch_size=10, latent_seq_length=None, indel_factor_dependence=False, indel_prior_scale=1., @@ -194,7 +192,8 @@ def __init__(self, data_length, alphabet_length, z_dim, self.z_prior_distribution = z_prior_distribution # Batch control. - self.batch_scale_factor = batch_scale_factor + assert isinstance(batch_size, int) + self.batch_size = batch_size # Initialize layers. self.encoder = Encoder(data_length, alphabet_length, z_dim) @@ -283,8 +282,8 @@ def model(self, data): self.latent_alphabet_length, self.alphabet_length]) ).to_event(2)) - with pyro.plate("batch", data.shape[0]), poutine.scale( - scale=self.batch_scale_factor): + with pyro.plate("batch", len(data), + subsample_size=self.batch_size) as ind: # Sample latent variable from prior. if self.z_prior_distribution == 'Normal': z = pyro.sample("latent", dist.Normal( @@ -314,14 +313,15 @@ def model(self, data): decoded['insert_seq_logits'], insert_logits, delete_logits)) # Draw samples. + L_data_ind, seq_data_ind = data[ind] if self.length_model: - data, L = data - pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), obs=L) - pyro.sample("obs", + pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), + obs=L_data_ind) + pyro.sample("obs_seq", MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits), - obs=data) + obs=seq_data_ind) def guide(self, data): # Register encoder with pyro. @@ -376,8 +376,7 @@ def guide(self, data): substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) # Per data latent variables. - with pyro.plate("batch", data.shape[0]), poutine.scale( - scale=self.batch_scale_factor): + with pyro.plate("batch", data.shape[0], subsample_size=self.batch_size): # Encode seq. z_loc, z_scale = self.encoder(data) # Sample. @@ -386,6 +385,11 @@ def guide(self, data): elif self.z_prior_distribution == 'Laplace': pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) + def fit_svi(self, dataloader, epochs=1, scheduler=None): + """Infer model parameters with stochastic variational inference.""" + + + def reconstruct_precursor_seq(self, data, param): # Encode seq. z_loc = self.encoder(data)[0] diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py new file mode 100644 index 0000000000..2320a0accc --- /dev/null +++ b/tests/contrib/mue/test_dataloaders.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pyro.contrib.mue.dataloaders import alphabets, BiosequenceDataset + + +@pytest.mark.parametrize('source_type', ['list', 'fasta']) +@pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', ['A', 'T', 'C']]) +def test_biosequencedataset(source_type, alphabet): + + # Define dataset. + seqs = ['AATC', 'CA', 'T'] + + # Encode dataset, alternate approach. + if type(alphabet) is list: + alphabet_list = alphabet + elif alphabet in alphabets: + alphabet_list = list(alphabets[alphabet]) + L_data_check = [len(seq) for seq in seqs] + max_length_check = max(L_data_check) + data_size_check = len(seqs) + seq_data_check = torch.zeros([len(seqs), max_length_check, + len(alphabet_list)]) + for i in range(len(seqs)): + for j, s in enumerate(seqs[i]): + seq_data_check[i, j, list(alphabet_list).index(s)] = 1 + + # Setup data source. + if source_type == 'fasta': + # Save as external file. + source = 'test_seqs.fasta' + with open(source, 'w') as fw: + text = """>one +AAT +C +>two +CA +>three +T +""" + fw.write(text) + elif source_type == 'list': + source = seqs + + # Load dataset. + dataset = BiosequenceDataset(source, source_type, alphabet) + + # Check. + assert torch.allclose(dataset.L_data, torch.tensor(L_data_check)) + assert dataset.max_length == max_length_check + assert len(dataset) == data_size_check + assert dataset.data_size == data_size_check + assert torch.allclose(dataset.seq_data, seq_data_check) + ind = torch.tensor([0, 2]) + assert torch.allclose(dataset[ind][0], + torch.cat([seq_data_check[0, None, :, :], + seq_data_check[2, None, :, :]])) + assert torch.allclose(dataset[ind][1], torch.tensor([4, 1])) diff --git a/tests/contrib/mue/test_missingdatahmm.py b/tests/contrib/mue/test_missingdatahmm.py index 03f305b9c2..d012e4fd5c 100644 --- a/tests/contrib/mue/test_missingdatahmm.py +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from pyro.distributions import DiscreteHMM, Categorical import pytest import torch diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index e69de29bb2..3e0be5c556 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -0,0 +1,4 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch diff --git a/tests/contrib/mue/test_seqs.fasta b/tests/contrib/mue/test_seqs.fasta new file mode 100644 index 0000000000..2a90359e45 --- /dev/null +++ b/tests/contrib/mue/test_seqs.fasta @@ -0,0 +1,7 @@ +>one +AAT +C +>two +CA +>three +T diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 97fddeed93..a8c5c3a563 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import pytest import torch From 0d7ea6c966ef7f04a329ca778666ac1611fdc494 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 11 Feb 2021 16:22:12 -0500 Subject: [PATCH 24/91] subsampling inference provided with model --- examples/contrib/mue/FactorMuE.py | 53 ++++++--------------------- pyro/contrib/mue/dataloaders.py | 3 +- pyro/contrib/mue/models.py | 44 ++++++++++++++++++---- tests/contrib/mue/test_dataloaders.py | 1 + 4 files changed, 52 insertions(+), 49 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index b6fd2a9660..111555b8b7 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -13,9 +13,9 @@ from torch.optim import Adam import pyro +from pyro.contrib.mue.dataloaders import BiosequenceDataset from pyro.contrib.mue.models import FactorMuE -from pyro.infer import SVI, Trace_ELBO from pyro.optim import MultiStepLR @@ -34,7 +34,9 @@ def main(args): mult_step = 400 # Construct example dataset. - xs = [torch.tensor([[0., 1.], + seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat + dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) + """xs = [torch.tensor([[0., 1.], [1., 0.], [0., 1.], [0., 1.], @@ -54,31 +56,20 @@ def main(args): [0., 0.]])] data = torch.cat([xs[0][None, :, :] for j in range(6*mult_dat)] + [xs[1][None, :, :] for j in range(4*mult_dat)] + - [xs[2][None, :, :] for j in range(4*mult_dat)], dim=0) + [xs[2][None, :, :] for j in range(4*mult_dat)], dim=0)""" # Set up inference. - obs_seq_length, alphabet_length, z_dim = 6, 2, 2 - # adam_params = {"lr": 0.1, "betas": (0.90, 0.999)} + z_dim = 2 scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.1}, 'milestones': [20, 100, 1000, 2000], 'gamma': 0.5}) - # optimizer = Adam(adam_params) - model = FactorMuE(obs_seq_length, alphabet_length, z_dim, + model = FactorMuE(dataset.max_length, dataset.alphabet_length, z_dim, substitution_matrix=False) + n_epochs = 10*mult_step + batch_size = len(dataset) - svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO()) - n_steps = 10*mult_step - - # Run inference. - losses = [] - t0 = datetime.datetime.now() - for step in range(n_steps): - - loss = svi.step(data) - losses.append(loss) - scheduler.step() - if step % 10 == 0: - print(step, loss, ' ', datetime.datetime.now() - t0) + # Infer. + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -89,32 +80,12 @@ def main(args): plt.savefig('FactorMuE_plot.loss_{}.pdf'.format(time_stamp)) plt.figure(figsize=(6, 6)) - latent = model.encoder(data)[0].detach() + latent = model.encoder(dataset.seq_data)[0].detach() plt.scatter(latent[:, 0], latent[:, 1]) plt.xlabel('z_1') plt.ylabel('z_2') plt.savefig('FactorMuE_plot.latent_{}.pdf'.format(time_stamp)) - # plt.figure(figsize=(6, 6)) - # decoder_bias = pyro.param('decoder$$$f.bias').detach() - # decoder_bias = decoder_bias.reshape( - # [-1, 2, model.latent_seq_length+1, model.alphabet_length]) - # plt.plot(decoder_bias[0, 0, :, 1]) - # plt.xlabel('position') - # plt.ylabel('bias for character 1') - # plt.savefig('FactorMuE_plot.decoder_bias_{}.pdf'.format(time_stamp)) - - for xi, x in enumerate(xs): - reconstruct_x = model.reconstruct_precursor_seq(x, pyro.param) - plt.figure(figsize=(6, 6)) - plt.plot(reconstruct_x[0, :, 1], label="reconstruct") - plt.plot(x[:, 1], label="data") - plt.xlabel('position') - plt.ylabel('probability of character 1') - plt.legend() - plt.savefig('FactorMuE_plot.reconstruction_{}_{}.pdf'.format( - xi, time_stamp)) - plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index ba5835e879..064fd634fe 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -28,7 +28,7 @@ def __init__(self, source, source_type='list', alphabet='amino-acid'): # Get lengths. self.L_data = torch.tensor([len(seq) for seq in seqs]) - self.max_length = torch.max(self.L_data) + self.max_length = int(torch.max(self.L_data)) self.data_size = len(self.L_data) # Get alphabet. @@ -38,6 +38,7 @@ def __init__(self, source, source_type='list', alphabet='amino-acid'): alphabet = alphabets[alphabet] else: assert 'Alphabet unavailable, please provide a list of letters.' + self.alphabet_length = len(alphabet) # Build dataset. self.seq_data = torch.cat([self._one_hot( diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 5a4a26b0fa..ae57889a10 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -8,6 +8,10 @@ import torch import torch.nn as nn from torch.nn.functional import softplus +from torch.optim import Adam + +import datetime +import numpy as np import pyro import pyro.distributions as dist @@ -15,6 +19,9 @@ from pyro.contrib.mue.statearrangers import Profile from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM +from pyro.infer import SVI, Trace_ELBO +from pyro.optim import MultiStepLR + class ProfileHMM(nn.Module): """Model: Constant + MuE. """ @@ -68,7 +75,7 @@ def model(self, data): self.statearrange(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits)) # Draw samples. - for i in pyro.plate("data", data.shape[0]): + for i in pyro.plate("data", len(data)): pyro.sample("obs_{}".format(i), MissingDataDiscreteHMM(initial_logits, transition_logits, @@ -313,7 +320,7 @@ def model(self, data): decoded['insert_seq_logits'], insert_logits, delete_logits)) # Draw samples. - L_data_ind, seq_data_ind = data[ind] + seq_data_ind, L_data_ind = data[ind] if self.length_model: pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), obs=L_data_ind) @@ -376,19 +383,42 @@ def guide(self, data): substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) # Per data latent variables. - with pyro.plate("batch", data.shape[0], subsample_size=self.batch_size): - # Encode seq. - z_loc, z_scale = self.encoder(data) + with pyro.plate("batch", len(data), + subsample_size=self.batch_size) as ind: + # Encode sequences. + z_loc, z_scale = self.encoder(data[ind][0]) # Sample. if self.z_prior_distribution == 'Normal': pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) elif self.z_prior_distribution == 'Laplace': pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) - def fit_svi(self, dataloader, epochs=1, scheduler=None): + def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): """Infer model parameters with stochastic variational inference.""" - + # Setup. + if batch_size is not None: + self.batch_size = batch_size + if scheduler is None: + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.01}, + 'milestones': [], + 'gamma': 0.5}) + n_steps = int(np.ceil(torch.tensor(len(dataset)/self.batch_size)) + )*epochs + svi = SVI(self.model, self.guide, scheduler, loss=Trace_ELBO()) + + # Run inference. + losses = [] + t0 = datetime.datetime.now() + for step in range(n_steps): + loss = svi.step(dataset) + losses.append(loss) + scheduler.step() + if (step + 1) % (n_steps/epochs) == 0: + print(int(epochs*(step+1)/n_steps), loss, ' ', + datetime.datetime.now() - t0) + return losses def reconstruct_precursor_seq(self, data, param): # Encode seq. diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 2320a0accc..9a75e08740 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -54,6 +54,7 @@ def test_biosequencedataset(source_type, alphabet): assert dataset.max_length == max_length_check assert len(dataset) == data_size_check assert dataset.data_size == data_size_check + assert dataset.alphabet_length == len(alphabet_list) assert torch.allclose(dataset.seq_data, seq_data_check) ind = torch.tensor([0, 2]) assert torch.allclose(dataset[ind][0], From 8cb2642503c7e63bd5f8f3b4849eee5f9f983a90 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 11 Feb 2021 20:15:38 -0500 Subject: [PATCH 25/91] FactorMuE test. --- examples/contrib/mue/FactorMuE.py | 22 +------------ pyro/contrib/mue/dataloaders.py | 2 +- pyro/contrib/mue/missingdatahmm.py | 4 ++- pyro/contrib/mue/models.py | 13 ++++---- tests/contrib/mue/test_dataloaders.py | 6 ++-- tests/contrib/mue/test_models.py | 46 +++++++++++++++++++++++++++ 6 files changed, 61 insertions(+), 32 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 111555b8b7..9cba423bbd 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -36,27 +36,7 @@ def main(args): # Construct example dataset. seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) - """xs = [torch.tensor([[0., 1.], - [1., 0.], - [0., 1.], - [0., 1.], - [1., 0.], - [0., 0.]]), - torch.tensor([[0., 1.], - [1., 0.], - [1., 0.], - [0., 1.], - [0., 0.], - [0., 0.]]), - torch.tensor([[0., 1.], - [1., 0.], - [0., 1.], - [0., 1.], - [0., 1.], - [0., 0.]])] - data = torch.cat([xs[0][None, :, :] for j in range(6*mult_dat)] + - [xs[1][None, :, :] for j in range(4*mult_dat)] + - [xs[2][None, :, :] for j in range(4*mult_dat)], dim=0)""" + # Set up inference. z_dim = 2 scheduler = MultiStepLR({'optimizer': Adam, diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 064fd634fe..f5169a5d4b 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -27,7 +27,7 @@ def __init__(self, source, source_type='list', alphabet='amino-acid'): seqs = self._load_fasta(source) # Get lengths. - self.L_data = torch.tensor([len(seq) for seq in seqs]) + self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) self.max_length = int(torch.max(self.L_data)) self.data_size = len(self.L_data) diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index bf778d2237..201811bcd1 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -52,9 +52,11 @@ def __init__(self, initial_logits, transition_logits, observation_logits, raise ValueError( "expected observation_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape)) - shape = broadcast_shape(initial_logits.shape[:-1] + (1,), + shape = broadcast_shape(initial_logits.shape[:-1], transition_logits.shape[:-2], observation_logits.shape[:-2]) + if len(shape) == 0: + shape = torch.Size([1]) batch_shape = shape event_shape = (1, observation_logits.shape[-1]) self.initial_logits = (initial_logits - diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index ae57889a10..ee882a5879 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -22,6 +22,8 @@ from pyro.infer import SVI, Trace_ELBO from pyro.optim import MultiStepLR +import pdb + class ProfileHMM(nn.Module): """Model: Constant + MuE. """ @@ -165,10 +167,9 @@ def __init__(self, data_length, alphabet_length, z_dim, self.z_dim = z_dim # Parameter shapes. - if latent_alphabet_length is None: + if (not substitution_matrix) or (latent_alphabet_length is None): latent_alphabet_length = alphabet_length self.latent_alphabet_length = latent_alphabet_length - # self.seq_shape = (latent_seq_length, latent_alphabet_length) self.indel_shape = (latent_seq_length, 3, 2) self.total_factor_size = ( (2*latent_seq_length+1)*latent_alphabet_length + @@ -215,7 +216,7 @@ def decoder(self, z, W, B, inverse_temp): if self.length_model: # Extract expected length. v, L_v = v.split([self.total_factor_size-1, 1], dim=1) - out['L_mean'] = softplus(L_v) + out['L_mean'] = softplus(L_v).squeeze(1) if self.indel_factor_dependence: # Extract insertion and deletion parameters. v, insert_v, delete_v = v.split([ @@ -227,7 +228,7 @@ def decoder(self, z, W, B, inverse_temp): delete_v = (delete_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior) out['delete_logits'] = delete_v - delete_v.logsumexp(-1, True) - # Extraction precursor and insertion sequences. + # Extract precursor and insertion sequences. precursor_seq_v, insert_seq_v = (v*softplus(inverse_temp)).split([ self.latent_seq_length*self.latent_alphabet_length, (self.latent_seq_length+1)*self.latent_alphabet_length], dim=1) @@ -420,9 +421,9 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): datetime.datetime.now() - t0) return losses - def reconstruct_precursor_seq(self, data, param): + def reconstruct_precursor_seq(self, data, ind, param): # Encode seq. - z_loc = self.encoder(data)[0] + z_loc = self.encoder(data[ind][0])[0] # Reconstruct decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), param("inverse_temp_q_mn")) diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 9a75e08740..067598caac 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import numpy as np import pytest import torch @@ -50,7 +49,8 @@ def test_biosequencedataset(source_type, alphabet): dataset = BiosequenceDataset(source, source_type, alphabet) # Check. - assert torch.allclose(dataset.L_data, torch.tensor(L_data_check)) + assert torch.allclose(dataset.L_data, + torch.tensor(L_data_check, dtype=torch.float64)) assert dataset.max_length == max_length_check assert len(dataset) == data_size_check assert dataset.data_size == data_size_check @@ -60,4 +60,4 @@ def test_biosequencedataset(source_type, alphabet): assert torch.allclose(dataset[ind][0], torch.cat([seq_data_check[0, None, :, :], seq_data_check[2, None, :, :]])) - assert torch.allclose(dataset[ind][1], torch.tensor([4, 1])) + assert torch.allclose(dataset[ind][1], torch.tensor([4., 1.])) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 3e0be5c556..152423425f 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -1,4 +1,50 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import pytest + +import numpy as np import torch +from torch.optim import Adam + +import pyro + +from pyro.contrib.mue.dataloaders import BiosequenceDataset +from pyro.contrib.mue.models import ProfileHMM, FactorMuE + +from pyro.optim import MultiStepLR + + +@pytest.mark.parametrize('indel_factor_dependence', [False, True]) +@pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) +@pytest.mark.parametrize('ARD_prior', [False, True]) +@pytest.mark.parametrize('substitution_matrix', [False, True]) +@pytest.mark.parametrize('length_model', [False, True]) +def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, + ARD_prior, substitution_matrix, length_model): + # Setup dataset. + seqs = ['BABBA', 'BAAB', 'BABBB'] + alph = ['A', 'B'] + dataset = BiosequenceDataset(seqs, 'list', alph) + + # Infer. + z_dim = 2 + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.1}, + 'milestones': [20, 100, 1000, 2000], + 'gamma': 0.5}) + model = FactorMuE(dataset.max_length, dataset.alphabet_length, z_dim, + indel_factor_dependence=indel_factor_dependence, + z_prior_distribution=z_prior_distribution, + ARD_prior=ARD_prior, + substitution_matrix=substitution_matrix, + length_model=length_model) + n_epochs = 5 + batch_size = 2 + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) + + # Reconstruct. + recon = model.reconstruct_precursor_seq(dataset, 1, pyro.param) + + assert not np.isnan(losses[-1]) + assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) From 7977022f425180d026880092d7bdeb8df8ff8431 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 10:50:13 -0500 Subject: [PATCH 26/91] Inference, length model, tests for profile hmm. --- pyro/contrib/mue/models.py | 57 +++++++++++++++++++++++++++++--- tests/contrib/mue/test_models.py | 21 ++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index ee882a5879..a27ef68e93 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -28,7 +28,7 @@ class ProfileHMM(nn.Module): """Model: Constant + MuE. """ def __init__(self, latent_seq_length, alphabet_length, - prior_scale=1., indel_prior_strength=10.): + length_model=False, prior_scale=1., indel_prior_strength=10.): super().__init__() assert isinstance(latent_seq_length, int) and latent_seq_length > 0 @@ -40,6 +40,8 @@ def __init__(self, latent_seq_length, alphabet_length, self.insert_seq_shape = (latent_seq_length+1, alphabet_length) self.indel_shape = (latent_seq_length, 3, 2) + assert isinstance(length_model, bool) + self.length_model = length_model assert isinstance(prior_scale, float) self.prior_scale = prior_scale assert isinstance(indel_prior_strength, float) @@ -76,13 +78,26 @@ def model(self, data): initial_logits, transition_logits, observation_logits = ( self.statearrange(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits)) + + # Length model. + if self.length_model: + length = pyro.sample("length", dist.Normal( + torch.tensor(200.), torch.tensor(1000.))) + L_mean = softplus(length) + # Draw samples. - for i in pyro.plate("data", len(data)): - pyro.sample("obs_{}".format(i), + with pyro.plate("batch", len(data), + subsample_size=self.batch_size) as ind: + + seq_data_ind, L_data_ind = data[ind] + if self.length_model: + pyro.sample("obs_L", dist.Poisson(L_mean), + obs=L_data_ind) + pyro.sample("obs_seq", MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits), - obs=data[i]) + obs=seq_data_ind) def guide(self, data): # Sequence. @@ -115,6 +130,40 @@ def guide(self, data): pyro.sample("delete", dist.Normal( delete_q_mn, softplus(delete_q_sd)).to_event(3)) + # Length. + if self.length_model: + length_q_mn = pyro.param("length_q_mn", torch.zeros(1)) + length_q_sd = pyro.param("length_q_sd", torch.zeros(1)) + pyro.sample("length", dist.Normal( + length_q_mn, softplus(length_q_sd))) + + def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): + """Infer model parameters with stochastic variational inference.""" + + # Setup. + if batch_size is not None: + self.batch_size = batch_size + if scheduler is None: + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.01}, + 'milestones': [], + 'gamma': 0.5}) + n_steps = int(np.ceil(torch.tensor(len(dataset)/self.batch_size)) + )*epochs + svi = SVI(self.model, self.guide, scheduler, loss=Trace_ELBO()) + + # Run inference. + losses = [] + t0 = datetime.datetime.now() + for step in range(n_steps): + loss = svi.step(dataset) + losses.append(loss) + scheduler.step() + if (step + 1) % (n_steps/epochs) == 0: + print(int(epochs*(step+1)/n_steps), loss, ' ', + datetime.datetime.now() - t0) + return losses + class Encoder(nn.Module): def __init__(self, data_length, alphabet_length, z_dim): diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 152423425f..189d615a4b 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -15,6 +15,27 @@ from pyro.optim import MultiStepLR +@pytest.mark.parametrize('length_model', [False, True]) +def test_ProfileHMM_smoke(length_model): + # Setup dataset. + seqs = ['BABBA', 'BAAB', 'BABBB'] + alph = ['A', 'B'] + dataset = BiosequenceDataset(seqs, 'list', alph) + + # Infer. + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.1}, + 'milestones': [20, 100, 1000, 2000], + 'gamma': 0.5}) + model = ProfileHMM(dataset.max_length, dataset.alphabet_length, + length_model) + n_epochs = 5 + batch_size = 2 + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) + + assert not np.isnan(losses[-1]) + + @pytest.mark.parametrize('indel_factor_dependence', [False, True]) @pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) @pytest.mark.parametrize('ARD_prior', [False, True]) From 5d085cf46a7e6b9c861a3dc94780982baadbaa5c Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 16:56:25 -0500 Subject: [PATCH 27/91] FactorMuE example full input options. --- examples/contrib/mue/FactorMuE.py | 197 ++++++++++++++++++++++-------- 1 file changed, 143 insertions(+), 54 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 9cba423bbd..395362d56f 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -8,6 +8,8 @@ import argparse import datetime import matplotlib.pyplot as plt +import json +import os import torch from torch.optim import Adam @@ -19,72 +21,159 @@ from pyro.optim import MultiStepLR -def main(args): - - torch.manual_seed(9) - torch.set_default_tensor_type('torch.DoubleTensor') - - small_test = args.test - +def generate_data(small_test): + """Generate example dataset.""" if small_test: mult_dat = 1 - mult_step = 1 else: mult_dat = 10 - mult_step = 400 - # Construct example dataset. seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) - # Set up inference. - z_dim = 2 - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.1}, - 'milestones': [20, 100, 1000, 2000], - 'gamma': 0.5}) - model = FactorMuE(dataset.max_length, dataset.alphabet_length, z_dim, - substitution_matrix=False) - n_epochs = 10*mult_step - batch_size = len(dataset) + return dataset + + +def main(args): + + pyro.set_rng_seed(args.rng_seed) + + # Construct example dataset. + if args.test: + dataset = generate_data(args.small) + else: + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + + # Construct model. + model = FactorMuE(dataset.max_length, dataset.alphabet_length, args.z_dim, + batch_size=args.batch_size, + latent_seq_length=args.latent_seq_length, + indel_factor_dependence=args.indel_factor, + indel_prior_scale=args.indel_prior_scale, + indel_prior_bias=args.indel_prior_bias, + inverse_temp_prior=args.inverse_temp_prior, + weights_prior_scale=args.weights_prior_scale, + offset_prior_scale=args.offset_prior_scale, + z_prior_distribution=args.z_prior, + ARD_prior=args.ARD_prior, + substitution_matrix=args.substitution_matrix, + substitution_prior_scale=args.substitution_prior_scale, + latent_alphabet_length=args.latent_alphabet, + length_model=args.length_model) # Infer. - losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) - - # Plots. - time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - plt.figure(figsize=(6, 6)) - plt.plot(losses) - plt.xlabel('step') - plt.ylabel('loss') - plt.savefig('FactorMuE_plot.loss_{}.pdf'.format(time_stamp)) - - plt.figure(figsize=(6, 6)) - latent = model.encoder(dataset.seq_data)[0].detach() - plt.scatter(latent[:, 0], latent[:, 1]) - plt.xlabel('z_1') - plt.ylabel('z_2') - plt.savefig('FactorMuE_plot.latent_{}.pdf'.format(time_stamp)) - - plt.figure(figsize=(6, 6)) - insert = pyro.param("insert_q_mn").detach() - insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) - plt.plot(insert_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of insert') - plt.savefig('FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp)) - plt.figure(figsize=(6, 6)) - delete = pyro.param("delete_q_mn").detach() - delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) - plt.plot(delete_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of delete') - plt.savefig('FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp)) + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': args.learning_rate}, + 'milestones': json.loads(args.milestones), + 'gamma': args.learning_gamma}) + if args.test and not args.small: + n_epochs = 100 + else: + n_epochs = args.n_epochs + losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler) + + # Plot and save. + if args.plots: + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) + + plt.figure(figsize=(6, 6)) + latent = model.encoder(dataset.seq_data)[0].detach() + plt.scatter(latent[:, 0], latent[:, 1]) + plt.xlabel('z_1') + plt.ylabel('z_2') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) + + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Basic Factor MuE model.") - parser.add_argument('-t', '--test', action='store_true', default=False, - help='small dataset, a few steps') + parser = argparse.ArgumentParser(description="Factor MuE model.") + parser.add_argument("--test", action='store_true', default=False, + help='Run with generated example dataset.') + parser.add_argument("--small", action='store_true', default=False, + help='Run with small example dataset.') + parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("-f", "--file", default=None, + help='Input file (fasta format).') + parser.add_argument("-a", "--alphabet", default='amino-acid', + help='Alphabet (amino-acid OR dna).') + parser.add_argument("-zdim", "--z-dim", default=2, type=int, + help='z space dimension.') + parser.add_argument("-b", "--batch-size", default=10, type=int, + help='Batch size.') + parser.add_argument("-M", "--latent-seq-length", default=None, + help='Latent sequence length.') + parser.add_argument("-idfac", "--indel-factor", default=False, type=bool, + help='Indel parameters depend on latent variable.') + parser.add_argument("-zdist", "--z-prior", default='Normal', + help='Latent prior distribution (normal or Laplace).') + parser.add_argument("-ard", "--ARD-prior", default=False, type=bool, + help='Use automatic relevance detection prior.') + parser.add_argument("-sub", "--substitution-matrix", default=True, type=bool, + help='Use substitution matrix.') + parser.add_argument("-D", "--latent-alphabet", default=None, + help='Latent alphabet length.') + parser.add_argument("-L", "--length-model", default=False, type=bool, + help='Model sequence length.') + parser.add_argument("--indel-prior-scale", default=1., type=float, + help=('Indel prior scale parameter ' + + '(when indel-factor=False).')) + parser.add_argument("--indel-prior-bias", default=10., type=float, + help='Indel prior bias parameter.') + parser.add_argument("--inverse-temp-prior", default=100., type=float, + help='Inverse temperature prior mean.') + parser.add_argument("--weights-prior-scale", default=1., type=float, + help='Factor parameter prior scale.') + parser.add_argument("--offset-prior-scale", default=1., type=float, + help='Offset parameter prior scale.') + parser.add_argument("--substitution-prior-scale", default=10., type=float, + help='Substitution matrix prior scale.') + parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, + help='Learning rate for Adam optimizer.') + parser.add_argument("--milestones", default='[]', type=str, + help='Milestones for multistage learning rate.') + parser.add_argument("--learning-gamma", default=0.5, type=float, + help='Gamma parameter for multistage learning rate.') + parser.add_argument("-e", "--n-epochs", default=10, type=int, + help='Number of epochs of training.') + parser.add_argument("-p", "--plots", default=True, type=bool, + help='Make plots.') + parser.add_argument("-s", "--save", default=True, type=bool, + help='Save plots and results.') + parser.add_argument("-outf", "--out-folder", default='.', + help='Folder to save plots.') args = parser.parse_args() + + torch.set_default_dtype(torch.float64) + main(args) From 921eebedf84c4d501d32092755711517b98bc38e Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 19:30:19 -0500 Subject: [PATCH 28/91] Debug saving. --- examples/contrib/mue/FactorMuE.py | 53 +++++++++++++++++++------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 395362d56f..89e4fd26d1 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -43,6 +43,7 @@ def main(args): dataset = generate_data(args.small) else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + args.batch_size = min([dataset.data_size, args.batch_size]) # Construct model. model = FactorMuE(dataset.max_length, dataset.alphabet_length, args.z_dim, @@ -94,26 +95,38 @@ def main(args): args.out_folder, 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) - plt.figure(figsize=(6, 6)) - insert = pyro.param("insert_q_mn").detach() - insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) - plt.plot(insert_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of insert') - if args.save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) - plt.figure(figsize=(6, 6)) - delete = pyro.param("delete_q_mn").detach() - delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) - plt.plot(delete_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of delete') - if args.save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) + if not args.indel_factor: + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) + if args.save: + pyro.get_param_store().save(os.path.join( + args.out_folder, + 'FactorMuE_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'FactorMuE_results.input_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('[args]\n') + for elem in list(args.__dict__.keys()): + ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) if __name__ == '__main__': From f07ce6d3d535d777920163c22b7eedd8ad44b011 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 20:05:26 -0500 Subject: [PATCH 29/91] Clean up --- examples/contrib/mue/FactorMuE.py | 7 +++---- examples/contrib/mue/{phmm.py => ProfileHMM.py} | 7 +++---- pyro/contrib/mue/__init__.py | 7 ------- pyro/contrib/mue/dataloaders.py | 2 -- pyro/contrib/mue/models.py | 12 ++++-------- pyro/contrib/mue/statearrangers.py | 1 - tests/contrib/mue/test_dataloaders.py | 2 +- tests/contrib/mue/test_missingdatahmm.py | 2 +- tests/contrib/mue/test_models.py | 8 ++------ tests/contrib/mue/test_statearrangers.py | 2 +- 10 files changed, 15 insertions(+), 35 deletions(-) rename examples/contrib/mue/{phmm.py => ProfileHMM.py} (98%) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 89e4fd26d1..91d463ef8a 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -2,22 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 """ -A PCA model with a MuE emission (FactorMuE). Uses the MuE package. +A PCA model with a MuE emission (FactorMuE). """ import argparse import datetime -import matplotlib.pyplot as plt import json import os +import matplotlib.pyplot as plt import torch from torch.optim import Adam -import pyro +import pyro from pyro.contrib.mue.dataloaders import BiosequenceDataset from pyro.contrib.mue.models import FactorMuE - from pyro.optim import MultiStepLR diff --git a/examples/contrib/mue/phmm.py b/examples/contrib/mue/ProfileHMM.py similarity index 98% rename from examples/contrib/mue/phmm.py rename to examples/contrib/mue/ProfileHMM.py index ef45d13f13..2b59316ae9 100644 --- a/examples/contrib/mue/phmm.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -2,18 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 """ -A standard profile HMM model example, using the MuE package. +A standard profile HMM model. """ import argparse import datetime -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import torch -import pyro +import pyro from pyro.contrib.mue.models import ProfileHMM - from pyro.infer import SVI, Trace_ELBO from pyro.optim import Adam diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index b01f59866b..d6960608d6 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,9 +1,2 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.contrib.mue.statearrangers import Profile -from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM - -__all__ = [ - "Profile" - "MissingDataDiscreteHMM" -] diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index f5169a5d4b..de03fccfb1 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -2,11 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np - import torch from torch.utils.data import Dataset - alphabets = {'amino-acid': np.array( ['R', 'H', 'K', 'D', 'E', 'S', 'T', 'N', 'Q', 'C', diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index a27ef68e93..a4cded1e2c 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -5,25 +5,21 @@ Example MuE observation models. """ +import datetime + +import numpy as np import torch import torch.nn as nn from torch.nn.functional import softplus from torch.optim import Adam -import datetime -import numpy as np - import pyro import pyro.distributions as dist - -from pyro.contrib.mue.statearrangers import Profile from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM - +from pyro.contrib.mue.statearrangers import Profile from pyro.infer import SVI, Trace_ELBO from pyro.optim import MultiStepLR -import pdb - class ProfileHMM(nn.Module): """Model: Constant + MuE. """ diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 2f1b799ff7..c8511b3585 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn -import pdb class Profile(nn.Module): diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 067598caac..887b63c490 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -4,7 +4,7 @@ import pytest import torch -from pyro.contrib.mue.dataloaders import alphabets, BiosequenceDataset +from pyro.contrib.mue.dataloaders import BiosequenceDataset, alphabets @pytest.mark.parametrize('source_type', ['list', 'fasta']) diff --git a/tests/contrib/mue/test_missingdatahmm.py b/tests/contrib/mue/test_missingdatahmm.py index d012e4fd5c..ee19f4b31d 100644 --- a/tests/contrib/mue/test_missingdatahmm.py +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -1,11 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.distributions import DiscreteHMM, Categorical import pytest import torch from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM +from pyro.distributions import Categorical, DiscreteHMM def test_hmm_log_prob(): diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 189d615a4b..5f716a9c33 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -1,17 +1,13 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import pytest - import numpy as np -import torch +import pytest from torch.optim import Adam import pyro - from pyro.contrib.mue.dataloaders import BiosequenceDataset -from pyro.contrib.mue.models import ProfileHMM, FactorMuE - +from pyro.contrib.mue.models import FactorMuE, ProfileHMM from pyro.optim import MultiStepLR diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index a8c5c3a563..7215edc9c8 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -4,7 +4,7 @@ import pytest import torch -from pyro.contrib.mue.statearrangers import mg2k, Profile +from pyro.contrib.mue.statearrangers import Profile, mg2k def simpleprod(lst): From f9823166b79a7d8ed868b759fbcf0790ac2d238a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 20:05:43 -0500 Subject: [PATCH 30/91] Add FactorMuE to test_examples.py --- tests/test_examples.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_examples.py b/tests/test_examples.py index b0d0cb96c8..99cb2370df 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,6 +53,8 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', + 'contrib/mue/FactorMuE.py --test --small', + 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', From 42e97bcf1b5381fbdb07f19c813533d586572384 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 19 Feb 2021 20:06:47 -0500 Subject: [PATCH 31/91] Files autochanged by make format. --- docs/source/conf.py | 1 - docs/source/contrib.mue.rst | 9 +++- examples/air/air.py | 2 +- examples/air/main.py | 4 +- examples/capture_recapture/cjs.py | 3 +- examples/contrib/autoname/scoping_mixture.py | 7 ++- examples/contrib/funsor/hmm.py | 1 - examples/contrib/gp/sv-dkl.py | 2 +- examples/contrib/oed/ab_test.py | 14 +++-- examples/contrib/oed/gp_bayes_opt.py | 2 +- examples/contrib/timeseries/gp_models.py | 8 +-- examples/cvae/baseline.py | 5 +- examples/cvae/cvae.py | 10 ++-- examples/cvae/main.py | 10 ++-- examples/cvae/mnist.py | 2 +- examples/cvae/util.py | 12 +++-- examples/eight_schools/data.py | 1 - examples/eight_schools/mcmc.py | 2 +- examples/eight_schools/svi.py | 2 +- examples/hmm.py | 2 +- examples/lkj.py | 3 +- examples/minipyro.py | 2 +- examples/mixed_hmm/experiment.py | 9 ++-- examples/mixed_hmm/seal_data.py | 2 - examples/rsa/generics.py | 9 ++-- examples/rsa/hyperbole.py | 9 ++-- examples/rsa/schelling.py | 3 +- examples/rsa/schelling_false.py | 3 +- examples/rsa/search_inference.py | 4 +- examples/rsa/semantic_parsing.py | 7 ++- examples/scanvi/data.py | 6 +-- examples/scanvi/scanvi.py | 14 +++-- examples/sparse_gamma_def.py | 11 ++-- examples/sparse_regression.py | 9 ++-- examples/vae/ss_vae_M2.py | 6 +-- examples/vae/vae.py | 6 +-- examples/vae/vae_comparison.py | 2 +- profiler/hmm.py | 3 +- profiler/profiling_utils.py | 2 +- pyro/contrib/__init__.py | 1 + pyro/contrib/autoname/__init__.py | 3 +- pyro/contrib/bnn/hidden_layer.py | 2 +- pyro/contrib/bnn/utils.py | 3 +- pyro/contrib/conjugate/infer.py | 2 +- pyro/contrib/easyguide/__init__.py | 1 - pyro/contrib/easyguide/easyguide.py | 2 +- pyro/contrib/examples/bart.py | 2 +- pyro/contrib/examples/finance.py | 2 +- .../examples/polyphonic_data_loader.py | 3 +- pyro/contrib/forecast/util.py | 2 +- pyro/contrib/funsor/__init__.py | 12 ++--- pyro/contrib/funsor/handlers/__init__.py | 8 +-- .../contrib/funsor/handlers/enum_messenger.py | 9 ++-- .../funsor/handlers/named_messenger.py | 3 +- .../funsor/handlers/plate_messenger.py | 8 ++- pyro/contrib/funsor/handlers/primitives.py | 1 - .../funsor/handlers/replay_messenger.py | 2 +- .../funsor/handlers/trace_messenger.py | 5 +- pyro/contrib/funsor/infer/__init__.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 7 ++- pyro/contrib/funsor/infer/traceenum_elbo.py | 5 +- pyro/contrib/funsor/infer/tracetmc_elbo.py | 6 +-- pyro/contrib/gp/kernels/__init__.py | 7 ++- pyro/contrib/gp/likelihoods/binary.py | 1 - pyro/contrib/gp/likelihoods/gaussian.py | 1 - pyro/contrib/gp/likelihoods/multi_class.py | 1 - pyro/contrib/gp/likelihoods/poisson.py | 1 - pyro/contrib/oed/__init__.py | 2 +- pyro/contrib/oed/eig.py | 9 ++-- pyro/contrib/oed/glmm/__init__.py | 2 +- pyro/contrib/oed/glmm/glmm.py | 6 +-- pyro/contrib/oed/glmm/guides.py | 8 ++- pyro/contrib/oed/search.py | 3 +- pyro/contrib/oed/util.py | 3 +- .../contrib/randomvariable/random_variable.py | 13 ++--- pyro/contrib/timeseries/__init__.py | 2 +- pyro/contrib/tracking/distributions.py | 2 +- pyro/contrib/tracking/dynamic_models.py | 1 + pyro/contrib/tracking/measurements.py | 1 + pyro/contrib/util.py | 2 + pyro/distributions/__init__.py | 51 ++++--------------- pyro/distributions/ordered_logistic.py | 1 + pyro/distributions/spanning_tree.py | 1 + pyro/distributions/transforms/__init__.py | 5 +- pyro/distributions/transforms/ordered.py | 3 +- pyro/infer/__init__.py | 2 +- pyro/infer/abstract_infer.py | 2 +- pyro/infer/autoguide/initialization.py | 1 - pyro/infer/mcmc/__init__.py | 2 +- pyro/infer/mcmc/adaptation.py | 2 +- pyro/infer/mcmc/hmc.py | 3 +- pyro/infer/mcmc/util.py | 4 +- pyro/infer/predictive.py | 2 +- pyro/infer/reparam/neutra.py | 1 + pyro/infer/svgd.py | 6 +-- pyro/infer/trace_mean_field_elbo.py | 2 +- pyro/infer/trace_mmd.py | 2 +- pyro/infer/trace_tail_adaptive_elbo.py | 2 +- pyro/infer/traceenum_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 3 +- pyro/infer/tracetmc_elbo.py | 1 - pyro/logger.py | 1 - pyro/ops/arrowhead.py | 1 - pyro/ops/einsum/torch_map.py | 1 - pyro/ops/einsum/torch_sample.py | 1 - pyro/ops/newton.py | 2 +- pyro/ops/ssm_gp.py | 2 +- pyro/ops/stats.py | 2 +- pyro/poutine/broadcast_messenger.py | 1 + pyro/poutine/indep_messenger.py | 1 + pyro/poutine/trace_struct.py | 2 +- tests/__init__.py | 1 - tests/conftest.py | 1 - tests/contrib/autoguide/test_inference.py | 4 +- .../autoguide/test_mean_field_entropy.py | 4 +- tests/contrib/autoname/test_scoping.py | 2 +- tests/contrib/bnn/test_hidden_layer.py | 2 +- tests/contrib/epidemiology/test_quant.py | 1 - tests/contrib/funsor/test_enum_funsor.py | 4 +- tests/contrib/funsor/test_named_handlers.py | 3 +- tests/contrib/funsor/test_pyroapi_funsor.py | 2 +- tests/contrib/funsor/test_tmc.py | 3 +- .../contrib/funsor/test_valid_models_enum.py | 5 +- .../contrib/funsor/test_valid_models_plate.py | 3 +- .../test_valid_models_sequential_plate.py | 3 +- .../contrib/funsor/test_vectorized_markov.py | 6 +-- tests/contrib/gp/test_kernels.py | 5 +- tests/contrib/gp/test_likelihoods.py | 1 - tests/contrib/gp/test_models.py | 7 ++- tests/contrib/oed/test_ewma.py | 2 +- tests/contrib/oed/test_finite_spaces_eig.py | 8 ++- tests/contrib/oed/test_glmm.py | 6 +-- tests/contrib/oed/test_linear_models_eig.py | 11 ++-- .../randomvariable/test_random_variable.py | 1 + tests/contrib/test_util.py | 5 +- tests/contrib/timeseries/test_gp.py | 9 ++-- tests/contrib/timeseries/test_lgssm.py | 4 +- tests/contrib/tracking/test_assignment.py | 4 +- tests/contrib/tracking/test_distributions.py | 3 +- tests/contrib/tracking/test_dynamic_models.py | 3 +- tests/contrib/tracking/test_ekf.py | 3 +- tests/contrib/tracking/test_em.py | 1 - tests/contrib/tracking/test_measurements.py | 1 + tests/distributions/test_empirical.py | 2 +- tests/distributions/test_gaussian_mixtures.py | 6 +-- tests/distributions/test_haar.py | 2 +- tests/distributions/test_ig.py | 2 +- tests/distributions/test_mask.py | 3 +- tests/distributions/test_mvt.py | 1 - tests/distributions/test_omt_mvn.py | 2 +- tests/distributions/test_ordered_logistic.py | 3 +- tests/distributions/test_pickle.py | 2 +- tests/distributions/test_transforms.py | 5 +- tests/doctest_fixtures.py | 5 +- tests/infer/mcmc/test_adaptation.py | 7 +-- tests/infer/mcmc/test_hmc.py | 4 +- tests/infer/mcmc/test_mcmc_api.py | 2 +- tests/infer/mcmc/test_nuts.py | 8 +-- tests/infer/test_abstract_infer.py | 3 +- tests/infer/test_autoguide.py | 5 +- tests/infer/test_predictive.py | 2 +- tests/infer/test_svgd.py | 6 +-- tests/infer/test_tmc.py | 3 +- tests/infer/test_util.py | 2 +- tests/ops/test_arrowhead.py | 1 - tests/ops/test_gamma_gaussian.py | 8 +-- tests/ops/test_newton.py | 1 - tests/perf/test_benchmark.py | 2 +- tests/poutine/test_nesting.py | 3 +- tests/poutine/test_poutines.py | 4 +- tests/poutine/test_trace_struct.py | 1 - tests/pyroapi/test_pyroapi.py | 1 - tests/test_generic.py | 4 +- tests/test_primitives.py | 3 +- tests/test_util.py | 3 +- 175 files changed, 302 insertions(+), 399 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 250af8e937..b66e2c7f46 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,7 +6,6 @@ import sphinx_rtd_theme - # import pkg_resources # -*- coding: utf-8 -*- diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 335bcd8965..9fd7e10254 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -14,6 +14,13 @@ Reference: MuE models were described in Weinstein and Marks (2020), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. +Example MuE Models +------------------ +.. automodule:: pyro.contrib.mue.models + :members: + :show-inheritance: + :member-order: bysource + State Arrangers for Parameterizing MuEs --------------------------------------- .. automodule:: pyro.contrib.mue.statearrangers @@ -23,7 +30,7 @@ State Arrangers for Parameterizing MuEs Variable Length/Missing Data HMM -------------------------------- -.. automodule:: pyro.contrib.mue.variablelengthhmm +.. automodule:: pyro.contrib.mue.missingdatahmm :members: :show-inheritance: :member-order: bysource diff --git a/examples/air/air.py b/examples/air/air.py index a9b8958d0f..985e2853bf 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -14,10 +14,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import MLP, Decoder, Encoder, Identity, Predict import pyro import pyro.distributions as dist -from modules import MLP, Decoder, Encoder, Identity, Predict # Default prior success probability for z_pres. diff --git a/examples/air/main.py b/examples/air/main.py index 8cf13b255e..34be8303b6 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -19,15 +19,15 @@ import numpy as np import torch import visdom +from air import AIR, latents_to_tensor +from viz import draw_many, tensor_to_objs import pyro import pyro.contrib.examples.multi_mnist as multi_mnist import pyro.optim as optim import pyro.poutine as poutine -from air import AIR, latents_to_tensor from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO -from viz import draw_many, tensor_to_objs def count_accuracy(X, true_counts, air, batch_size): diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index 65b709afca..fa868899d5 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -39,11 +39,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer.autoguide import AutoDiagonalNormal from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO +from pyro.infer.autoguide import AutoDiagonalNormal from pyro.optim import Adam - """ Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival probability phi; diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 842e9f03c0..1d4adb8ec8 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -2,16 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import argparse + import torch from torch.distributions import constraints import pyro -import pyro.optim import pyro.distributions as dist - -from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO - +import pyro.optim from pyro.contrib.autoname import scope +from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate def model(K, data): diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 5c860d1470..32f7331e77 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -64,7 +64,6 @@ from pyroapi import distributions as dist from pyroapi import handlers, infer, optim, pyro, pyro_backend - logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 338ad6a19f..616cc297d0 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -39,7 +39,7 @@ import pyro import pyro.contrib.gp as gp import pyro.infer as infer -from pyro.contrib.examples.util import get_data_loader, get_data_directory +from pyro.contrib.examples.util import get_data_directory, get_data_loader class CNN(nn.Module): diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 16842b7c05..b713d28306 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -3,20 +3,18 @@ import argparse from functools import partial + +import numpy as np import torch +from gp_bayes_opt import GPBayesOptimizer from torch.distributions import constraints -import numpy as np import pyro +import pyro.contrib.gp as gp from pyro import optim -from pyro.infer import TraceEnum_ELBO from pyro.contrib.oed.eig import vi_eig -import pyro.contrib.gp as gp -from pyro.contrib.oed.glmm import ( - zero_mean_unit_obs_sd_lm, group_assignment_matrix, analytic_posterior_cov -) - -from gp_bayes_opt import GPBayesOptimizer +from pyro.contrib.oed.glmm import analytic_posterior_cov, group_assignment_matrix, zero_mean_unit_obs_sd_lm +from pyro.infer import TraceEnum_ELBO """ Example builds on the Bayesian regression tutorial [1]. It demonstrates how diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 3c114c9bcf..6132dee48a 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -7,8 +7,8 @@ from torch.distributions import transform_to import pyro.contrib.gp as gp -from pyro.infer import TraceEnum_ELBO import pyro.optim +from pyro.infer import TraceEnum_ELBO class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index 81d2e1316e..2ca6e03bc4 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -1,16 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import argparse +from os.path import exists +from urllib.request import urlopen + import numpy as np import torch import pyro from pyro.contrib.timeseries import IndependentMaternGP, LinearlyCoupledMaternGP -import argparse -from os.path import exists -from urllib.request import urlopen - # download dataset from UCI archive def download_data(): diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index cb5d279445..23e1591016 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import numpy as np from pathlib import Path -from tqdm import tqdm + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm class BaselineNet(nn.Module): diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index fb792cf4d3..f499aaf452 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -1,15 +1,17 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import numpy as np from pathlib import Path -import pyro -import pyro.distributions as dist -from pyro.infer import SVI, Trace_ELBO + +import numpy as np import torch import torch.nn as nn from tqdm import tqdm +import pyro +import pyro.distributions as dist +from pyro.infer import SVI, Trace_ELBO + class Encoder(nn.Module): def __init__(self, z_dim, hidden_1, hidden_2): diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 224b4f05af..dea506db5a 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import pandas as pd -import pyro -import torch + import baseline import cvae -from util import get_data, visualize, generate_table +import pandas as pd +import torch +from util import generate_table, get_data, visualize + +import pyro def main(args): diff --git a/examples/cvae/mnist.py b/examples/cvae/mnist.py index a98c667081..12dd7409f2 100644 --- a/examples/cvae/mnist.py +++ b/examples/cvae/mnist.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset from torchvision.datasets import MNIST from torchvision.transforms import Compose, functional diff --git a/examples/cvae/util.py b/examples/cvae/util.py index e578085946..87650298ef 100644 --- a/examples/cvae/util.py +++ b/examples/cvae/util.py @@ -1,17 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import numpy as np +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from pathlib import Path -from pyro.infer import Predictive, Trace_ELBO import torch +from baseline import MaskedBCELoss +from mnist import get_data from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm -from baseline import MaskedBCELoss -from mnist import get_data + +from pyro.infer import Predictive, Trace_ELBO def imshow(inp, image_path=None): diff --git a/examples/eight_schools/data.py b/examples/eight_schools/data.py index 39529e798a..56158fa36e 100644 --- a/examples/eight_schools/data.py +++ b/examples/eight_schools/data.py @@ -3,7 +3,6 @@ import torch - J = 8 y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index 7b927d43e0..ec6ff9b5c6 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -4,9 +4,9 @@ import argparse import logging +import data import torch -import data import pyro import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index 14e7d32c33..7f70044707 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -5,11 +5,11 @@ import logging import torch +from data import J, sigma, y from torch.distributions import constraints, transforms import pyro import pyro.distributions as dist -from data import J, sigma, y from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam diff --git a/examples/hmm.py b/examples/hmm.py index 52086a5067..417b9e171f 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -47,8 +47,8 @@ import pyro.contrib.examples.polyphonic_data_loader as poly import pyro.distributions as dist from pyro import poutine -from pyro.infer.autoguide import AutoDelta from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO +from pyro.infer.autoguide import AutoDelta from pyro.ops.indexing import Vindex from pyro.optim import Adam from pyro.util import ignore_jit_warnings diff --git a/examples/lkj.py b/examples/lkj.py index 56d26ab33d..2c8610ca23 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import argparse + import torch import pyro import pyro.distributions as dist -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc import NUTS +from pyro.infer.mcmc.api import MCMC """ This simple example is intended to demonstrate how to use an LKJ prior with diff --git a/examples/minipyro.py b/examples/minipyro.py index e12775dfda..9855b9694c 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -11,8 +11,8 @@ import torch -from pyro.generic import distributions as dist # We use the pyro.generic interface to support dynamic choice of backend. +from pyro.generic import distributions as dist from pyro.generic import infer, ops, optim, pyro, pyro_backend diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index 65584c6769..bd4bd02e6a 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -2,20 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import os +import functools import json +import os import uuid -import functools import torch +from model import guide_generic, model_generic +from seal_data import prepare_seal import pyro import pyro.poutine as poutine from pyro.infer import TraceEnum_ELBO -from model import model_generic, guide_generic -from seal_data import prepare_seal - def aic_num_parameters(model, guide=None): """ diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py index 390201a8d9..609fc69da3 100644 --- a/examples/mixed_hmm/seal_data.py +++ b/examples/mixed_hmm/seal_data.py @@ -5,10 +5,8 @@ from urllib.request import urlopen import pandas as pd - import torch - MISSING = 1e-6 diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 1e617316c4..4fb059b748 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -9,18 +9,17 @@ [1] https://gscontras.github.io/probLang/chapters/07-generics.html """ -import torch - import argparse -import numbers import collections +import numbers + +import torch +from search_inference import HashingMarginal, Search, memoize import pyro import pyro.distributions as dist import pyro.poutine as poutine -from search_inference import HashingMarginal, memoize, Search - torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index 04d878fa8b..429715137d 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -7,17 +7,16 @@ Taken from: https://gscontras.github.io/probLang/chapters/03-nonliteral.html """ -import torch - -import collections import argparse +import collections + +import torch +from search_inference import HashingMarginal, Search, memoize import pyro import pyro.distributions as dist import pyro.poutine as poutine -from search_inference import HashingMarginal, memoize, Search - torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index ea31af2811..ce8bfd219e 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -11,12 +11,13 @@ Taken from: http://forestdb.org/models/schelling.html """ import argparse + import torch +from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli -from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index 4a5ffcdf98..82ab6aedd0 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -12,12 +12,13 @@ Taken from: http://forestdb.org/models/schelling-falsebelief.html """ import argparse + import torch +from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli -from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/search_inference.py b/examples/rsa/search_inference.py index 14a49766f1..7e2cb8e142 100644 --- a/examples/rsa/search_inference.py +++ b/examples/rsa/search_inference.py @@ -8,10 +8,10 @@ """ import collections +import functools +import queue import torch -import queue -import functools import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index 0a998c6227..11a39ef10a 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -7,16 +7,15 @@ Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html """ -import torch - import argparse import collections +import torch +from search_inference import BestFirstSearch, HashingMarginal, memoize + import pyro import pyro.distributions as dist -from search_inference import HashingMarginal, BestFirstSearch, memoize - torch.set_default_dtype(torch.float64) diff --git a/examples/scanvi/data.py b/examples/scanvi/data.py index 690eab0717..429883d1a3 100644 --- a/examples/scanvi/data.py +++ b/examples/scanvi/data.py @@ -8,11 +8,11 @@ """ import math -import numpy as np -from scipy import sparse +import numpy as np import torch import torch.nn as nn +from scipy import sparse class BatchDataLoader(object): @@ -122,8 +122,8 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): return BatchDataLoader(X, Y, batch_size), num_genes, 2.0, 1.0, None - import scvi import scanpy as sc + import scvi adata = scvi.data.purified_pbmc_dataset(subset_datasets=["regulatory_t", "naive_t", "memory_t", "naive_cytotoxic"]) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 2e4d4bc760..5a9e1cffb3 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -19,25 +19,22 @@ import argparse +import matplotlib.pyplot as plt import numpy as np - import torch import torch.nn as nn +from data import get_data +from matplotlib.patches import Patch from torch.distributions import constraints -from torch.nn.functional import softplus, softmax +from torch.nn.functional import softmax, softplus from torch.optim import Adam import pyro import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.util import broadcast_shape +from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate from pyro.optim import MultiStepLR -from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO - -import matplotlib.pyplot as plt -from matplotlib.patches import Patch - -from data import get_data # Helper for making fully-connected neural networks @@ -300,6 +297,7 @@ def main(args): # Now that we're done training we'll inspect the latent representations we've learned if args.plot and args.dataset == 'pbmc': import scanpy as sc + # Compute latent representation (z2_loc) for each cell in the dataset latent_rep = scanvi.z2l_encoder(dataloader.data_x)[0] diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index 3af6153609..ae37f2cd0a 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -20,19 +20,16 @@ import numpy as np import torch +import wget from torch.nn.functional import softplus import pyro import pyro.optim as optim -import wget - +from pyro.contrib.easyguide import EasyGuide from pyro.contrib.examples.util import get_data_directory -from pyro.distributions import Gamma, Poisson, Normal +from pyro.distributions import Gamma, Normal, Poisson from pyro.infer import SVI, TraceMeanField_ELBO -from pyro.infer.autoguide import AutoDiagonalNormal -from pyro.infer.autoguide import init_to_feasible -from pyro.contrib.easyguide import EasyGuide - +from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible torch.set_default_tensor_type('torch.FloatTensor') pyro.util.set_rng_seed(0) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 9ae9421417..639ac44d73 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -2,20 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import math import numpy as np import torch -import math +from torch.optim import Adam import pyro import pyro.distributions as dist from pyro import poutine -from pyro.infer.autoguide import AutoDelta from pyro.infer import Trace_ELBO -from pyro.infer.autoguide import init_to_median - -from torch.optim import Adam - +from pyro.infer.autoguide import AutoDelta, init_to_median """ We demonstrate how to do sparse linear regression using a variant of the diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index 2f720a2ce0..627dded294 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn +from utils.custom_mlp import MLP, Exp +from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders +from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae from visdom import Visdom import pyro @@ -12,9 +15,6 @@ from pyro.contrib.examples.util import print_and_log from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate from pyro.optim import Adam -from utils.custom_mlp import MLP, Exp -from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders -from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae class SSVAE(nn.Module): diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 396ff788b8..98f19533dc 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -7,14 +7,14 @@ import torch import torch.nn as nn import visdom +from utils.mnist_cached import MNISTCached as MNIST +from utils.mnist_cached import setup_data_loaders +from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples import pyro import pyro.distributions as dist from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam -from utils.mnist_cached import MNISTCached as MNIST -from utils.mnist_cached import setup_data_loaders -from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples # define the PyTorch module that parameterizes the diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index f4291e5e35..60f9eddcb3 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -10,13 +10,13 @@ import torch.nn as nn from torch.nn import functional from torchvision.utils import save_image +from utils.mnist_cached import DATA_DIR, RESULTS_DIR import pyro from pyro.contrib.examples import util from pyro.distributions import Bernoulli, Normal from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam -from utils.mnist_cached import DATA_DIR, RESULTS_DIR """ Comparison of VAE implementation in PyTorch and Pyro. This example can be diff --git a/profiler/hmm.py b/profiler/hmm.py index 4308c3df56..1825c82c20 100644 --- a/profiler/hmm.py +++ b/profiler/hmm.py @@ -8,13 +8,12 @@ import subprocess import sys from collections import defaultdict -from os.path import join, abspath +from os.path import abspath, join from numpy import median from pyro.util import timed - EXAMPLES_DIR = join(abspath(__file__), os.pardir, os.pardir, "examples") diff --git a/profiler/profiling_utils.py b/profiler/profiling_utils.py index 8375132eb2..aee4f9b564 100644 --- a/profiler/profiling_utils.py +++ b/profiler/profiling_utils.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import cProfile -from io import StringIO import functools import os import pstats import timeit from contextlib import contextmanager +from io import StringIO from prettytable import ALL, PrettyTable diff --git a/pyro/contrib/__init__.py b/pyro/contrib/__init__.py index 3f14bd1862..045ec5435f 100644 --- a/pyro/contrib/__init__.py +++ b/pyro/contrib/__init__.py @@ -25,6 +25,7 @@ try: import funsor as funsor_ # noqa: F401 + from pyro.contrib import funsor __all__ += ["funsor"] except ImportError: diff --git a/pyro/contrib/autoname/__init__.py b/pyro/contrib/autoname/__init__.py index 6f396f55b6..d3e72366d9 100644 --- a/pyro/contrib/autoname/__init__.py +++ b/pyro/contrib/autoname/__init__.py @@ -6,8 +6,7 @@ generating unique, semantically meaningful names for sample sites. """ from pyro.contrib.autoname import named -from pyro.contrib.autoname.scoping import scope, name_count - +from pyro.contrib.autoname.scoping import name_count, scope __all__ = [ "named", diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index 6a4f679e29..cc97b051fa 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from torch.distributions.utils import lazy_property import torch.nn.functional as F +from torch.distributions.utils import lazy_property from pyro.contrib.bnn.utils import adjoin_ones_vector from pyro.distributions.torch_distribution import TorchDistribution diff --git a/pyro/contrib/bnn/utils.py b/pyro/contrib/bnn/utils.py index ec2f33623a..794f66f984 100644 --- a/pyro/contrib/bnn/utils.py +++ b/pyro/contrib/bnn/utils.py @@ -1,9 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import math +import torch + def xavier_uniform(D_in, D_out): scale = math.sqrt(6.0 / float(D_in + D_out)) diff --git a/pyro/contrib/conjugate/infer.py b/pyro/contrib/conjugate/infer.py index 23a3fe791e..0c815c0126 100644 --- a/pyro/contrib/conjugate/infer.py +++ b/pyro/contrib/conjugate/infer.py @@ -6,8 +6,8 @@ import torch import pyro.distributions as dist -from pyro.distributions.util import sum_leftmost from pyro import poutine +from pyro.distributions.util import sum_leftmost from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample diff --git a/pyro/contrib/easyguide/__init__.py b/pyro/contrib/easyguide/__init__.py index 9e2577841f..d26c63c9cf 100644 --- a/pyro/contrib/easyguide/__init__.py +++ b/pyro/contrib/easyguide/__init__.py @@ -3,7 +3,6 @@ from pyro.contrib.easyguide.easyguide import EasyGuide, easy_guide - __all__ = [ "EasyGuide", "easy_guide", diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index fbc204466b..55535ae72d 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -14,8 +14,8 @@ import pyro.poutine as poutine import pyro.poutine.runtime as runtime from pyro.distributions.util import broadcast_shape, sum_rightmost -from pyro.infer.autoguide.initialization import InitMessenger from pyro.infer.autoguide.guides import prototype_hide_fn +from pyro.infer.autoguide.initialization import InitMessenger from pyro.nn.module import PyroModule, PyroParam diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index 0d89fee5fc..0398ad137d 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -14,7 +14,7 @@ import torch -from pyro.contrib.examples.util import get_data_directory, _mkdir_p +from pyro.contrib.examples.util import _mkdir_p, get_data_directory DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/finance.py b/pyro/contrib/examples/finance.py index 03572c0289..c40a0b55e8 100644 --- a/pyro/contrib/examples/finance.py +++ b/pyro/contrib/examples/finance.py @@ -6,7 +6,7 @@ import pandas as pd -from pyro.contrib.examples.util import get_data_directory, _mkdir_p +from pyro.contrib.examples.util import _mkdir_p, get_data_directory DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/polyphonic_data_loader.py b/pyro/contrib/examples/polyphonic_data_loader.py index 132c6d953d..491ae0517f 100644 --- a/pyro/contrib/examples/polyphonic_data_loader.py +++ b/pyro/contrib/examples/polyphonic_data_loader.py @@ -17,9 +17,9 @@ """ import os +import pickle from collections import namedtuple from urllib.request import urlopen -import pickle import torch import torch.nn as nn @@ -27,7 +27,6 @@ from pyro.contrib.examples.util import get_data_directory - dset = namedtuple("dset", ["name", "url", "filename"]) JSB_CHORALES = dset("jsb_chorales", diff --git a/pyro/contrib/forecast/util.py b/pyro/contrib/forecast/util.py index 8918b470da..f2bd7034b5 100644 --- a/pyro/contrib/forecast/util.py +++ b/pyro/contrib/forecast/util.py @@ -7,7 +7,7 @@ from torch.distributions import transform_to, transforms import pyro.distributions as dist -from pyro.infer.reparam import HaarReparam, DiscreteCosineReparam +from pyro.infer.reparam import DiscreteCosineReparam, HaarReparam from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample from pyro.primitives import get_param_store diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index 30a23d5ca2..dcb9355e5e 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -3,14 +3,12 @@ import pyroapi -from pyro.primitives import ( # noqa: F401 - clear_param_store, deterministic, enable_validation, factor, get_param_store, - module, param, random_module, sample, set_rng_seed, subsample, -) - -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 -from pyro.contrib.funsor.handlers import condition, do, markov, vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers import condition, do, markov # noqa: F401 from pyro.contrib.funsor.handlers import plate as _plate +from pyro.contrib.funsor.handlers import vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 +from pyro.primitives import (clear_param_store, deterministic, enable_validation, factor, get_param_store, # noqa: F401 + module, param, random_module, sample, set_rng_seed, subsample) def plate(*args, **kwargs): diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index a98be1de94..724ec29c83 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -1,20 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from pyro.poutine import (block, condition, do, escape, infer_config, mask, reparam, scale, seed, # noqa: F401 + uncondition) from pyro.poutine.handlers import _make_handler -from pyro.poutine import ( # noqa: F401 - block, condition, do, escape, infer_config, - mask, reparam, scale, seed, uncondition, -) - from .enum_messenger import EnumMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger from .trace_messenger import TraceMessenger - _msngrs = [ EnumMessenger, MarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 15caf49078..befbeb2014 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -9,18 +9,17 @@ import math from collections import OrderedDict -import torch import funsor +import torch import pyro.poutine.runtime import pyro.poutine.util -from pyro.poutine.escape_messenger import EscapeMessenger -from pyro.poutine.subsample_messenger import _Subsample - -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger +from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index fb7667fab3..e7cf18ffbc 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,9 +4,8 @@ from collections import OrderedDict from contextlib import ExitStack -from pyro.poutine.reentrant_messenger import ReentrantMessenger - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame +from pyro.poutine.reentrant_messenger import ReentrantMessenger class NamedMessenger(ReentrantMessenger): diff --git a/pyro/contrib/funsor/handlers/plate_messenger.py b/pyro/contrib/funsor/handlers/plate_messenger.py index f119db1de0..d180df021f 100644 --- a/pyro/contrib/funsor/handlers/plate_messenger.py +++ b/pyro/contrib/funsor/handlers/plate_messenger.py @@ -6,18 +6,16 @@ import funsor +from pyro.contrib.funsor.handlers.named_messenger import DimRequest, DimType, GlobalNamedMessenger, NamedMessenger +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.distributions.util import copy_docs_from from pyro.poutine.broadcast_messenger import BroadcastMessenger from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import effectful from pyro.poutine.subsample_messenger import SubsampleMessenger as OrigSubsampleMessenger from pyro.util import ignore_jit_warnings -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor -from pyro.contrib.funsor.handlers.named_messenger import DimRequest, DimType, GlobalNamedMessenger, \ - NamedMessenger -from pyro.poutine.runtime import effectful - funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/primitives.py b/pyro/contrib/funsor/handlers/primitives.py index 3d8815eff0..0b7a4c4edb 100644 --- a/pyro/contrib/funsor/handlers/primitives.py +++ b/pyro/contrib/funsor/handlers/primitives.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.poutine.runtime - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index 2389941049..ae672d2dd4 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger from pyro.contrib.funsor.handlers.primitives import to_data +from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger class ReplayMessenger(OrigReplayMessenger): diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index 4671901668..d6a995596d 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -3,11 +3,10 @@ import funsor -from pyro.poutine.subsample_messenger import _Subsample -from pyro.poutine.trace_messenger import TraceMessenger as OrigTraceMessenger - from pyro.contrib.funsor.handlers.primitives import to_funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK +from pyro.poutine.subsample_messenger import _Subsample +from pyro.poutine.trace_messenger import TraceMessenger as OrigTraceMessenger class TraceMessenger(OrigTraceMessenger): diff --git a/pyro/contrib/funsor/infer/__init__.py b/pyro/contrib/funsor/infer/__init__.py index 4525e2cef5..55f260e6a9 100644 --- a/pyro/contrib/funsor/infer/__init__.py +++ b/pyro/contrib/funsor/infer/__init__.py @@ -5,5 +5,5 @@ from .elbo import ELBO # noqa: F401 from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401 -from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 +from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 1912edca11..686926b772 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -5,14 +5,13 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import Trace_ELBO as _OrigTrace_ELBO - from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer import config_enumerate +from pyro.distributions.util import copy_docs_from +from pyro.infer import Trace_ELBO as _OrigTrace_ELBO -from .elbo import Jit_ELBO, ELBO +from .elbo import ELBO, Jit_ELBO from .traceenum_elbo import terms_from_trace diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index d725da2c50..c9e7eb0242 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -5,12 +5,11 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO - from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO def terms_from_trace(tr): diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index 8a714d8c03..aae66ce5c0 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -5,14 +5,12 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO - from pyro.contrib.funsor import to_data from pyro.contrib.funsor.handlers import enum, plate, replay, trace - from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO @copy_docs_from(_OrigTraceTMC_ELBO) diff --git a/pyro/contrib/gp/kernels/__init__.py b/pyro/contrib/gp/kernels/__init__.py index 9874e73c17..c36ddd37fb 100644 --- a/pyro/contrib/gp/kernels/__init__.py +++ b/pyro/contrib/gp/kernels/__init__.py @@ -4,10 +4,9 @@ from pyro.contrib.gp.kernels.brownian import Brownian from pyro.contrib.gp.kernels.coregionalize import Coregionalize from pyro.contrib.gp.kernels.dot_product import DotProduct, Linear, Polynomial -from pyro.contrib.gp.kernels.isotropic import (RBF, Exponential, Isotropy, Matern32, Matern52, - RationalQuadratic) -from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, - Transforming, VerticalScaling, Warping) +from pyro.contrib.gp.kernels.isotropic import RBF, Exponential, Isotropy, Matern32, Matern52, RationalQuadratic +from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, Transforming, VerticalScaling, + Warping) from pyro.contrib.gp.kernels.periodic import Cosine, Periodic from pyro.contrib.gp.kernels.static import Constant, WhiteNoise diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index 3041f5e92e..ef417f9e22 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/gaussian.py b/pyro/contrib/gp/likelihoods/gaussian.py index cb5a15d8c7..b1b65ff95c 100644 --- a/pyro/contrib/gp/likelihoods/gaussian.py +++ b/pyro/contrib/gp/likelihoods/gaussian.py @@ -6,7 +6,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood from pyro.nn.module import PyroParam diff --git a/pyro/contrib/gp/likelihoods/multi_class.py b/pyro/contrib/gp/likelihoods/multi_class.py index 9ff69e81f1..ed8463f8bd 100644 --- a/pyro/contrib/gp/likelihoods/multi_class.py +++ b/pyro/contrib/gp/likelihoods/multi_class.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/poisson.py b/pyro/contrib/gp/likelihoods/poisson.py index 48916e0634..8abed6fd2a 100644 --- a/pyro/contrib/gp/likelihoods/poisson.py +++ b/pyro/contrib/gp/likelihoods/poisson.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/oed/__init__.py b/pyro/contrib/oed/__init__.py index 006c57a7a1..3afd3a440d 100644 --- a/pyro/contrib/oed/__init__.py +++ b/pyro/contrib/oed/__init__.py @@ -67,7 +67,7 @@ def model(design): """ -from pyro.contrib.oed import search, eig +from pyro.contrib.oed import eig, search __all__ = [ "search", diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index 7faec28aa6..8d6c7ae22b 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -1,17 +1,18 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import math import warnings +import torch + import pyro from pyro import poutine -from pyro.infer.autoguide.utils import mean_field_entropy from pyro.contrib.oed.search import Search -from pyro.infer import EmpiricalMarginal, Importance, SVI -from pyro.util import torch_isnan, torch_isinf from pyro.contrib.util import lexpand +from pyro.infer import SVI, EmpiricalMarginal, Importance +from pyro.infer.autoguide.utils import mean_field_entropy +from pyro.util import torch_isinf, torch_isnan __all__ = [ "laplace_eig", diff --git a/pyro/contrib/oed/glmm/__init__.py b/pyro/contrib/oed/glmm/__init__.py index f9d75643ca..c17c221213 100644 --- a/pyro/contrib/oed/glmm/__init__.py +++ b/pyro/contrib/oed/glmm/__init__.py @@ -36,5 +36,5 @@ For random effects with a shared covariance matrix, see :meth:`pyro.contrib.oed.glmm.lmer_model`. """ -from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 from pyro.contrib.oed.glmm import guides # noqa: F401 +from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 diff --git a/pyro/contrib/oed/glmm/glmm.py b/pyro/contrib/oed/glmm/glmm.py index 2c418391e3..68507be53f 100644 --- a/pyro/contrib/oed/glmm/glmm.py +++ b/pyro/contrib/oed/glmm/glmm.py @@ -3,17 +3,17 @@ import warnings from collections import OrderedDict -from functools import partial from contextlib import ExitStack +from functools import partial import torch -from torch.nn.functional import softplus from torch.distributions import constraints from torch.distributions.transforms import AffineTransform, SigmoidTransform +from torch.nn.functional import softplus import pyro import pyro.distributions as dist -from pyro.contrib.util import rmv, iter_plates_to_shape +from pyro.contrib.util import iter_plates_to_shape, rmv # TODO read from torch float spec epsilon = torch.tensor(2**-24) diff --git a/pyro/contrib/oed/glmm/guides.py b/pyro/contrib/oed/glmm/guides.py index c71b06415c..d2425adff2 100644 --- a/pyro/contrib/oed/glmm/guides.py +++ b/pyro/contrib/oed/glmm/guides.py @@ -1,17 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from contextlib import ExitStack + import torch from torch import nn -from contextlib import ExitStack - import pyro import pyro.distributions as dist from pyro import poutine -from pyro.contrib.util import ( - tensor_to_dict, rmv, rvv, rtril, lexpand, iter_plates_to_shape -) +from pyro.contrib.util import iter_plates_to_shape, lexpand, rmv, rtril, rvv, tensor_to_dict from pyro.ops.linalg import rinverse diff --git a/pyro/contrib/oed/search.py b/pyro/contrib/oed/search.py index 4bf8eb1816..721f6305c3 100644 --- a/pyro/contrib/oed/search.py +++ b/pyro/contrib/oed/search.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import queue -from pyro.infer.abstract_infer import TracePosterior + import pyro.poutine as poutine +from pyro.infer.abstract_infer import TracePosterior ################################### # Search borrowed from RSA example diff --git a/pyro/contrib/oed/util.py b/pyro/contrib/oed/util.py index d5c315f85a..50774ff0bd 100644 --- a/pyro/contrib/oed/util.py +++ b/pyro/contrib/oed/util.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import math + import torch -from pyro.contrib.util import get_indices from pyro.contrib.oed.glmm import analytic_posterior_cov +from pyro.contrib.util import get_indices from pyro.infer.autoguide.utils import mean_field_entropy diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 1bc28f1caa..39c06748ae 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -4,17 +4,10 @@ from typing import Union from torch import Tensor + from pyro.distributions import TransformedDistribution -from pyro.distributions.transforms import ( - Transform, - AffineTransform, - AbsTransform, - PowerTransform, - ExpTransform, - TanhTransform, - SoftmaxTransform, - SigmoidTransform -) +from pyro.distributions.transforms import (AbsTransform, AffineTransform, ExpTransform, PowerTransform, + SigmoidTransform, SoftmaxTransform, TanhTransform, Transform) class RVMagicOps: diff --git a/pyro/contrib/timeseries/__init__.py b/pyro/contrib/timeseries/__init__.py index f119203f04..517c3f9550 100644 --- a/pyro/contrib/timeseries/__init__.py +++ b/pyro/contrib/timeseries/__init__.py @@ -6,7 +6,7 @@ models useful for forecasting applications. """ from pyro.contrib.timeseries.base import TimeSeriesModel -from pyro.contrib.timeseries.gp import IndependentMaternGP, LinearlyCoupledMaternGP, DependentMaternGP +from pyro.contrib.timeseries.gp import DependentMaternGP, IndependentMaternGP, LinearlyCoupledMaternGP from pyro.contrib.timeseries.lgssm import GenericLGSSM from pyro.contrib.timeseries.lgssmgp import GenericLGSSMWithGPNoiseModel diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index 641c764d19..fc3e61c6c1 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -5,9 +5,9 @@ from torch.distributions import constraints import pyro.distributions as dist -from pyro.distributions.torch_distribution import TorchDistribution from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement +from pyro.distributions.torch_distribution import TorchDistribution class EKFDistribution(TorchDistribution): diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index bd26b344e2..7ea41ad3a7 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -6,6 +6,7 @@ import torch from torch import nn from torch.nn import Parameter + import pyro.distributions as dist from pyro.distributions.util import eye_like diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index cb98c49ccd..8f24ea4360 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod import torch + from pyro.distributions.util import eye_like diff --git a/pyro/contrib/util.py b/pyro/contrib/util.py index 44ff34832b..e250639ca7 100644 --- a/pyro/contrib/util.py +++ b/pyro/contrib/util.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict + import torch + import pyro diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 25e46dc8bb..c51ee45ded 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -3,40 +3,19 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.coalescent import ( - CoalescentRateLikelihood, - CoalescentTimes, - CoalescentTimesWithRate, -) -from pyro.distributions.conditional import ( - ConditionalDistribution, - ConditionalTransform, - ConditionalTransformedDistribution, - ConditionalTransformModule, -) -from pyro.distributions.conjugate import ( - BetaBinomial, - DirichletMultinomial, - GammaPoisson, -) +from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimes, CoalescentTimesWithRate +from pyro.distributions.conditional import (ConditionalDistribution, ConditionalTransform, + ConditionalTransformedDistribution, ConditionalTransformModule) +from pyro.distributions.conjugate import BetaBinomial, DirichletMultinomial, GammaPoisson from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals -from pyro.distributions.diag_normal_mixture_shared_cov import ( - MixtureOfDiagNormalsSharedCovariance, -) +from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.extended import ExtendedBetaBinomial, ExtendedBinomial from pyro.distributions.folded import FoldedDistribution from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.hmm import ( - DiscreteHMM, - GammaGaussianHMM, - GaussianHMM, - GaussianMRF, - IndependentHMM, - LinearHMM, -) +from pyro.distributions.hmm import DiscreteHMM, GammaGaussianHMM, GaussianHMM, GaussianMRF, IndependentHMM, LinearHMM from pyro.distributions.improper_uniform import ImproperUniform from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJCorrCholesky @@ -48,10 +27,8 @@ from pyro.distributions.ordered_logistic import OrderedLogistic from pyro.distributions.polya_gamma import TruncatedPolyaGamma from pyro.distributions.rejector import Rejector -from pyro.distributions.relaxed_straight_through import ( - RelaxedBernoulliStraightThrough, - RelaxedOneHotCategoricalStraightThrough, -) +from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, + RelaxedOneHotCategoricalStraightThrough) from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import * # noqa F403 @@ -59,17 +36,9 @@ from pyro.distributions.torch_distribution import MaskedDistribution, TorchDistribution from pyro.distributions.torch_transform import ComposeTransformModule, TransformModule from pyro.distributions.unit import Unit -from pyro.distributions.util import ( - enable_validation, - is_validation_enabled, - validation_enabled, -) +from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled from pyro.distributions.von_mises_3d import VonMises3D -from pyro.distributions.zero_inflated import ( - ZeroInflatedDistribution, - ZeroInflatedNegativeBinomial, - ZeroInflatedPoisson, -) +from pyro.distributions.zero_inflated import ZeroInflatedDistribution, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson from . import constraints, kl, transforms diff --git a/pyro/distributions/ordered_logistic.py b/pyro/distributions/ordered_logistic.py index d6d288fafb..c8d4ef459d 100644 --- a/pyro/distributions/ordered_logistic.py +++ b/pyro/distributions/ordered_logistic.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch + from pyro.distributions import constraints from pyro.distributions.torch import Categorical diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index f7b2f64238..ff5e370f05 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -185,6 +185,7 @@ def _get_cpp_module(): global _cpp_module if _cpp_module is None: import os + from torch.utils.cpp_extension import load path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp") _cpp_module = load(name="cpp_spanning_tree", diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index f926f52330..df68c5a5c0 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -5,10 +5,7 @@ from torch.distributions.transforms import * # noqa F403 from torch.distributions.transforms import __all__ as torch_transforms -from pyro.distributions.constraints import ( - IndependentConstraint, - corr_cholesky_constraint, - ordered_vector) +from pyro.distributions.constraints import IndependentConstraint, corr_cholesky_constraint, ordered_vector from pyro.distributions.torch_transform import ComposeTransformModule from pyro.distributions.transforms.affine_autoregressive import (AffineAutoregressive, ConditionalAffineAutoregressive, affine_autoregressive, diff --git a/pyro/distributions/transforms/ordered.py b/pyro/distributions/transforms/ordered.py index 79aea6a261..95497e45fa 100644 --- a/pyro/distributions/transforms/ordered.py +++ b/pyro/distributions/transforms/ordered.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from pyro.distributions.transforms import Transform + from pyro.distributions import constraints +from pyro.distributions.transforms import Transform class OrderedTransform(Transform): diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 421e5a16fa..5485b7476e 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -17,13 +17,13 @@ from pyro.infer.smcfilter import SMCFilter from pyro.infer.svgd import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.svi import SVI -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO from pyro.infer.trace_mean_field_elbo import JitTraceMeanField_ELBO, TraceMeanField_ELBO from pyro.infer.trace_mmd import Trace_MMD from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 0424949e7d..8891c9cc64 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -10,8 +10,8 @@ import pyro.poutine as poutine from pyro.distributions import Categorical, Empirical -from pyro.poutine.util import site_is_subsample from pyro.ops.stats import waic +from pyro.poutine.util import site_is_subsample class EmpiricalMarginal(Empirical): diff --git a/pyro/infer/autoguide/initialization.py b/pyro/infer/autoguide/initialization.py index 70b1a430f3..69ccdde4ca 100644 --- a/pyro/infer/autoguide/initialization.py +++ b/pyro/infer/autoguide/initialization.py @@ -20,7 +20,6 @@ from pyro.poutine.messenger import Messenger from pyro.util import torch_isnan - # TODO: move this file out of `autoguide` in a minor release def _is_multivariate(d): diff --git a/pyro/infer/mcmc/__init__.py b/pyro/infer/mcmc/__init__.py index e33cbec518..99d241b162 100644 --- a/pyro/infer/mcmc/__init__.py +++ b/pyro/infer/mcmc/__init__.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS __all__ = [ diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index 46497a53b5..c8d41924f6 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -9,7 +9,7 @@ import pyro from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul from pyro.ops.dual_averaging import DualAveraging -from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance +from pyro.ops.welford import WelfordArrowheadCovariance, WelfordCovariance adapt_window = namedtuple("adapt_window", ["start", "end"]) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 857004212f..139d7f4a0d 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -8,9 +8,8 @@ import pyro import pyro.distributions as dist -from pyro.distributions.util import scalar_like from pyro.distributions.testing.fakes import NonreparameterizedNormal - +from pyro.distributions.util import scalar_like from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 372104c53f..89831a93f8 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import traceback as tb import warnings from collections import OrderedDict, defaultdict from functools import partial, reduce from itertools import product -import traceback as tb import torch -from torch.distributions import biject_to from opt_einsum import shared_intermediates +from torch.distributions import biject_to import pyro import pyro.poutine as poutine diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 1b946accfe..857b54df77 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from functools import reduce import warnings +from functools import reduce import torch diff --git a/pyro/infer/reparam/neutra.py b/pyro/infer/reparam/neutra.py index f9753d9051..2d6b9c8bb2 100644 --- a/pyro/infer/reparam/neutra.py +++ b/pyro/infer/reparam/neutra.py @@ -8,6 +8,7 @@ from pyro import poutine from pyro.distributions.util import sum_rightmost from pyro.infer.autoguide.guides import AutoContinuous + from .reparam import Reparam diff --git a/pyro/infer/svgd.py b/pyro/infer/svgd.py index ebc526a86d..9e722a745f 100644 --- a/pyro/infer/svgd.py +++ b/pyro/infer/svgd.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from abc import ABCMeta, abstractmethod import math +from abc import ABCMeta, abstractmethod import torch from torch.distributions import biject_to @@ -10,10 +10,10 @@ import pyro from pyro import poutine from pyro.distributions import Delta -from pyro.infer.trace_elbo import Trace_ELBO +from pyro.distributions.util import copy_docs_from from pyro.infer.autoguide.guides import AutoContinuous from pyro.infer.autoguide.initialization import init_to_sample -from pyro.distributions.util import copy_docs_from +from pyro.infer.trace_elbo import Trace_ELBO def vectorize(fn, num_particles, max_plate_nesting): diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index d04210ee56..6eab28596d 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -10,7 +10,7 @@ import pyro.ops.jit from pyro.distributions.util import scale_and_mask from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import is_validation_enabled, torch_item, check_fully_reparametrized +from pyro.infer.util import check_fully_reparametrized, is_validation_enabled, torch_item from pyro.util import warn_if_nan diff --git a/pyro/infer/trace_mmd.py b/pyro/infer/trace_mmd.py index 1cc71992b9..661ff727c2 100644 --- a/pyro/infer/trace_mmd.py +++ b/pyro/infer/trace_mmd.py @@ -9,8 +9,8 @@ import pyro.ops.jit from pyro import poutine from pyro.infer.elbo import ELBO -from pyro.infer.util import torch_item, is_validation_enabled from pyro.infer.enum import get_importance_trace +from pyro.infer.util import is_validation_enabled, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/trace_tail_adaptive_elbo.py b/pyro/infer/trace_tail_adaptive_elbo.py index a69ea6d191..b05251a300 100644 --- a/pyro/infer/trace_tail_adaptive_elbo.py +++ b/pyro/infer/trace_tail_adaptive_elbo.py @@ -6,7 +6,7 @@ import torch from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import is_validation_enabled, check_fully_reparametrized +from pyro.infer.util import check_fully_reparametrized, is_validation_enabled class TraceTailAdaptive_ELBO(Trace_ELBO): diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index e2e483eec7..504060963f 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -1,10 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import queue import warnings import weakref from collections import OrderedDict -import queue import torch from opt_einsum import shared_intermediates diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 852ea6e658..8e7bc7ed6f 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -11,8 +11,7 @@ from pyro.distributions.util import detach, is_identically_zero from pyro.infer import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import (MultiFrameTensor, get_plate_stacks, - is_validation_enabled, torch_backward, torch_item) +from pyro.infer.util import MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_backward, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/tracetmc_elbo.py b/pyro/infer/tracetmc_elbo.py index 51c3ba3b78..f78b277080 100644 --- a/pyro/infer/tracetmc_elbo.py +++ b/pyro/infer/tracetmc_elbo.py @@ -7,7 +7,6 @@ import torch import pyro.poutine as poutine - from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO from pyro.infer.enum import get_importance_trace, iter_discrete_escape, iter_discrete_extend diff --git a/pyro/logger.py b/pyro/logger.py index 5bee771e4a..64a8c70c46 100644 --- a/pyro/logger.py +++ b/pyro/logger.py @@ -3,7 +3,6 @@ import logging - default_format = '%(levelname)s \t %(message)s' log = logging.getLogger("pyro") log.setLevel(logging.INFO) diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index 9641ab4ec2..4714d3095e 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -5,7 +5,6 @@ import torch - SymmArrowhead = namedtuple("SymmArrowhead", ["top", "bottom_diag"]) TriuArrowhead = namedtuple("TriuArrowhead", ["top", "bottom_diag"]) diff --git a/pyro/ops/einsum/torch_map.py b/pyro/ops/einsum/torch_map.py index 6e2832bcff..e4293c1140 100644 --- a/pyro/ops/einsum/torch_map.py +++ b/pyro/ops/einsum/torch_map.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import operator - from functools import reduce from pyro.ops import packed diff --git a/pyro/ops/einsum/torch_sample.py b/pyro/ops/einsum/torch_sample.py index 06c8108886..5420c328ba 100644 --- a/pyro/ops/einsum/torch_sample.py +++ b/pyro/ops/einsum/torch_sample.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import operator - from functools import reduce import pyro.distributions as dist diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index 6611d3a93c..e651b4284e 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -4,8 +4,8 @@ import torch from torch.autograd import grad +from pyro.ops.linalg import eig_3d, rinverse from pyro.util import warn_if_nan -from pyro.ops.linalg import rinverse, eig_3d def newton_step(loss, x, trust_radius=None): diff --git a/pyro/ops/ssm_gp.py b/pyro/ops/ssm_gp.py index eb88ba9d70..89abcb2912 100644 --- a/pyro/ops/ssm_gp.py +++ b/pyro/ops/ssm_gp.py @@ -6,7 +6,7 @@ import torch from torch.distributions import constraints -from pyro.nn import PyroModule, pyro_method, PyroParam +from pyro.nn import PyroModule, PyroParam, pyro_method root_three = math.sqrt(3.0) root_five = math.sqrt(5.0) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index f05886e22e..d1ae84281e 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -6,8 +6,8 @@ import torch +from .fft import irfft, rfft from .tensor_utils import next_fast_len -from .fft import rfft, irfft def _compute_chain_variance_stats(input): diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 20a5b1d5e1..2ff2a57ae9 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.util import ignore_jit_warnings + from .messenger import Messenger diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 9ed1575857..0b65987932 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -7,6 +7,7 @@ import torch from pyro.util import ignore_jit_warnings + from .messenger import Messenger from .runtime import _DIM_ALLOCATOR diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 39cd234bb9..176c7a772f 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict import sys +from collections import OrderedDict import opt_einsum diff --git a/tests/__init__.py b/tests/__init__.py index 200bfc2d65..4056718ce7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging - import os # create log handler for tests diff --git a/tests/conftest.py b/tests/conftest.py index 2cfcba39d4..699cca55c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ import pyro - torch.set_default_tensor_type(os.environ.get('PYRO_TENSOR_TYPE', 'torch.DoubleTensor')) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 9e9002d507..767b763ff4 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -12,10 +12,10 @@ import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.distributions.transforms import iterated, block_autoregressive +from pyro.distributions.transforms import block_autoregressive, iterated +from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide import (AutoDiagonalNormal, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal) -from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide.guides import AutoNormalizingFlow from tests.common import assert_equal from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/contrib/autoguide/test_mean_field_entropy.py index 9f8c301b32..2f5cd163db 100644 --- a/tests/contrib/autoguide/test_mean_field_entropy.py +++ b/tests/contrib/autoguide/test_mean_field_entropy.py @@ -1,9 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch -import scipy.special as sc import pytest +import scipy.special as sc +import torch import pyro import pyro.distributions as dist diff --git a/tests/contrib/autoname/test_scoping.py b/tests/contrib/autoname/test_scoping.py index d10d6f2d7f..aa7e44bae6 100644 --- a/tests/contrib/autoname/test_scoping.py +++ b/tests/contrib/autoname/test_scoping.py @@ -8,7 +8,7 @@ import pyro import pyro.distributions.torch as dist import pyro.poutine as poutine -from pyro.contrib.autoname import scope, name_count +from pyro.contrib.autoname import name_count, scope logger = logging.getLogger(__name__) diff --git a/tests/contrib/bnn/test_hidden_layer.py b/tests/contrib/bnn/test_hidden_layer.py index 1067cc03f2..c688572d0f 100644 --- a/tests/contrib/bnn/test_hidden_layer.py +++ b/tests/contrib/bnn/test_hidden_layer.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch import torch.nn.functional as F from torch.distributions import Normal -import pytest from pyro.contrib.bnn import HiddenLayer from tests.common import assert_equal diff --git a/tests/contrib/epidemiology/test_quant.py b/tests/contrib/epidemiology/test_quant.py index f9dc53bb64..d2a0edbc1d 100644 --- a/tests/contrib/epidemiology/test_quant.py +++ b/tests/contrib/epidemiology/test_quant.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - import torch from pyro.contrib.epidemiology.util import compute_bin_probs diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 75fff55463..9b273e5e2e 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -7,7 +7,6 @@ import pyroapi import pytest import torch - from torch.autograd import grad from torch.distributions import constraints @@ -17,9 +16,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro + + import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_named_handlers.py b/tests/contrib/funsor/test_named_handlers.py index 48c464daa3..c4c57b7bd5 100644 --- a/tests/contrib/funsor/test_named_handlers.py +++ b/tests/contrib/funsor/test_named_handlers.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict import logging +from collections import OrderedDict import pytest import torch @@ -11,6 +11,7 @@ try: import funsor from funsor.tensor import Tensor + import pyro.contrib.funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger funsor.set_backend("torch") diff --git a/tests/contrib/funsor/test_pyroapi_funsor.py b/tests/contrib/funsor/test_pyroapi_funsor.py index 74dbf972e3..9e050462e9 100644 --- a/tests/contrib/funsor/test_pyroapi_funsor.py +++ b/tests/contrib/funsor/test_pyroapi_funsor.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 try: # triggers backend registration import funsor + import pyro.contrib.funsor # noqa: F401 funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index cc1ab52178..54d4eedaae 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -14,9 +14,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro, pyro_backend + + import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_enum.py b/tests/contrib/funsor/test_valid_models_enum.py index 3ef3241ad2..7df1b23a90 100644 --- a/tests/contrib/funsor/test_valid_models_enum.py +++ b/tests/contrib/funsor/test_valid_models_enum.py @@ -1,10 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict import contextlib import logging import os +from collections import defaultdict from queue import LifoQueue import pytest @@ -19,11 +19,12 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import infer, handlers, pyro, pyro_backend + from pyroapi import handlers, infer, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_plate.py b/tests/contrib/funsor/test_valid_models_plate.py index ed20ee4be4..f5d30fc1b7 100644 --- a/tests/contrib/funsor/test_valid_models_plate.py +++ b/tests/contrib/funsor/test_valid_models_plate.py @@ -12,9 +12,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro + + import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_valid_models_sequential_plate.py b/tests/contrib/funsor/test_valid_models_sequential_plate.py index 1de6af5b08..40eeb79cb6 100644 --- a/tests/contrib/funsor/test_valid_models_sequential_plate.py +++ b/tests/contrib/funsor/test_valid_models_sequential_plate.py @@ -11,9 +11,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro + + import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index ac34469f26..d1a5008a10 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -3,7 +3,6 @@ import pytest import torch - from torch.distributions import constraints from pyro.ops.indexing import Vindex @@ -12,10 +11,11 @@ try: import funsor from funsor.testing import assert_close - import pyro.contrib.funsor from pyroapi import distributions as dist + + import pyro.contrib.funsor funsor.set_backend("torch") - from pyroapi import handlers, pyro, pyro_backend, infer + from pyroapi import handlers, infer, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/gp/test_kernels.py b/tests/contrib/gp/test_kernels.py index cc9797ff2c..db1c803786 100644 --- a/tests/contrib/gp/test_kernels.py +++ b/tests/contrib/gp/test_kernels.py @@ -6,9 +6,8 @@ import pytest import torch -from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, - Exponential, Linear, Matern32, Matern52, Periodic, - Polynomial, Product, RationalQuadratic, Sum, +from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, Exponential, Linear, + Matern32, Matern52, Periodic, Polynomial, Product, RationalQuadratic, Sum, VerticalScaling, Warping, WhiteNoise) from tests.common import assert_equal diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index 71fbe663ad..c63c1ce9a5 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -11,7 +11,6 @@ from pyro.contrib.gp.models import VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train - T = namedtuple("TestGPLikelihood", ["model_class", "X", "y", "kernel", "likelihood"]) X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0], [3.0, 4.0, 6.0]]) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index d5afa24ec2..711089025d 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -8,13 +8,12 @@ import torch import pyro.distributions as dist -from pyro.contrib.gp.kernels import Cosine, Matern32, RBF, WhiteNoise +from pyro.contrib.gp.kernels import RBF, Cosine, Matern32, WhiteNoise from pyro.contrib.gp.likelihoods import Gaussian -from pyro.contrib.gp.models import (GPLVM, GPRegression, SparseGPRegression, - VariationalGP, VariationalSparseGP) +from pyro.contrib.gp.models import GPLVM, GPRegression, SparseGPRegression, VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.nn.module import PyroSample from tests.common import assert_equal diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index 57e8e4e7da..aa8df92188 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -3,9 +3,9 @@ import math +import pytest import torch -import pytest from pyro.contrib.oed.eig import EwmaLog from tests.common import assert_equal diff --git a/tests/contrib/oed/test_finite_spaces_eig.py b/tests/contrib/oed/test_finite_spaces_eig.py index 49fc02493a..b6f69234d4 100644 --- a/tests/contrib/oed/test_finite_spaces_eig.py +++ b/tests/contrib/oed/test_finite_spaces_eig.py @@ -1,17 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import pytest +import torch import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import ( - nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, lfire_eig, - donsker_varadhan_eig) +from pyro.contrib.oed.eig import (donsker_varadhan_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, nmc_eig, + posterior_eig, vnmc_eig) from pyro.contrib.util import iter_plates_to_shape - from tests.common import assert_equal try: diff --git a/tests/contrib/oed/test_glmm.py b/tests/contrib/oed/test_glmm.py index 6e855525dd..cb3e95d169 100644 --- a/tests/contrib/oed/test_glmm.py +++ b/tests/contrib/oed/test_glmm.py @@ -8,10 +8,8 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.contrib.oed.glmm import ( - known_covariance_linear_model, group_linear_model, zero_mean_unit_obs_sd_lm, - normal_inverse_gamma_linear_model, logistic_regression_model, sigmoid_model -) +from pyro.contrib.oed.glmm import (group_linear_model, known_covariance_linear_model, logistic_regression_model, + normal_inverse_gamma_linear_model, sigmoid_model, zero_mean_unit_obs_sd_lm) from tests.common import assert_equal diff --git a/tests/contrib/oed/test_linear_models_eig.py b/tests/contrib/oed/test_linear_models_eig.py index 30280cb602..f84ba916e5 100644 --- a/tests/contrib/oed/test_linear_models_eig.py +++ b/tests/contrib/oed/test_linear_models_eig.py @@ -1,20 +1,19 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import pytest +import torch import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.infer import Trace_ELBO +from pyro.contrib.oed.eig import (donsker_varadhan_eig, laplace_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, + nmc_eig, posterior_eig, vnmc_eig) from pyro.contrib.oed.glmm import known_covariance_linear_model +from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from pyro.contrib.oed.util import linear_model_ground_truth -from pyro.contrib.oed.eig import ( - nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, laplace_eig, lfire_eig, - donsker_varadhan_eig) from pyro.contrib.util import rmv, rvv -from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide +from pyro.infer import Trace_ELBO from tests.common import assert_equal diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 5a1a43c194..4c392bc997 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -4,6 +4,7 @@ import math import torch.tensor as tt + from pyro.distributions import Uniform N_SAMPLES = 100 diff --git a/tests/contrib/test_util.py b/tests/contrib/test_util.py index 442ca61bec..60a3115dad 100644 --- a/tests/contrib/test_util.py +++ b/tests/contrib/test_util.py @@ -2,12 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict + import pytest import torch -from pyro.contrib.util import ( - get_indices, tensor_to_dict, rmv, rvv, lexpand, rexpand, rdiag, rtril -) +from pyro.contrib.util import get_indices, lexpand, rdiag, rexpand, rmv, rtril, rvv, tensor_to_dict from tests.common import assert_equal diff --git a/tests/contrib/timeseries/test_gp.py b/tests/contrib/timeseries/test_gp.py index 2698faa01b..e2e39a0aba 100644 --- a/tests/contrib/timeseries/test_gp.py +++ b/tests/contrib/timeseries/test_gp.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import math + +import pytest import torch -from tests.common import assert_equal import pyro -from pyro.contrib.timeseries import (IndependentMaternGP, LinearlyCoupledMaternGP, GenericLGSSM, - GenericLGSSMWithGPNoiseModel, DependentMaternGP) +from pyro.contrib.timeseries import (DependentMaternGP, GenericLGSSM, GenericLGSSMWithGPNoiseModel, IndependentMaternGP, + LinearlyCoupledMaternGP) from pyro.ops.tensor_utils import block_diag_embed -import pytest +from tests.common import assert_equal @pytest.mark.parametrize('model,obs_dim,nu_statedim', [('ssmgp', 3, 1.5), ('ssmgp', 2, 2.5), diff --git a/tests/contrib/timeseries/test_lgssm.py b/tests/contrib/timeseries/test_lgssm.py index f5c2dac137..5b5ed9d339 100644 --- a/tests/contrib/timeseries/test_lgssm.py +++ b/tests/contrib/timeseries/test_lgssm.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch -from tests.common import assert_equal from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel -import pytest +from tests.common import assert_equal @pytest.mark.parametrize('model_class', ['lgssm', 'lgssmgp']) diff --git a/tests/contrib/tracking/test_assignment.py b/tests/contrib/tracking/test_assignment.py index 554a373eb3..9c425dd502 100644 --- a/tests/contrib/tracking/test_assignment.py +++ b/tests/contrib/tracking/test_assignment.py @@ -1,12 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + import pytest import torch from torch.autograd import grad -import logging - import pyro import pyro.distributions as dist from pyro.contrib.tracking.assignment import MarginalAssignment, MarginalAssignmentPersistent, MarginalAssignmentSparse diff --git a/tests/contrib/tracking/test_distributions.py b/tests/contrib/tracking/test_distributions.py index 4c589ac221..fe4c149b49 100644 --- a/tests/contrib/tracking/test_distributions.py +++ b/tests/contrib/tracking/test_distributions.py @@ -1,13 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch from pyro.contrib.tracking.distributions import EKFDistribution from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous -import pytest - @pytest.mark.parametrize('Model', [NcpContinuous, NcvContinuous]) @pytest.mark.parametrize('dim', [2, 3]) diff --git a/tests/contrib/tracking/test_dynamic_models.py b/tests/contrib/tracking/test_dynamic_models.py index 51df52e75d..4f93afe523 100644 --- a/tests/contrib/tracking/test_dynamic_models.py +++ b/tests/contrib/tracking/test_dynamic_models.py @@ -3,8 +3,7 @@ import torch -from pyro.contrib.tracking.dynamic_models import (NcpContinuous, NcvContinuous, - NcvDiscrete, NcpDiscrete) +from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcpDiscrete, NcvContinuous, NcvDiscrete from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_ekf.py b/tests/contrib/tracking/test_ekf.py index 99cec4488c..35db1544d1 100644 --- a/tests/contrib/tracking/test_ekf.py +++ b/tests/contrib/tracking/test_ekf.py @@ -3,10 +3,9 @@ import torch -from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous +from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement - from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index 1d0fca7147..c3401f4114 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -16,7 +16,6 @@ from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton - logger = logging.getLogger(__name__) diff --git a/tests/contrib/tracking/test_measurements.py b/tests/contrib/tracking/test_measurements.py index 38f2afcd3d..373cad0e79 100644 --- a/tests/contrib/tracking/test_measurements.py +++ b/tests/contrib/tracking/test_measurements.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch + from pyro.contrib.tracking.measurements import PositionMeasurement diff --git a/tests/distributions/test_empirical.py b/tests/distributions/test_empirical.py index 7d220aa95e..3f2d4435dd 100644 --- a/tests/distributions/test_empirical.py +++ b/tests/distributions/test_empirical.py @@ -5,7 +5,7 @@ import torch from pyro.distributions.empirical import Empirical -from tests.common import assert_equal, assert_close +from tests.common import assert_close, assert_equal @pytest.mark.parametrize("size", [[], [1], [2, 3]]) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index f02426696e..03737ecf63 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -4,14 +4,12 @@ import logging import math +import pytest import torch -import pytest -from pyro.distributions import MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture -from pyro.distributions import MixtureOfDiagNormals +from pyro.distributions import GaussianScaleMixture, MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 63f3daea57..53857c2791 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch -import pytest from pyro.distributions.transforms import HaarTransform from tests.common import assert_equal diff --git a/tests/distributions/test_ig.py b/tests/distributions/test_ig.py index 215d00ed36..5091e02ad7 100644 --- a/tests/distributions/test_ig.py +++ b/tests/distributions/test_ig.py @@ -3,9 +3,9 @@ import math +import pytest import torch -import pytest from pyro.distributions import Gamma, InverseGamma from tests.common import assert_equal diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index e71336b2af..27cfdc4910 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -6,9 +6,8 @@ from torch import tensor from torch.distributions import kl_divergence -from pyro.distributions.util import broadcast_shape from pyro.distributions.torch import Bernoulli, Normal -from pyro.distributions.util import scale_and_mask +from pyro.distributions.util import broadcast_shape, scale_and_mask from tests.common import assert_equal diff --git a/tests/distributions/test_mvt.py b/tests/distributions/test_mvt.py index a61cb1b3f8..ab2dec09ad 100644 --- a/tests/distributions/test_mvt.py +++ b/tests/distributions/test_mvt.py @@ -4,7 +4,6 @@ import math import pytest - import torch from torch.distributions import Gamma, MultivariateNormal, StudentT diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index f1d92bbb0b..eb04d455fb 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np +import pytest import torch -import pytest from pyro.distributions import AVFMultivariateNormal, MultivariateNormal, OMTMultivariateNormal from tests.common import assert_equal diff --git a/tests/distributions/test_ordered_logistic.py b/tests/distributions/test_ordered_logistic.py index 6c6c3ae409..715db994fb 100644 --- a/tests/distributions/test_ordered_logistic.py +++ b/tests/distributions/test_ordered_logistic.py @@ -6,10 +6,9 @@ import torch.tensor as tt from torch.autograd.functional import jacobian -from pyro.distributions import OrderedLogistic, Normal +from pyro.distributions import Normal, OrderedLogistic from pyro.distributions.transforms import OrderedTransform - # Tests for the OrderedLogistic distribution diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index fabb71b451..66f881bbc5 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -3,9 +3,9 @@ import inspect import io +import pickle import pytest -import pickle import torch import pyro.distributions as dist diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 00eaf424a5..7a9bf18f53 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import operator +from functools import partial, reduce from unittest import TestCase import pytest @@ -9,9 +11,6 @@ import pyro.distributions as dist import pyro.distributions.transforms as T -from functools import partial, reduce -import operator - pytestmark = pytest.mark.init(rng_seed=123) diff --git a/tests/doctest_fixtures.py b/tests/doctest_fixtures.py index 8be64b2948..0d4e785d84 100644 --- a/tests/doctest_fixtures.py +++ b/tests/doctest_fixtures.py @@ -6,16 +6,15 @@ import torch import pyro -import pyro.contrib.gp as gp import pyro.contrib.autoname.named as named +import pyro.contrib.gp as gp import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import EmpiricalMarginal -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc import HMC, NUTS +from pyro.infer.mcmc.api import MCMC from pyro.params import param_with_module_name - # Fix seed for all doctest runs. pyro.set_rng_seed(0) diff --git a/tests/infer/mcmc/test_adaptation.py b/tests/infer/mcmc/test_adaptation.py index 2fad237d90..675e43525d 100644 --- a/tests/infer/mcmc/test_adaptation.py +++ b/tests/infer/mcmc/test_adaptation.py @@ -4,12 +4,7 @@ import pytest import torch -from pyro.infer.mcmc.adaptation import ( - ArrowheadMassMatrix, - BlockMassMatrix, - WarmupAdapter, - adapt_window, -) +from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix, WarmupAdapter, adapt_window from tests.common import assert_close, assert_equal diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 58bbf0a76e..2f2f9f967d 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -11,9 +11,9 @@ import pyro import pyro.distributions as dist from pyro.infer.mcmc import NUTS -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC -from tests.common import assert_equal, assert_close +from pyro.infer.mcmc.hmc import HMC +from tests.common import assert_close, assert_equal logger = logging.getLogger(__name__) diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index a577da9d40..cb203fcea8 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -11,7 +11,7 @@ import pyro.distributions as dist from pyro import poutine from pyro.infer.mcmc import HMC, NUTS -from pyro.infer.mcmc.api import MCMC, _UnarySampler, _MultiSampler +from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model from pyro.util import optional diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 43630bcb16..106a4510e3 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -10,12 +10,12 @@ import pyro import pyro.distributions as dist -from pyro.infer.autoguide import AutoDelta -from pyro.contrib.conjugate.infer import BetaBinomialPair, collapse_conjugate, GammaPoissonPair, posterior_replay -from pyro.infer import TraceEnum_ELBO, SVI -from pyro.infer.mcmc import ArrowheadMassMatrix, MCMC, NUTS import pyro.optim as optim import pyro.poutine as poutine +from pyro.contrib.conjugate.infer import BetaBinomialPair, GammaPoissonPair, collapse_conjugate, posterior_replay +from pyro.infer import SVI, TraceEnum_ELBO +from pyro.infer.autoguide import AutoDelta +from pyro.infer.mcmc import MCMC, NUTS, ArrowheadMassMatrix from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal diff --git a/tests/infer/test_abstract_infer.py b/tests/infer/test_abstract_infer.py index 483bc4e854..bfacd142a2 100644 --- a/tests/infer/test_abstract_infer.py +++ b/tests/infer/test_abstract_infer.py @@ -8,12 +8,11 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer.autoguide import AutoLaplaceApproximation from pyro.infer import SVI, Trace_ELBO +from pyro.infer.autoguide import AutoLaplaceApproximation from pyro.infer.mcmc import MCMC, NUTS from tests.common import assert_equal - pytestmark = pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 01d431f169..4c5caaf6fb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -15,11 +15,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine - -from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, Predictive +from pyro.infer import SVI, Predictive, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO from pyro.infer.autoguide import (AutoCallable, AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, AutoGuide, AutoGuideList, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, - AutoNormal, AutoMultivariateNormal, init_to_feasible, init_to_mean, init_to_median, + AutoMultivariateNormal, AutoNormal, init_to_feasible, init_to_mean, init_to_median, init_to_sample) from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.optim import Adam diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 61f32b43f6..6b6ab5e115 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,8 +8,8 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine +from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal -from pyro.infer import Predictive, SVI, Trace_ELBO from tests.common import assert_close diff --git a/tests/infer/test_svgd.py b/tests/infer/test_svgd.py index 2d10b53b55..c6944dedc2 100644 --- a/tests/infer/test_svgd.py +++ b/tests/infer/test_svgd.py @@ -6,11 +6,9 @@ import pyro import pyro.distributions as dist - -from pyro.infer import SVGD, RBFSteinKernel, IMQSteinKernel -from pyro.optim import Adam +from pyro.infer import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.autoguide.utils import _product - +from pyro.optim import Adam from tests.common import assert_equal diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index cf55ed02ce..35667e56ef 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -15,11 +15,10 @@ from pyro.distributions.testing import fakes from pyro.infer import config_enumerate from pyro.infer.importance import vectorized_importance_weights -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.traceenum_elbo import TraceEnum_ELBO +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 236b460050..7fda7a399f 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -3,12 +3,12 @@ import math +import pytest import torch import pyro import pyro.distributions as dist import pyro.poutine as poutine -import pytest from pyro.infer.importance import psis_diagnostic from pyro.infer.util import MultiFrameTensor from tests.common import assert_equal diff --git a/tests/ops/test_arrowhead.py b/tests/ops/test_arrowhead.py index 13feae5697..2ffa76bf78 100644 --- a/tests/ops/test_arrowhead.py +++ b/tests/ops/test_arrowhead.py @@ -5,7 +5,6 @@ import torch from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul - from tests.common import assert_close diff --git a/tests/ops/test_gamma_gaussian.py b/tests/ops/test_gamma_gaussian.py index 872a42e531..74c018bcc5 100644 --- a/tests/ops/test_gamma_gaussian.py +++ b/tests/ops/test_gamma_gaussian.py @@ -9,12 +9,8 @@ import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -from pyro.ops.gamma_gaussian import ( - GammaGaussian, - gamma_gaussian_tensordot, - matrix_and_mvn_to_gamma_gaussian, - gamma_and_mvn_to_gamma_gaussian, -) +from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot, + matrix_and_mvn_to_gamma_gaussian) from tests.common import assert_close from tests.ops.gamma_gaussian import assert_close_gamma_gaussian, random_gamma, random_gamma_gaussian from tests.ops.gaussian import random_mvn diff --git a/tests/ops/test_newton.py b/tests/ops/test_newton.py index d502cde5d7..d264b3ae35 100644 --- a/tests/ops/test_newton.py +++ b/tests/ops/test_newton.py @@ -11,7 +11,6 @@ from pyro.ops.newton import newton_step from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 53fb0213fb..32026d77a2 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -16,8 +16,8 @@ import pyro.optim as optim from pyro.distributions.testing import fakes from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS Model = namedtuple('TestModel', ['model', 'model_args', 'model_id']) diff --git a/tests/poutine/test_nesting.py b/tests/poutine/test_nesting.py index 6fd6f3614d..ede0456c32 100644 --- a/tests/poutine/test_nesting.py +++ b/tests/poutine/test_nesting.py @@ -4,11 +4,10 @@ import logging import pyro -import pyro.poutine as poutine import pyro.distributions as dist +import pyro.poutine as poutine import pyro.poutine.runtime - logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index f2f4eee025..99fbfb6336 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -6,12 +6,12 @@ import logging import pickle import warnings +from queue import Queue from unittest import TestCase import pytest import torch import torch.nn as nn -from queue import Queue import pyro import pyro.distributions as dist @@ -19,7 +19,7 @@ from pyro.distributions import Bernoulli, Categorical, Normal from pyro.poutine.runtime import _DIM_ALLOCATOR, NonlocalExit from pyro.poutine.util import all_escape, discrete_escape -from tests.common import assert_equal, assert_not_equal, assert_close +from tests.common import assert_close, assert_equal, assert_not_equal logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_trace_struct.py b/tests/poutine/test_trace_struct.py index 9ad7d351a6..4511ccbdf3 100644 --- a/tests/poutine/test_trace_struct.py +++ b/tests/poutine/test_trace_struct.py @@ -8,7 +8,6 @@ from pyro.poutine import Trace from tests.common import assert_equal - EDGE_SETS = [ # 1 # / \ diff --git a/tests/pyroapi/test_pyroapi.py b/tests/pyroapi/test_pyroapi.py index 1fa1673b9f..271c38efab 100644 --- a/tests/pyroapi/test_pyroapi.py +++ b/tests/pyroapi/test_pyroapi.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 diff --git a/tests/test_generic.py b/tests/test_generic.py index a3324b27c0..1ca5c77588 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - -from pyro.generic import handlers, infer, pyro, pyro_backend, ops from pyroapi.testing import MODELS + +from pyro.generic import handlers, infer, ops, pyro, pyro_backend from tests.common import xfail_if_not_implemented pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_primitives.py b/tests/test_primitives.py index d285ad69b5..22f331a450 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch + import pyro import pyro.distributions as dist -import torch pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_util.py b/tests/test_util.py index f8b382b4ec..09ec92f4f7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -import pytest +import pytest import torch + from pyro import util pytestmark = pytest.mark.stage('unit') From cca23e122a426ca6d6701f615a325019c604d0b4 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 14:31:20 -0500 Subject: [PATCH 32/91] Revert "Add FactorMuE to test_examples.py" This reverts commit f9823166b79a7d8ed868b759fbcf0790ac2d238a. --- tests/test_examples.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 99cb2370df..b0d0cb96c8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,8 +53,6 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/mue/FactorMuE.py --test --small', - 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', From 9132ae51bacd61a0ef25a3ef51e4be1f30c5c00d Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 14:34:07 -0500 Subject: [PATCH 33/91] Revert "Files autochanged by make format." This reverts commit 42e97bcf1b5381fbdb07f19c813533d586572384. --- docs/source/conf.py | 1 + docs/source/contrib.mue.rst | 9 +--- examples/air/air.py | 2 +- examples/air/main.py | 4 +- examples/capture_recapture/cjs.py | 3 +- examples/contrib/autoname/scoping_mixture.py | 7 +-- examples/contrib/funsor/hmm.py | 1 + examples/contrib/gp/sv-dkl.py | 2 +- examples/contrib/oed/ab_test.py | 14 ++--- examples/contrib/oed/gp_bayes_opt.py | 2 +- examples/contrib/timeseries/gp_models.py | 8 +-- examples/cvae/baseline.py | 5 +- examples/cvae/cvae.py | 10 ++-- examples/cvae/main.py | 10 ++-- examples/cvae/mnist.py | 2 +- examples/cvae/util.py | 12 ++--- examples/eight_schools/data.py | 1 + examples/eight_schools/mcmc.py | 2 +- examples/eight_schools/svi.py | 2 +- examples/hmm.py | 2 +- examples/lkj.py | 3 +- examples/minipyro.py | 2 +- examples/mixed_hmm/experiment.py | 9 ++-- examples/mixed_hmm/seal_data.py | 2 + examples/rsa/generics.py | 9 ++-- examples/rsa/hyperbole.py | 9 ++-- examples/rsa/schelling.py | 3 +- examples/rsa/schelling_false.py | 3 +- examples/rsa/search_inference.py | 4 +- examples/rsa/semantic_parsing.py | 7 +-- examples/scanvi/data.py | 6 +-- examples/scanvi/scanvi.py | 14 ++--- examples/sparse_gamma_def.py | 11 ++-- examples/sparse_regression.py | 9 ++-- examples/vae/ss_vae_M2.py | 6 +-- examples/vae/vae.py | 6 +-- examples/vae/vae_comparison.py | 2 +- profiler/hmm.py | 3 +- profiler/profiling_utils.py | 2 +- pyro/contrib/__init__.py | 1 - pyro/contrib/autoname/__init__.py | 3 +- pyro/contrib/bnn/hidden_layer.py | 2 +- pyro/contrib/bnn/utils.py | 3 +- pyro/contrib/conjugate/infer.py | 2 +- pyro/contrib/easyguide/__init__.py | 1 + pyro/contrib/easyguide/easyguide.py | 2 +- pyro/contrib/examples/bart.py | 2 +- pyro/contrib/examples/finance.py | 2 +- .../examples/polyphonic_data_loader.py | 3 +- pyro/contrib/forecast/util.py | 2 +- pyro/contrib/funsor/__init__.py | 12 +++-- pyro/contrib/funsor/handlers/__init__.py | 8 ++- .../contrib/funsor/handlers/enum_messenger.py | 9 ++-- .../funsor/handlers/named_messenger.py | 3 +- .../funsor/handlers/plate_messenger.py | 8 +-- pyro/contrib/funsor/handlers/primitives.py | 1 + .../funsor/handlers/replay_messenger.py | 2 +- .../funsor/handlers/trace_messenger.py | 5 +- pyro/contrib/funsor/infer/__init__.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 7 +-- pyro/contrib/funsor/infer/traceenum_elbo.py | 5 +- pyro/contrib/funsor/infer/tracetmc_elbo.py | 6 ++- pyro/contrib/gp/kernels/__init__.py | 7 +-- pyro/contrib/gp/likelihoods/binary.py | 1 + pyro/contrib/gp/likelihoods/gaussian.py | 1 + pyro/contrib/gp/likelihoods/multi_class.py | 1 + pyro/contrib/gp/likelihoods/poisson.py | 1 + pyro/contrib/oed/__init__.py | 2 +- pyro/contrib/oed/eig.py | 9 ++-- pyro/contrib/oed/glmm/__init__.py | 2 +- pyro/contrib/oed/glmm/glmm.py | 6 +-- pyro/contrib/oed/glmm/guides.py | 8 +-- pyro/contrib/oed/search.py | 3 +- pyro/contrib/oed/util.py | 3 +- .../contrib/randomvariable/random_variable.py | 13 +++-- pyro/contrib/timeseries/__init__.py | 2 +- pyro/contrib/tracking/distributions.py | 2 +- pyro/contrib/tracking/dynamic_models.py | 1 - pyro/contrib/tracking/measurements.py | 1 - pyro/contrib/util.py | 2 - pyro/distributions/__init__.py | 51 +++++++++++++++---- pyro/distributions/ordered_logistic.py | 1 - pyro/distributions/spanning_tree.py | 1 - pyro/distributions/transforms/__init__.py | 5 +- pyro/distributions/transforms/ordered.py | 3 +- pyro/infer/__init__.py | 2 +- pyro/infer/abstract_infer.py | 2 +- pyro/infer/autoguide/initialization.py | 1 + pyro/infer/mcmc/__init__.py | 2 +- pyro/infer/mcmc/adaptation.py | 2 +- pyro/infer/mcmc/hmc.py | 3 +- pyro/infer/mcmc/util.py | 4 +- pyro/infer/predictive.py | 2 +- pyro/infer/reparam/neutra.py | 1 - pyro/infer/svgd.py | 6 +-- pyro/infer/trace_mean_field_elbo.py | 2 +- pyro/infer/trace_mmd.py | 2 +- pyro/infer/trace_tail_adaptive_elbo.py | 2 +- pyro/infer/traceenum_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 3 +- pyro/infer/tracetmc_elbo.py | 1 + pyro/logger.py | 1 + pyro/ops/arrowhead.py | 1 + pyro/ops/einsum/torch_map.py | 1 + pyro/ops/einsum/torch_sample.py | 1 + pyro/ops/newton.py | 2 +- pyro/ops/ssm_gp.py | 2 +- pyro/ops/stats.py | 2 +- pyro/poutine/broadcast_messenger.py | 1 - pyro/poutine/indep_messenger.py | 1 - pyro/poutine/trace_struct.py | 2 +- tests/__init__.py | 1 + tests/conftest.py | 1 + tests/contrib/autoguide/test_inference.py | 4 +- .../autoguide/test_mean_field_entropy.py | 4 +- tests/contrib/autoname/test_scoping.py | 2 +- tests/contrib/bnn/test_hidden_layer.py | 2 +- tests/contrib/epidemiology/test_quant.py | 1 + tests/contrib/funsor/test_enum_funsor.py | 4 +- tests/contrib/funsor/test_named_handlers.py | 3 +- tests/contrib/funsor/test_pyroapi_funsor.py | 2 +- tests/contrib/funsor/test_tmc.py | 3 +- .../contrib/funsor/test_valid_models_enum.py | 5 +- .../contrib/funsor/test_valid_models_plate.py | 3 +- .../test_valid_models_sequential_plate.py | 3 +- .../contrib/funsor/test_vectorized_markov.py | 6 +-- tests/contrib/gp/test_kernels.py | 5 +- tests/contrib/gp/test_likelihoods.py | 1 + tests/contrib/gp/test_models.py | 7 +-- tests/contrib/oed/test_ewma.py | 2 +- tests/contrib/oed/test_finite_spaces_eig.py | 8 +-- tests/contrib/oed/test_glmm.py | 6 ++- tests/contrib/oed/test_linear_models_eig.py | 11 ++-- .../randomvariable/test_random_variable.py | 1 - tests/contrib/test_util.py | 5 +- tests/contrib/timeseries/test_gp.py | 9 ++-- tests/contrib/timeseries/test_lgssm.py | 4 +- tests/contrib/tracking/test_assignment.py | 4 +- tests/contrib/tracking/test_distributions.py | 3 +- tests/contrib/tracking/test_dynamic_models.py | 3 +- tests/contrib/tracking/test_ekf.py | 3 +- tests/contrib/tracking/test_em.py | 1 + tests/contrib/tracking/test_measurements.py | 1 - tests/distributions/test_empirical.py | 2 +- tests/distributions/test_gaussian_mixtures.py | 6 ++- tests/distributions/test_haar.py | 2 +- tests/distributions/test_ig.py | 2 +- tests/distributions/test_mask.py | 3 +- tests/distributions/test_mvt.py | 1 + tests/distributions/test_omt_mvn.py | 2 +- tests/distributions/test_ordered_logistic.py | 3 +- tests/distributions/test_pickle.py | 2 +- tests/distributions/test_transforms.py | 5 +- tests/doctest_fixtures.py | 5 +- tests/infer/mcmc/test_adaptation.py | 7 ++- tests/infer/mcmc/test_hmc.py | 4 +- tests/infer/mcmc/test_mcmc_api.py | 2 +- tests/infer/mcmc/test_nuts.py | 8 +-- tests/infer/test_abstract_infer.py | 3 +- tests/infer/test_autoguide.py | 5 +- tests/infer/test_predictive.py | 2 +- tests/infer/test_svgd.py | 6 ++- tests/infer/test_tmc.py | 3 +- tests/infer/test_util.py | 2 +- tests/ops/test_arrowhead.py | 1 + tests/ops/test_gamma_gaussian.py | 8 ++- tests/ops/test_newton.py | 1 + tests/perf/test_benchmark.py | 2 +- tests/poutine/test_nesting.py | 3 +- tests/poutine/test_poutines.py | 4 +- tests/poutine/test_trace_struct.py | 1 + tests/pyroapi/test_pyroapi.py | 1 + tests/test_generic.py | 4 +- tests/test_primitives.py | 3 +- tests/test_util.py | 3 +- 175 files changed, 399 insertions(+), 302 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b66e2c7f46..250af8e937 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,6 +6,7 @@ import sphinx_rtd_theme + # import pkg_resources # -*- coding: utf-8 -*- diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 9fd7e10254..335bcd8965 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -14,13 +14,6 @@ Reference: MuE models were described in Weinstein and Marks (2020), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. -Example MuE Models ------------------- -.. automodule:: pyro.contrib.mue.models - :members: - :show-inheritance: - :member-order: bysource - State Arrangers for Parameterizing MuEs --------------------------------------- .. automodule:: pyro.contrib.mue.statearrangers @@ -30,7 +23,7 @@ State Arrangers for Parameterizing MuEs Variable Length/Missing Data HMM -------------------------------- -.. automodule:: pyro.contrib.mue.missingdatahmm +.. automodule:: pyro.contrib.mue.variablelengthhmm :members: :show-inheritance: :member-order: bysource diff --git a/examples/air/air.py b/examples/air/air.py index 985e2853bf..a9b8958d0f 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -14,10 +14,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from modules import MLP, Decoder, Encoder, Identity, Predict import pyro import pyro.distributions as dist +from modules import MLP, Decoder, Encoder, Identity, Predict # Default prior success probability for z_pres. diff --git a/examples/air/main.py b/examples/air/main.py index 34be8303b6..8cf13b255e 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -19,15 +19,15 @@ import numpy as np import torch import visdom -from air import AIR, latents_to_tensor -from viz import draw_many, tensor_to_objs import pyro import pyro.contrib.examples.multi_mnist as multi_mnist import pyro.optim as optim import pyro.poutine as poutine +from air import AIR, latents_to_tensor from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO +from viz import draw_many, tensor_to_objs def count_accuracy(X, true_counts, air, batch_size): diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index fa868899d5..65b709afca 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -39,10 +39,11 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO from pyro.infer.autoguide import AutoDiagonalNormal +from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO from pyro.optim import Adam + """ Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival probability phi; diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 1d4adb8ec8..842e9f03c0 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -2,15 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - import torch from torch.distributions import constraints import pyro -import pyro.distributions as dist import pyro.optim +import pyro.distributions as dist + +from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO + from pyro.contrib.autoname import scope -from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate def model(K, data): diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 32f7331e77..5c860d1470 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -64,6 +64,7 @@ from pyroapi import distributions as dist from pyroapi import handlers, infer, optim, pyro, pyro_backend + logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 616cc297d0..338ad6a19f 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -39,7 +39,7 @@ import pyro import pyro.contrib.gp as gp import pyro.infer as infer -from pyro.contrib.examples.util import get_data_directory, get_data_loader +from pyro.contrib.examples.util import get_data_loader, get_data_directory class CNN(nn.Module): diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index b713d28306..16842b7c05 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -3,18 +3,20 @@ import argparse from functools import partial - -import numpy as np import torch -from gp_bayes_opt import GPBayesOptimizer from torch.distributions import constraints +import numpy as np import pyro -import pyro.contrib.gp as gp from pyro import optim -from pyro.contrib.oed.eig import vi_eig -from pyro.contrib.oed.glmm import analytic_posterior_cov, group_assignment_matrix, zero_mean_unit_obs_sd_lm from pyro.infer import TraceEnum_ELBO +from pyro.contrib.oed.eig import vi_eig +import pyro.contrib.gp as gp +from pyro.contrib.oed.glmm import ( + zero_mean_unit_obs_sd_lm, group_assignment_matrix, analytic_posterior_cov +) + +from gp_bayes_opt import GPBayesOptimizer """ Example builds on the Bayesian regression tutorial [1]. It demonstrates how diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 6132dee48a..3c114c9bcf 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -7,8 +7,8 @@ from torch.distributions import transform_to import pyro.contrib.gp as gp -import pyro.optim from pyro.infer import TraceEnum_ELBO +import pyro.optim class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index 2ca6e03bc4..81d2e1316e 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -1,16 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import argparse -from os.path import exists -from urllib.request import urlopen - import numpy as np import torch import pyro from pyro.contrib.timeseries import IndependentMaternGP, LinearlyCoupledMaternGP +import argparse +from os.path import exists +from urllib.request import urlopen + # download dataset from UCI archive def download_data(): diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index 23e1591016..cb5d279445 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from pathlib import Path - import numpy as np +from pathlib import Path +from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F -from tqdm import tqdm class BaselineNet(nn.Module): diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index f499aaf452..fb792cf4d3 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -1,16 +1,14 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path - import numpy as np -import torch -import torch.nn as nn -from tqdm import tqdm - +from pathlib import Path import pyro import pyro.distributions as dist from pyro.infer import SVI, Trace_ELBO +import torch +import torch.nn as nn +from tqdm import tqdm class Encoder(nn.Module): diff --git a/examples/cvae/main.py b/examples/cvae/main.py index dea506db5a..224b4f05af 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -2,14 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - -import baseline -import cvae import pandas as pd -import torch -from util import generate_table, get_data, visualize - import pyro +import torch +import baseline +import cvae +from util import get_data, visualize, generate_table def main(args): diff --git a/examples/cvae/mnist.py b/examples/cvae/mnist.py index 12dd7409f2..a98c667081 100644 --- a/examples/cvae/mnist.py +++ b/examples/cvae/mnist.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset, DataLoader from torchvision.datasets import MNIST from torchvision.transforms import Compose, functional diff --git a/examples/cvae/util.py b/examples/cvae/util.py index 87650298ef..e578085946 100644 --- a/examples/cvae/util.py +++ b/examples/cvae/util.py @@ -1,19 +1,17 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path - -import matplotlib.pyplot as plt import numpy as np +import matplotlib.pyplot as plt import pandas as pd +from pathlib import Path +from pyro.infer import Predictive, Trace_ELBO import torch -from baseline import MaskedBCELoss -from mnist import get_data from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm - -from pyro.infer import Predictive, Trace_ELBO +from baseline import MaskedBCELoss +from mnist import get_data def imshow(inp, image_path=None): diff --git a/examples/eight_schools/data.py b/examples/eight_schools/data.py index 56158fa36e..39529e798a 100644 --- a/examples/eight_schools/data.py +++ b/examples/eight_schools/data.py @@ -3,6 +3,7 @@ import torch + J = 8 y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index ec6ff9b5c6..7b927d43e0 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -4,9 +4,9 @@ import argparse import logging -import data import torch +import data import pyro import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index 7f70044707..14e7d32c33 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -5,11 +5,11 @@ import logging import torch -from data import J, sigma, y from torch.distributions import constraints, transforms import pyro import pyro.distributions as dist +from data import J, sigma, y from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam diff --git a/examples/hmm.py b/examples/hmm.py index 417b9e171f..52086a5067 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -47,8 +47,8 @@ import pyro.contrib.examples.polyphonic_data_loader as poly import pyro.distributions as dist from pyro import poutine -from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO from pyro.infer.autoguide import AutoDelta +from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO from pyro.ops.indexing import Vindex from pyro.optim import Adam from pyro.util import ignore_jit_warnings diff --git a/examples/lkj.py b/examples/lkj.py index 2c8610ca23..56d26ab33d 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - import torch import pyro import pyro.distributions as dist -from pyro.infer.mcmc import NUTS from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc import NUTS """ This simple example is intended to demonstrate how to use an LKJ prior with diff --git a/examples/minipyro.py b/examples/minipyro.py index 9855b9694c..e12775dfda 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -11,8 +11,8 @@ import torch -# We use the pyro.generic interface to support dynamic choice of backend. from pyro.generic import distributions as dist +# We use the pyro.generic interface to support dynamic choice of backend. from pyro.generic import infer, ops, optim, pyro, pyro_backend diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index bd4bd02e6a..65584c6769 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -2,19 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import functools -import json import os +import json import uuid +import functools import torch -from model import guide_generic, model_generic -from seal_data import prepare_seal import pyro import pyro.poutine as poutine from pyro.infer import TraceEnum_ELBO +from model import model_generic, guide_generic +from seal_data import prepare_seal + def aic_num_parameters(model, guide=None): """ diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py index 609fc69da3..390201a8d9 100644 --- a/examples/mixed_hmm/seal_data.py +++ b/examples/mixed_hmm/seal_data.py @@ -5,8 +5,10 @@ from urllib.request import urlopen import pandas as pd + import torch + MISSING = 1e-6 diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 4fb059b748..1e617316c4 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -9,17 +9,18 @@ [1] https://gscontras.github.io/probLang/chapters/07-generics.html """ +import torch + import argparse -import collections import numbers - -import torch -from search_inference import HashingMarginal, Search, memoize +import collections import pyro import pyro.distributions as dist import pyro.poutine as poutine +from search_inference import HashingMarginal, memoize, Search + torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index 429715137d..04d878fa8b 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -7,16 +7,17 @@ Taken from: https://gscontras.github.io/probLang/chapters/03-nonliteral.html """ -import argparse -import collections - import torch -from search_inference import HashingMarginal, Search, memoize + +import collections +import argparse import pyro import pyro.distributions as dist import pyro.poutine as poutine +from search_inference import HashingMarginal, memoize, Search + torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index ce8bfd219e..ea31af2811 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -11,13 +11,12 @@ Taken from: http://forestdb.org/models/schelling.html """ import argparse - import torch -from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli +from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index 82ab6aedd0..4a5ffcdf98 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -12,13 +12,12 @@ Taken from: http://forestdb.org/models/schelling-falsebelief.html """ import argparse - import torch -from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli +from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/search_inference.py b/examples/rsa/search_inference.py index 7e2cb8e142..14a49766f1 100644 --- a/examples/rsa/search_inference.py +++ b/examples/rsa/search_inference.py @@ -8,10 +8,10 @@ """ import collections -import functools -import queue import torch +import queue +import functools import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index 11a39ef10a..0a998c6227 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -7,15 +7,16 @@ Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html """ +import torch + import argparse import collections -import torch -from search_inference import BestFirstSearch, HashingMarginal, memoize - import pyro import pyro.distributions as dist +from search_inference import HashingMarginal, BestFirstSearch, memoize + torch.set_default_dtype(torch.float64) diff --git a/examples/scanvi/data.py b/examples/scanvi/data.py index 429883d1a3..690eab0717 100644 --- a/examples/scanvi/data.py +++ b/examples/scanvi/data.py @@ -8,11 +8,11 @@ """ import math - import numpy as np +from scipy import sparse + import torch import torch.nn as nn -from scipy import sparse class BatchDataLoader(object): @@ -122,8 +122,8 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): return BatchDataLoader(X, Y, batch_size), num_genes, 2.0, 1.0, None - import scanpy as sc import scvi + import scanpy as sc adata = scvi.data.purified_pbmc_dataset(subset_datasets=["regulatory_t", "naive_t", "memory_t", "naive_cytotoxic"]) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 5a9e1cffb3..2e4d4bc760 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -19,22 +19,25 @@ import argparse -import matplotlib.pyplot as plt import numpy as np + import torch import torch.nn as nn -from data import get_data -from matplotlib.patches import Patch from torch.distributions import constraints -from torch.nn.functional import softmax, softplus +from torch.nn.functional import softplus, softmax from torch.optim import Adam import pyro import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.util import broadcast_shape -from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate from pyro.optim import MultiStepLR +from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO + +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +from data import get_data # Helper for making fully-connected neural networks @@ -297,7 +300,6 @@ def main(args): # Now that we're done training we'll inspect the latent representations we've learned if args.plot and args.dataset == 'pbmc': import scanpy as sc - # Compute latent representation (z2_loc) for each cell in the dataset latent_rep = scanvi.z2l_encoder(dataloader.data_x)[0] diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index ae37f2cd0a..3af6153609 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -20,16 +20,19 @@ import numpy as np import torch -import wget from torch.nn.functional import softplus import pyro import pyro.optim as optim -from pyro.contrib.easyguide import EasyGuide +import wget + from pyro.contrib.examples.util import get_data_directory -from pyro.distributions import Gamma, Normal, Poisson +from pyro.distributions import Gamma, Poisson, Normal from pyro.infer import SVI, TraceMeanField_ELBO -from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible +from pyro.infer.autoguide import AutoDiagonalNormal +from pyro.infer.autoguide import init_to_feasible +from pyro.contrib.easyguide import EasyGuide + torch.set_default_tensor_type('torch.FloatTensor') pyro.util.set_rng_seed(0) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 639ac44d73..9ae9421417 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -2,17 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import math import numpy as np import torch -from torch.optim import Adam +import math import pyro import pyro.distributions as dist from pyro import poutine +from pyro.infer.autoguide import AutoDelta from pyro.infer import Trace_ELBO -from pyro.infer.autoguide import AutoDelta, init_to_median +from pyro.infer.autoguide import init_to_median + +from torch.optim import Adam + """ We demonstrate how to do sparse linear regression using a variant of the diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index 627dded294..2f720a2ce0 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -5,9 +5,6 @@ import torch import torch.nn as nn -from utils.custom_mlp import MLP, Exp -from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders -from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae from visdom import Visdom import pyro @@ -15,6 +12,9 @@ from pyro.contrib.examples.util import print_and_log from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate from pyro.optim import Adam +from utils.custom_mlp import MLP, Exp +from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders +from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae class SSVAE(nn.Module): diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 98f19533dc..396ff788b8 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -7,14 +7,14 @@ import torch import torch.nn as nn import visdom -from utils.mnist_cached import MNISTCached as MNIST -from utils.mnist_cached import setup_data_loaders -from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples import pyro import pyro.distributions as dist from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam +from utils.mnist_cached import MNISTCached as MNIST +from utils.mnist_cached import setup_data_loaders +from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples # define the PyTorch module that parameterizes the diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index 60f9eddcb3..f4291e5e35 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -10,13 +10,13 @@ import torch.nn as nn from torch.nn import functional from torchvision.utils import save_image -from utils.mnist_cached import DATA_DIR, RESULTS_DIR import pyro from pyro.contrib.examples import util from pyro.distributions import Bernoulli, Normal from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam +from utils.mnist_cached import DATA_DIR, RESULTS_DIR """ Comparison of VAE implementation in PyTorch and Pyro. This example can be diff --git a/profiler/hmm.py b/profiler/hmm.py index 1825c82c20..4308c3df56 100644 --- a/profiler/hmm.py +++ b/profiler/hmm.py @@ -8,12 +8,13 @@ import subprocess import sys from collections import defaultdict -from os.path import abspath, join +from os.path import join, abspath from numpy import median from pyro.util import timed + EXAMPLES_DIR = join(abspath(__file__), os.pardir, os.pardir, "examples") diff --git a/profiler/profiling_utils.py b/profiler/profiling_utils.py index aee4f9b564..8375132eb2 100644 --- a/profiler/profiling_utils.py +++ b/profiler/profiling_utils.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import cProfile +from io import StringIO import functools import os import pstats import timeit from contextlib import contextmanager -from io import StringIO from prettytable import ALL, PrettyTable diff --git a/pyro/contrib/__init__.py b/pyro/contrib/__init__.py index 045ec5435f..3f14bd1862 100644 --- a/pyro/contrib/__init__.py +++ b/pyro/contrib/__init__.py @@ -25,7 +25,6 @@ try: import funsor as funsor_ # noqa: F401 - from pyro.contrib import funsor __all__ += ["funsor"] except ImportError: diff --git a/pyro/contrib/autoname/__init__.py b/pyro/contrib/autoname/__init__.py index d3e72366d9..6f396f55b6 100644 --- a/pyro/contrib/autoname/__init__.py +++ b/pyro/contrib/autoname/__init__.py @@ -6,7 +6,8 @@ generating unique, semantically meaningful names for sample sites. """ from pyro.contrib.autoname import named -from pyro.contrib.autoname.scoping import name_count, scope +from pyro.contrib.autoname.scoping import scope, name_count + __all__ = [ "named", diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index cc97b051fa..6a4f679e29 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.nn.functional as F from torch.distributions.utils import lazy_property +import torch.nn.functional as F from pyro.contrib.bnn.utils import adjoin_ones_vector from pyro.distributions.torch_distribution import TorchDistribution diff --git a/pyro/contrib/bnn/utils.py b/pyro/contrib/bnn/utils.py index 794f66f984..ec2f33623a 100644 --- a/pyro/contrib/bnn/utils.py +++ b/pyro/contrib/bnn/utils.py @@ -1,9 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import math - import torch +import math def xavier_uniform(D_in, D_out): diff --git a/pyro/contrib/conjugate/infer.py b/pyro/contrib/conjugate/infer.py index 0c815c0126..23a3fe791e 100644 --- a/pyro/contrib/conjugate/infer.py +++ b/pyro/contrib/conjugate/infer.py @@ -6,8 +6,8 @@ import torch import pyro.distributions as dist -from pyro import poutine from pyro.distributions.util import sum_leftmost +from pyro import poutine from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample diff --git a/pyro/contrib/easyguide/__init__.py b/pyro/contrib/easyguide/__init__.py index d26c63c9cf..9e2577841f 100644 --- a/pyro/contrib/easyguide/__init__.py +++ b/pyro/contrib/easyguide/__init__.py @@ -3,6 +3,7 @@ from pyro.contrib.easyguide.easyguide import EasyGuide, easy_guide + __all__ = [ "EasyGuide", "easy_guide", diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index 55535ae72d..fbc204466b 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -14,8 +14,8 @@ import pyro.poutine as poutine import pyro.poutine.runtime as runtime from pyro.distributions.util import broadcast_shape, sum_rightmost -from pyro.infer.autoguide.guides import prototype_hide_fn from pyro.infer.autoguide.initialization import InitMessenger +from pyro.infer.autoguide.guides import prototype_hide_fn from pyro.nn.module import PyroModule, PyroParam diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index 0398ad137d..0d89fee5fc 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -14,7 +14,7 @@ import torch -from pyro.contrib.examples.util import _mkdir_p, get_data_directory +from pyro.contrib.examples.util import get_data_directory, _mkdir_p DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/finance.py b/pyro/contrib/examples/finance.py index c40a0b55e8..03572c0289 100644 --- a/pyro/contrib/examples/finance.py +++ b/pyro/contrib/examples/finance.py @@ -6,7 +6,7 @@ import pandas as pd -from pyro.contrib.examples.util import _mkdir_p, get_data_directory +from pyro.contrib.examples.util import get_data_directory, _mkdir_p DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/polyphonic_data_loader.py b/pyro/contrib/examples/polyphonic_data_loader.py index 491ae0517f..132c6d953d 100644 --- a/pyro/contrib/examples/polyphonic_data_loader.py +++ b/pyro/contrib/examples/polyphonic_data_loader.py @@ -17,9 +17,9 @@ """ import os -import pickle from collections import namedtuple from urllib.request import urlopen +import pickle import torch import torch.nn as nn @@ -27,6 +27,7 @@ from pyro.contrib.examples.util import get_data_directory + dset = namedtuple("dset", ["name", "url", "filename"]) JSB_CHORALES = dset("jsb_chorales", diff --git a/pyro/contrib/forecast/util.py b/pyro/contrib/forecast/util.py index f2bd7034b5..8918b470da 100644 --- a/pyro/contrib/forecast/util.py +++ b/pyro/contrib/forecast/util.py @@ -7,7 +7,7 @@ from torch.distributions import transform_to, transforms import pyro.distributions as dist -from pyro.infer.reparam import DiscreteCosineReparam, HaarReparam +from pyro.infer.reparam import HaarReparam, DiscreteCosineReparam from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample from pyro.primitives import get_param_store diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index dcb9355e5e..30a23d5ca2 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -3,12 +3,14 @@ import pyroapi -from pyro.contrib.funsor.handlers import condition, do, markov # noqa: F401 -from pyro.contrib.funsor.handlers import plate as _plate -from pyro.contrib.funsor.handlers import vectorized_markov # noqa: F401 +from pyro.primitives import ( # noqa: F401 + clear_param_store, deterministic, enable_validation, factor, get_param_store, + module, param, random_module, sample, set_rng_seed, subsample, +) + from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 -from pyro.primitives import (clear_param_store, deterministic, enable_validation, factor, get_param_store, # noqa: F401 - module, param, random_module, sample, set_rng_seed, subsample) +from pyro.contrib.funsor.handlers import condition, do, markov, vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers import plate as _plate def plate(*args, **kwargs): diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 724ec29c83..a98be1de94 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -1,16 +1,20 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.poutine import (block, condition, do, escape, infer_config, mask, reparam, scale, seed, # noqa: F401 - uncondition) from pyro.poutine.handlers import _make_handler +from pyro.poutine import ( # noqa: F401 + block, condition, do, escape, infer_config, + mask, reparam, scale, seed, uncondition, +) + from .enum_messenger import EnumMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger from .trace_messenger import TraceMessenger + _msngrs = [ EnumMessenger, MarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index befbeb2014..15caf49078 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -9,17 +9,18 @@ import math from collections import OrderedDict -import funsor import torch +import funsor import pyro.poutine.runtime import pyro.poutine.util -from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger +from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.subsample_messenger import _Subsample + from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor +from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger -from pyro.poutine.escape_messenger import EscapeMessenger -from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index e7cf18ffbc..fb7667fab3 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,9 +4,10 @@ from collections import OrderedDict from contextlib import ExitStack -from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame from pyro.poutine.reentrant_messenger import ReentrantMessenger +from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame + class NamedMessenger(ReentrantMessenger): """ diff --git a/pyro/contrib/funsor/handlers/plate_messenger.py b/pyro/contrib/funsor/handlers/plate_messenger.py index d180df021f..f119db1de0 100644 --- a/pyro/contrib/funsor/handlers/plate_messenger.py +++ b/pyro/contrib/funsor/handlers/plate_messenger.py @@ -6,16 +6,18 @@ import funsor -from pyro.contrib.funsor.handlers.named_messenger import DimRequest, DimType, GlobalNamedMessenger, NamedMessenger -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.distributions.util import copy_docs_from from pyro.poutine.broadcast_messenger import BroadcastMessenger from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import effectful from pyro.poutine.subsample_messenger import SubsampleMessenger as OrigSubsampleMessenger from pyro.util import ignore_jit_warnings +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor +from pyro.contrib.funsor.handlers.named_messenger import DimRequest, DimType, GlobalNamedMessenger, \ + NamedMessenger +from pyro.poutine.runtime import effectful + funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/primitives.py b/pyro/contrib/funsor/handlers/primitives.py index 0b7a4c4edb..3d8815eff0 100644 --- a/pyro/contrib/funsor/handlers/primitives.py +++ b/pyro/contrib/funsor/handlers/primitives.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.poutine.runtime + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index ae672d2dd4..2389941049 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.contrib.funsor.handlers.primitives import to_data from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger +from pyro.contrib.funsor.handlers.primitives import to_data class ReplayMessenger(OrigReplayMessenger): diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index d6a995596d..4671901668 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -3,11 +3,12 @@ import funsor -from pyro.contrib.funsor.handlers.primitives import to_funsor -from pyro.contrib.funsor.handlers.runtime import _DIM_STACK from pyro.poutine.subsample_messenger import _Subsample from pyro.poutine.trace_messenger import TraceMessenger as OrigTraceMessenger +from pyro.contrib.funsor.handlers.primitives import to_funsor +from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + class TraceMessenger(OrigTraceMessenger): """ diff --git a/pyro/contrib/funsor/infer/__init__.py b/pyro/contrib/funsor/infer/__init__.py index 55f260e6a9..4525e2cef5 100644 --- a/pyro/contrib/funsor/infer/__init__.py +++ b/pyro/contrib/funsor/infer/__init__.py @@ -5,5 +5,5 @@ from .elbo import ELBO # noqa: F401 from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401 -from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 +from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 686926b772..1912edca11 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -5,13 +5,14 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import Trace_ELBO as _OrigTrace_ELBO + from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer import config_enumerate -from pyro.distributions.util import copy_docs_from -from pyro.infer import Trace_ELBO as _OrigTrace_ELBO -from .elbo import ELBO, Jit_ELBO +from .elbo import Jit_ELBO, ELBO from .traceenum_elbo import terms_from_trace diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index c9e7eb0242..d725da2c50 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -5,11 +5,12 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO + from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO def terms_from_trace(tr): diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index aae66ce5c0..8a714d8c03 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -5,12 +5,14 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO + from pyro.contrib.funsor import to_data from pyro.contrib.funsor.handlers import enum, plate, replay, trace + from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO @copy_docs_from(_OrigTraceTMC_ELBO) diff --git a/pyro/contrib/gp/kernels/__init__.py b/pyro/contrib/gp/kernels/__init__.py index c36ddd37fb..9874e73c17 100644 --- a/pyro/contrib/gp/kernels/__init__.py +++ b/pyro/contrib/gp/kernels/__init__.py @@ -4,9 +4,10 @@ from pyro.contrib.gp.kernels.brownian import Brownian from pyro.contrib.gp.kernels.coregionalize import Coregionalize from pyro.contrib.gp.kernels.dot_product import DotProduct, Linear, Polynomial -from pyro.contrib.gp.kernels.isotropic import RBF, Exponential, Isotropy, Matern32, Matern52, RationalQuadratic -from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, Transforming, VerticalScaling, - Warping) +from pyro.contrib.gp.kernels.isotropic import (RBF, Exponential, Isotropy, Matern32, Matern52, + RationalQuadratic) +from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, + Transforming, VerticalScaling, Warping) from pyro.contrib.gp.kernels.periodic import Cosine, Periodic from pyro.contrib.gp.kernels.static import Constant, WhiteNoise diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index ef417f9e22..3041f5e92e 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/gaussian.py b/pyro/contrib/gp/likelihoods/gaussian.py index b1b65ff95c..cb5a15d8c7 100644 --- a/pyro/contrib/gp/likelihoods/gaussian.py +++ b/pyro/contrib/gp/likelihoods/gaussian.py @@ -6,6 +6,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood from pyro.nn.module import PyroParam diff --git a/pyro/contrib/gp/likelihoods/multi_class.py b/pyro/contrib/gp/likelihoods/multi_class.py index ed8463f8bd..9ff69e81f1 100644 --- a/pyro/contrib/gp/likelihoods/multi_class.py +++ b/pyro/contrib/gp/likelihoods/multi_class.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/poisson.py b/pyro/contrib/gp/likelihoods/poisson.py index 8abed6fd2a..48916e0634 100644 --- a/pyro/contrib/gp/likelihoods/poisson.py +++ b/pyro/contrib/gp/likelihoods/poisson.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/oed/__init__.py b/pyro/contrib/oed/__init__.py index 3afd3a440d..006c57a7a1 100644 --- a/pyro/contrib/oed/__init__.py +++ b/pyro/contrib/oed/__init__.py @@ -67,7 +67,7 @@ def model(design): """ -from pyro.contrib.oed import eig, search +from pyro.contrib.oed import search, eig __all__ = [ "search", diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index 8d6c7ae22b..7faec28aa6 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -1,18 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import torch import math import warnings -import torch - import pyro from pyro import poutine +from pyro.infer.autoguide.utils import mean_field_entropy from pyro.contrib.oed.search import Search +from pyro.infer import EmpiricalMarginal, Importance, SVI +from pyro.util import torch_isnan, torch_isinf from pyro.contrib.util import lexpand -from pyro.infer import SVI, EmpiricalMarginal, Importance -from pyro.infer.autoguide.utils import mean_field_entropy -from pyro.util import torch_isinf, torch_isnan __all__ = [ "laplace_eig", diff --git a/pyro/contrib/oed/glmm/__init__.py b/pyro/contrib/oed/glmm/__init__.py index c17c221213..f9d75643ca 100644 --- a/pyro/contrib/oed/glmm/__init__.py +++ b/pyro/contrib/oed/glmm/__init__.py @@ -36,5 +36,5 @@ For random effects with a shared covariance matrix, see :meth:`pyro.contrib.oed.glmm.lmer_model`. """ -from pyro.contrib.oed.glmm import guides # noqa: F401 from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 +from pyro.contrib.oed.glmm import guides # noqa: F401 diff --git a/pyro/contrib/oed/glmm/glmm.py b/pyro/contrib/oed/glmm/glmm.py index 68507be53f..2c418391e3 100644 --- a/pyro/contrib/oed/glmm/glmm.py +++ b/pyro/contrib/oed/glmm/glmm.py @@ -3,17 +3,17 @@ import warnings from collections import OrderedDict -from contextlib import ExitStack from functools import partial +from contextlib import ExitStack import torch +from torch.nn.functional import softplus from torch.distributions import constraints from torch.distributions.transforms import AffineTransform, SigmoidTransform -from torch.nn.functional import softplus import pyro import pyro.distributions as dist -from pyro.contrib.util import iter_plates_to_shape, rmv +from pyro.contrib.util import rmv, iter_plates_to_shape # TODO read from torch float spec epsilon = torch.tensor(2**-24) diff --git a/pyro/contrib/oed/glmm/guides.py b/pyro/contrib/oed/glmm/guides.py index d2425adff2..c71b06415c 100644 --- a/pyro/contrib/oed/glmm/guides.py +++ b/pyro/contrib/oed/glmm/guides.py @@ -1,15 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from contextlib import ExitStack - import torch from torch import nn +from contextlib import ExitStack + import pyro import pyro.distributions as dist from pyro import poutine -from pyro.contrib.util import iter_plates_to_shape, lexpand, rmv, rtril, rvv, tensor_to_dict +from pyro.contrib.util import ( + tensor_to_dict, rmv, rvv, rtril, lexpand, iter_plates_to_shape +) from pyro.ops.linalg import rinverse diff --git a/pyro/contrib/oed/search.py b/pyro/contrib/oed/search.py index 721f6305c3..4bf8eb1816 100644 --- a/pyro/contrib/oed/search.py +++ b/pyro/contrib/oed/search.py @@ -2,9 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import queue - -import pyro.poutine as poutine from pyro.infer.abstract_infer import TracePosterior +import pyro.poutine as poutine ################################### # Search borrowed from RSA example diff --git a/pyro/contrib/oed/util.py b/pyro/contrib/oed/util.py index 50774ff0bd..d5c315f85a 100644 --- a/pyro/contrib/oed/util.py +++ b/pyro/contrib/oed/util.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import math - import torch -from pyro.contrib.oed.glmm import analytic_posterior_cov from pyro.contrib.util import get_indices +from pyro.contrib.oed.glmm import analytic_posterior_cov from pyro.infer.autoguide.utils import mean_field_entropy diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 39c06748ae..1bc28f1caa 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -4,10 +4,17 @@ from typing import Union from torch import Tensor - from pyro.distributions import TransformedDistribution -from pyro.distributions.transforms import (AbsTransform, AffineTransform, ExpTransform, PowerTransform, - SigmoidTransform, SoftmaxTransform, TanhTransform, Transform) +from pyro.distributions.transforms import ( + Transform, + AffineTransform, + AbsTransform, + PowerTransform, + ExpTransform, + TanhTransform, + SoftmaxTransform, + SigmoidTransform +) class RVMagicOps: diff --git a/pyro/contrib/timeseries/__init__.py b/pyro/contrib/timeseries/__init__.py index 517c3f9550..f119203f04 100644 --- a/pyro/contrib/timeseries/__init__.py +++ b/pyro/contrib/timeseries/__init__.py @@ -6,7 +6,7 @@ models useful for forecasting applications. """ from pyro.contrib.timeseries.base import TimeSeriesModel -from pyro.contrib.timeseries.gp import DependentMaternGP, IndependentMaternGP, LinearlyCoupledMaternGP +from pyro.contrib.timeseries.gp import IndependentMaternGP, LinearlyCoupledMaternGP, DependentMaternGP from pyro.contrib.timeseries.lgssm import GenericLGSSM from pyro.contrib.timeseries.lgssmgp import GenericLGSSMWithGPNoiseModel diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index fc3e61c6c1..641c764d19 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -5,9 +5,9 @@ from torch.distributions import constraints import pyro.distributions as dist +from pyro.distributions.torch_distribution import TorchDistribution from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement -from pyro.distributions.torch_distribution import TorchDistribution class EKFDistribution(TorchDistribution): diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index 7ea41ad3a7..bd26b344e2 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -6,7 +6,6 @@ import torch from torch import nn from torch.nn import Parameter - import pyro.distributions as dist from pyro.distributions.util import eye_like diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index 8f24ea4360..cb98c49ccd 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod import torch - from pyro.distributions.util import eye_like diff --git a/pyro/contrib/util.py b/pyro/contrib/util.py index e250639ca7..44ff34832b 100644 --- a/pyro/contrib/util.py +++ b/pyro/contrib/util.py @@ -2,9 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict - import torch - import pyro diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index c51ee45ded..25e46dc8bb 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -3,19 +3,40 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimes, CoalescentTimesWithRate -from pyro.distributions.conditional import (ConditionalDistribution, ConditionalTransform, - ConditionalTransformedDistribution, ConditionalTransformModule) -from pyro.distributions.conjugate import BetaBinomial, DirichletMultinomial, GammaPoisson +from pyro.distributions.coalescent import ( + CoalescentRateLikelihood, + CoalescentTimes, + CoalescentTimesWithRate, +) +from pyro.distributions.conditional import ( + ConditionalDistribution, + ConditionalTransform, + ConditionalTransformedDistribution, + ConditionalTransformModule, +) +from pyro.distributions.conjugate import ( + BetaBinomial, + DirichletMultinomial, + GammaPoisson, +) from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals -from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance +from pyro.distributions.diag_normal_mixture_shared_cov import ( + MixtureOfDiagNormalsSharedCovariance, +) from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.extended import ExtendedBetaBinomial, ExtendedBinomial from pyro.distributions.folded import FoldedDistribution from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.hmm import DiscreteHMM, GammaGaussianHMM, GaussianHMM, GaussianMRF, IndependentHMM, LinearHMM +from pyro.distributions.hmm import ( + DiscreteHMM, + GammaGaussianHMM, + GaussianHMM, + GaussianMRF, + IndependentHMM, + LinearHMM, +) from pyro.distributions.improper_uniform import ImproperUniform from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJCorrCholesky @@ -27,8 +48,10 @@ from pyro.distributions.ordered_logistic import OrderedLogistic from pyro.distributions.polya_gamma import TruncatedPolyaGamma from pyro.distributions.rejector import Rejector -from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, - RelaxedOneHotCategoricalStraightThrough) +from pyro.distributions.relaxed_straight_through import ( + RelaxedBernoulliStraightThrough, + RelaxedOneHotCategoricalStraightThrough, +) from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import * # noqa F403 @@ -36,9 +59,17 @@ from pyro.distributions.torch_distribution import MaskedDistribution, TorchDistribution from pyro.distributions.torch_transform import ComposeTransformModule, TransformModule from pyro.distributions.unit import Unit -from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled +from pyro.distributions.util import ( + enable_validation, + is_validation_enabled, + validation_enabled, +) from pyro.distributions.von_mises_3d import VonMises3D -from pyro.distributions.zero_inflated import ZeroInflatedDistribution, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson +from pyro.distributions.zero_inflated import ( + ZeroInflatedDistribution, + ZeroInflatedNegativeBinomial, + ZeroInflatedPoisson, +) from . import constraints, kl, transforms diff --git a/pyro/distributions/ordered_logistic.py b/pyro/distributions/ordered_logistic.py index c8d4ef459d..d6d288fafb 100644 --- a/pyro/distributions/ordered_logistic.py +++ b/pyro/distributions/ordered_logistic.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch - from pyro.distributions import constraints from pyro.distributions.torch import Categorical diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index ff5e370f05..f7b2f64238 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -185,7 +185,6 @@ def _get_cpp_module(): global _cpp_module if _cpp_module is None: import os - from torch.utils.cpp_extension import load path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp") _cpp_module = load(name="cpp_spanning_tree", diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index df68c5a5c0..f926f52330 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -5,7 +5,10 @@ from torch.distributions.transforms import * # noqa F403 from torch.distributions.transforms import __all__ as torch_transforms -from pyro.distributions.constraints import IndependentConstraint, corr_cholesky_constraint, ordered_vector +from pyro.distributions.constraints import ( + IndependentConstraint, + corr_cholesky_constraint, + ordered_vector) from pyro.distributions.torch_transform import ComposeTransformModule from pyro.distributions.transforms.affine_autoregressive import (AffineAutoregressive, ConditionalAffineAutoregressive, affine_autoregressive, diff --git a/pyro/distributions/transforms/ordered.py b/pyro/distributions/transforms/ordered.py index 95497e45fa..79aea6a261 100644 --- a/pyro/distributions/transforms/ordered.py +++ b/pyro/distributions/transforms/ordered.py @@ -2,9 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch - -from pyro.distributions import constraints from pyro.distributions.transforms import Transform +from pyro.distributions import constraints class OrderedTransform(Transform): diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 5485b7476e..421e5a16fa 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -17,13 +17,13 @@ from pyro.infer.smcfilter import SMCFilter from pyro.infer.svgd import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.svi import SVI +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO from pyro.infer.trace_mean_field_elbo import JitTraceMeanField_ELBO, TraceMeanField_ELBO from pyro.infer.trace_mmd import Trace_MMD from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 8891c9cc64..0424949e7d 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -10,8 +10,8 @@ import pyro.poutine as poutine from pyro.distributions import Categorical, Empirical -from pyro.ops.stats import waic from pyro.poutine.util import site_is_subsample +from pyro.ops.stats import waic class EmpiricalMarginal(Empirical): diff --git a/pyro/infer/autoguide/initialization.py b/pyro/infer/autoguide/initialization.py index 69ccdde4ca..70b1a430f3 100644 --- a/pyro/infer/autoguide/initialization.py +++ b/pyro/infer/autoguide/initialization.py @@ -20,6 +20,7 @@ from pyro.poutine.messenger import Messenger from pyro.util import torch_isnan + # TODO: move this file out of `autoguide` in a minor release def _is_multivariate(d): diff --git a/pyro/infer/mcmc/__init__.py b/pyro/infer/mcmc/__init__.py index 99d241b162..e33cbec518 100644 --- a/pyro/infer/mcmc/__init__.py +++ b/pyro/infer/mcmc/__init__.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.nuts import NUTS __all__ = [ diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index c8d41924f6..46497a53b5 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -9,7 +9,7 @@ import pyro from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul from pyro.ops.dual_averaging import DualAveraging -from pyro.ops.welford import WelfordArrowheadCovariance, WelfordCovariance +from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance adapt_window = namedtuple("adapt_window", ["start", "end"]) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 139d7f4a0d..857004212f 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -8,8 +8,9 @@ import pyro import pyro.distributions as dist -from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.distributions.util import scalar_like +from pyro.distributions.testing.fakes import NonreparameterizedNormal + from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 89831a93f8..372104c53f 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import functools -import traceback as tb import warnings from collections import OrderedDict, defaultdict from functools import partial, reduce from itertools import product +import traceback as tb import torch -from opt_einsum import shared_intermediates from torch.distributions import biject_to +from opt_einsum import shared_intermediates import pyro import pyro.poutine as poutine diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 857b54df77..1b946accfe 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import warnings from functools import reduce +import warnings import torch diff --git a/pyro/infer/reparam/neutra.py b/pyro/infer/reparam/neutra.py index 2d6b9c8bb2..f9753d9051 100644 --- a/pyro/infer/reparam/neutra.py +++ b/pyro/infer/reparam/neutra.py @@ -8,7 +8,6 @@ from pyro import poutine from pyro.distributions.util import sum_rightmost from pyro.infer.autoguide.guides import AutoContinuous - from .reparam import Reparam diff --git a/pyro/infer/svgd.py b/pyro/infer/svgd.py index 9e722a745f..ebc526a86d 100644 --- a/pyro/infer/svgd.py +++ b/pyro/infer/svgd.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import math from abc import ABCMeta, abstractmethod +import math import torch from torch.distributions import biject_to @@ -10,10 +10,10 @@ import pyro from pyro import poutine from pyro.distributions import Delta -from pyro.distributions.util import copy_docs_from +from pyro.infer.trace_elbo import Trace_ELBO from pyro.infer.autoguide.guides import AutoContinuous from pyro.infer.autoguide.initialization import init_to_sample -from pyro.infer.trace_elbo import Trace_ELBO +from pyro.distributions.util import copy_docs_from def vectorize(fn, num_particles, max_plate_nesting): diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index 6eab28596d..d04210ee56 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -10,7 +10,7 @@ import pyro.ops.jit from pyro.distributions.util import scale_and_mask from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import check_fully_reparametrized, is_validation_enabled, torch_item +from pyro.infer.util import is_validation_enabled, torch_item, check_fully_reparametrized from pyro.util import warn_if_nan diff --git a/pyro/infer/trace_mmd.py b/pyro/infer/trace_mmd.py index 661ff727c2..1cc71992b9 100644 --- a/pyro/infer/trace_mmd.py +++ b/pyro/infer/trace_mmd.py @@ -9,8 +9,8 @@ import pyro.ops.jit from pyro import poutine from pyro.infer.elbo import ELBO +from pyro.infer.util import torch_item, is_validation_enabled from pyro.infer.enum import get_importance_trace -from pyro.infer.util import is_validation_enabled, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/trace_tail_adaptive_elbo.py b/pyro/infer/trace_tail_adaptive_elbo.py index b05251a300..a69ea6d191 100644 --- a/pyro/infer/trace_tail_adaptive_elbo.py +++ b/pyro/infer/trace_tail_adaptive_elbo.py @@ -6,7 +6,7 @@ import torch from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import check_fully_reparametrized, is_validation_enabled +from pyro.infer.util import is_validation_enabled, check_fully_reparametrized class TraceTailAdaptive_ELBO(Trace_ELBO): diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 504060963f..e2e483eec7 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -1,10 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import queue import warnings import weakref from collections import OrderedDict +import queue import torch from opt_einsum import shared_intermediates diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 8e7bc7ed6f..852ea6e658 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -11,7 +11,8 @@ from pyro.distributions.util import detach, is_identically_zero from pyro.infer import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_backward, torch_item +from pyro.infer.util import (MultiFrameTensor, get_plate_stacks, + is_validation_enabled, torch_backward, torch_item) from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/tracetmc_elbo.py b/pyro/infer/tracetmc_elbo.py index f78b277080..51c3ba3b78 100644 --- a/pyro/infer/tracetmc_elbo.py +++ b/pyro/infer/tracetmc_elbo.py @@ -7,6 +7,7 @@ import torch import pyro.poutine as poutine + from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO from pyro.infer.enum import get_importance_trace, iter_discrete_escape, iter_discrete_extend diff --git a/pyro/logger.py b/pyro/logger.py index 64a8c70c46..5bee771e4a 100644 --- a/pyro/logger.py +++ b/pyro/logger.py @@ -3,6 +3,7 @@ import logging + default_format = '%(levelname)s \t %(message)s' log = logging.getLogger("pyro") log.setLevel(logging.INFO) diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index 4714d3095e..9641ab4ec2 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -5,6 +5,7 @@ import torch + SymmArrowhead = namedtuple("SymmArrowhead", ["top", "bottom_diag"]) TriuArrowhead = namedtuple("TriuArrowhead", ["top", "bottom_diag"]) diff --git a/pyro/ops/einsum/torch_map.py b/pyro/ops/einsum/torch_map.py index e4293c1140..6e2832bcff 100644 --- a/pyro/ops/einsum/torch_map.py +++ b/pyro/ops/einsum/torch_map.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import operator + from functools import reduce from pyro.ops import packed diff --git a/pyro/ops/einsum/torch_sample.py b/pyro/ops/einsum/torch_sample.py index 5420c328ba..06c8108886 100644 --- a/pyro/ops/einsum/torch_sample.py +++ b/pyro/ops/einsum/torch_sample.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import operator + from functools import reduce import pyro.distributions as dist diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index e651b4284e..6611d3a93c 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -4,8 +4,8 @@ import torch from torch.autograd import grad -from pyro.ops.linalg import eig_3d, rinverse from pyro.util import warn_if_nan +from pyro.ops.linalg import rinverse, eig_3d def newton_step(loss, x, trust_radius=None): diff --git a/pyro/ops/ssm_gp.py b/pyro/ops/ssm_gp.py index 89abcb2912..eb88ba9d70 100644 --- a/pyro/ops/ssm_gp.py +++ b/pyro/ops/ssm_gp.py @@ -6,7 +6,7 @@ import torch from torch.distributions import constraints -from pyro.nn import PyroModule, PyroParam, pyro_method +from pyro.nn import PyroModule, pyro_method, PyroParam root_three = math.sqrt(3.0) root_five = math.sqrt(5.0) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index d1ae84281e..f05886e22e 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -6,8 +6,8 @@ import torch -from .fft import irfft, rfft from .tensor_utils import next_fast_len +from .fft import rfft, irfft def _compute_chain_variance_stats(input): diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 2ff2a57ae9..20a5b1d5e1 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.util import ignore_jit_warnings - from .messenger import Messenger diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 0b65987932..9ed1575857 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -7,7 +7,6 @@ import torch from pyro.util import ignore_jit_warnings - from .messenger import Messenger from .runtime import _DIM_ALLOCATOR diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 176c7a772f..39cd234bb9 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import sys from collections import OrderedDict +import sys import opt_einsum diff --git a/tests/__init__.py b/tests/__init__.py index 4056718ce7..200bfc2d65 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging + import os # create log handler for tests diff --git a/tests/conftest.py b/tests/conftest.py index 699cca55c2..2cfcba39d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pyro + torch.set_default_tensor_type(os.environ.get('PYRO_TENSOR_TYPE', 'torch.DoubleTensor')) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 767b763ff4..9e9002d507 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -12,10 +12,10 @@ import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.distributions.transforms import block_autoregressive, iterated -from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO +from pyro.distributions.transforms import iterated, block_autoregressive from pyro.infer.autoguide import (AutoDiagonalNormal, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal) +from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide.guides import AutoNormalizingFlow from tests.common import assert_equal from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/contrib/autoguide/test_mean_field_entropy.py index 2f5cd163db..9f8c301b32 100644 --- a/tests/contrib/autoguide/test_mean_field_entropy.py +++ b/tests/contrib/autoguide/test_mean_field_entropy.py @@ -1,9 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest -import scipy.special as sc import torch +import scipy.special as sc +import pytest import pyro import pyro.distributions as dist diff --git a/tests/contrib/autoname/test_scoping.py b/tests/contrib/autoname/test_scoping.py index aa7e44bae6..d10d6f2d7f 100644 --- a/tests/contrib/autoname/test_scoping.py +++ b/tests/contrib/autoname/test_scoping.py @@ -8,7 +8,7 @@ import pyro import pyro.distributions.torch as dist import pyro.poutine as poutine -from pyro.contrib.autoname import name_count, scope +from pyro.contrib.autoname import scope, name_count logger = logging.getLogger(__name__) diff --git a/tests/contrib/bnn/test_hidden_layer.py b/tests/contrib/bnn/test_hidden_layer.py index c688572d0f..1067cc03f2 100644 --- a/tests/contrib/bnn/test_hidden_layer.py +++ b/tests/contrib/bnn/test_hidden_layer.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch import torch.nn.functional as F from torch.distributions import Normal +import pytest from pyro.contrib.bnn import HiddenLayer from tests.common import assert_equal diff --git a/tests/contrib/epidemiology/test_quant.py b/tests/contrib/epidemiology/test_quant.py index d2a0edbc1d..f9dc53bb64 100644 --- a/tests/contrib/epidemiology/test_quant.py +++ b/tests/contrib/epidemiology/test_quant.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + import torch from pyro.contrib.epidemiology.util import compute_bin_probs diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 9b273e5e2e..75fff55463 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -7,6 +7,7 @@ import pyroapi import pytest import torch + from torch.autograd import grad from torch.distributions import constraints @@ -16,10 +17,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro - - import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_named_handlers.py b/tests/contrib/funsor/test_named_handlers.py index c4c57b7bd5..48c464daa3 100644 --- a/tests/contrib/funsor/test_named_handlers.py +++ b/tests/contrib/funsor/test_named_handlers.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import logging from collections import OrderedDict +import logging import pytest import torch @@ -11,7 +11,6 @@ try: import funsor from funsor.tensor import Tensor - import pyro.contrib.funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger funsor.set_backend("torch") diff --git a/tests/contrib/funsor/test_pyroapi_funsor.py b/tests/contrib/funsor/test_pyroapi_funsor.py index 9e050462e9..74dbf972e3 100644 --- a/tests/contrib/funsor/test_pyroapi_funsor.py +++ b/tests/contrib/funsor/test_pyroapi_funsor.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 try: # triggers backend registration import funsor - import pyro.contrib.funsor # noqa: F401 funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index 54d4eedaae..cc1ab52178 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -14,10 +14,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro, pyro_backend - - import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_enum.py b/tests/contrib/funsor/test_valid_models_enum.py index 7df1b23a90..3ef3241ad2 100644 --- a/tests/contrib/funsor/test_valid_models_enum.py +++ b/tests/contrib/funsor/test_valid_models_enum.py @@ -1,10 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict import contextlib import logging import os -from collections import defaultdict from queue import LifoQueue import pytest @@ -19,12 +19,11 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import handlers, infer, pyro, pyro_backend + from pyroapi import infer, handlers, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_plate.py b/tests/contrib/funsor/test_valid_models_plate.py index f5d30fc1b7..ed20ee4be4 100644 --- a/tests/contrib/funsor/test_valid_models_plate.py +++ b/tests/contrib/funsor/test_valid_models_plate.py @@ -12,10 +12,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro - - import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_valid_models_sequential_plate.py b/tests/contrib/funsor/test_valid_models_sequential_plate.py index 40eeb79cb6..1de6af5b08 100644 --- a/tests/contrib/funsor/test_valid_models_sequential_plate.py +++ b/tests/contrib/funsor/test_valid_models_sequential_plate.py @@ -11,10 +11,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro - - import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index d1a5008a10..ac34469f26 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -3,6 +3,7 @@ import pytest import torch + from torch.distributions import constraints from pyro.ops.indexing import Vindex @@ -11,11 +12,10 @@ try: import funsor from funsor.testing import assert_close - from pyroapi import distributions as dist - import pyro.contrib.funsor + from pyroapi import distributions as dist funsor.set_backend("torch") - from pyroapi import handlers, infer, pyro, pyro_backend + from pyroapi import handlers, pyro, pyro_backend, infer except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/gp/test_kernels.py b/tests/contrib/gp/test_kernels.py index db1c803786..cc9797ff2c 100644 --- a/tests/contrib/gp/test_kernels.py +++ b/tests/contrib/gp/test_kernels.py @@ -6,8 +6,9 @@ import pytest import torch -from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, Exponential, Linear, - Matern32, Matern52, Periodic, Polynomial, Product, RationalQuadratic, Sum, +from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, + Exponential, Linear, Matern32, Matern52, Periodic, + Polynomial, Product, RationalQuadratic, Sum, VerticalScaling, Warping, WhiteNoise) from tests.common import assert_equal diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index c63c1ce9a5..71fbe663ad 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -11,6 +11,7 @@ from pyro.contrib.gp.models import VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train + T = namedtuple("TestGPLikelihood", ["model_class", "X", "y", "kernel", "likelihood"]) X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0], [3.0, 4.0, 6.0]]) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 711089025d..d5afa24ec2 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -8,12 +8,13 @@ import torch import pyro.distributions as dist -from pyro.contrib.gp.kernels import RBF, Cosine, Matern32, WhiteNoise +from pyro.contrib.gp.kernels import Cosine, Matern32, RBF, WhiteNoise from pyro.contrib.gp.likelihoods import Gaussian -from pyro.contrib.gp.models import GPLVM, GPRegression, SparseGPRegression, VariationalGP, VariationalSparseGP +from pyro.contrib.gp.models import (GPLVM, GPRegression, SparseGPRegression, + VariationalGP, VariationalSparseGP) from pyro.contrib.gp.util import train -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.nn.module import PyroSample from tests.common import assert_equal diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index aa8df92188..57e8e4e7da 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -3,9 +3,9 @@ import math -import pytest import torch +import pytest from pyro.contrib.oed.eig import EwmaLog from tests.common import assert_equal diff --git a/tests/contrib/oed/test_finite_spaces_eig.py b/tests/contrib/oed/test_finite_spaces_eig.py index b6f69234d4..49fc02493a 100644 --- a/tests/contrib/oed/test_finite_spaces_eig.py +++ b/tests/contrib/oed/test_finite_spaces_eig.py @@ -1,15 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import (donsker_varadhan_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, nmc_eig, - posterior_eig, vnmc_eig) +from pyro.contrib.oed.eig import ( + nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, lfire_eig, + donsker_varadhan_eig) from pyro.contrib.util import iter_plates_to_shape + from tests.common import assert_equal try: diff --git a/tests/contrib/oed/test_glmm.py b/tests/contrib/oed/test_glmm.py index cb3e95d169..6e855525dd 100644 --- a/tests/contrib/oed/test_glmm.py +++ b/tests/contrib/oed/test_glmm.py @@ -8,8 +8,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.contrib.oed.glmm import (group_linear_model, known_covariance_linear_model, logistic_regression_model, - normal_inverse_gamma_linear_model, sigmoid_model, zero_mean_unit_obs_sd_lm) +from pyro.contrib.oed.glmm import ( + known_covariance_linear_model, group_linear_model, zero_mean_unit_obs_sd_lm, + normal_inverse_gamma_linear_model, logistic_regression_model, sigmoid_model +) from tests.common import assert_equal diff --git a/tests/contrib/oed/test_linear_models_eig.py b/tests/contrib/oed/test_linear_models_eig.py index f84ba916e5..30280cb602 100644 --- a/tests/contrib/oed/test_linear_models_eig.py +++ b/tests/contrib/oed/test_linear_models_eig.py @@ -1,19 +1,20 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import (donsker_varadhan_eig, laplace_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, - nmc_eig, posterior_eig, vnmc_eig) +from pyro.infer import Trace_ELBO from pyro.contrib.oed.glmm import known_covariance_linear_model -from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from pyro.contrib.oed.util import linear_model_ground_truth +from pyro.contrib.oed.eig import ( + nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, laplace_eig, lfire_eig, + donsker_varadhan_eig) from pyro.contrib.util import rmv, rvv -from pyro.infer import Trace_ELBO +from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from tests.common import assert_equal diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 4c392bc997..5a1a43c194 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -4,7 +4,6 @@ import math import torch.tensor as tt - from pyro.distributions import Uniform N_SAMPLES = 100 diff --git a/tests/contrib/test_util.py b/tests/contrib/test_util.py index 60a3115dad..442ca61bec 100644 --- a/tests/contrib/test_util.py +++ b/tests/contrib/test_util.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict - import pytest import torch -from pyro.contrib.util import get_indices, lexpand, rdiag, rexpand, rmv, rtril, rvv, tensor_to_dict +from pyro.contrib.util import ( + get_indices, tensor_to_dict, rmv, rvv, lexpand, rexpand, rdiag, rtril +) from tests.common import assert_equal diff --git a/tests/contrib/timeseries/test_gp.py b/tests/contrib/timeseries/test_gp.py index e2e39a0aba..2698faa01b 100644 --- a/tests/contrib/timeseries/test_gp.py +++ b/tests/contrib/timeseries/test_gp.py @@ -2,15 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import math - -import pytest import torch +from tests.common import assert_equal import pyro -from pyro.contrib.timeseries import (DependentMaternGP, GenericLGSSM, GenericLGSSMWithGPNoiseModel, IndependentMaternGP, - LinearlyCoupledMaternGP) +from pyro.contrib.timeseries import (IndependentMaternGP, LinearlyCoupledMaternGP, GenericLGSSM, + GenericLGSSMWithGPNoiseModel, DependentMaternGP) from pyro.ops.tensor_utils import block_diag_embed -from tests.common import assert_equal +import pytest @pytest.mark.parametrize('model,obs_dim,nu_statedim', [('ssmgp', 3, 1.5), ('ssmgp', 2, 2.5), diff --git a/tests/contrib/timeseries/test_lgssm.py b/tests/contrib/timeseries/test_lgssm.py index 5b5ed9d339..f5c2dac137 100644 --- a/tests/contrib/timeseries/test_lgssm.py +++ b/tests/contrib/timeseries/test_lgssm.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch -from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel from tests.common import assert_equal +from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel +import pytest @pytest.mark.parametrize('model_class', ['lgssm', 'lgssmgp']) diff --git a/tests/contrib/tracking/test_assignment.py b/tests/contrib/tracking/test_assignment.py index 9c425dd502..554a373eb3 100644 --- a/tests/contrib/tracking/test_assignment.py +++ b/tests/contrib/tracking/test_assignment.py @@ -1,12 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import logging - import pytest import torch from torch.autograd import grad +import logging + import pyro import pyro.distributions as dist from pyro.contrib.tracking.assignment import MarginalAssignment, MarginalAssignmentPersistent, MarginalAssignmentSparse diff --git a/tests/contrib/tracking/test_distributions.py b/tests/contrib/tracking/test_distributions.py index fe4c149b49..4c589ac221 100644 --- a/tests/contrib/tracking/test_distributions.py +++ b/tests/contrib/tracking/test_distributions.py @@ -1,12 +1,13 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch from pyro.contrib.tracking.distributions import EKFDistribution from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous +import pytest + @pytest.mark.parametrize('Model', [NcpContinuous, NcvContinuous]) @pytest.mark.parametrize('dim', [2, 3]) diff --git a/tests/contrib/tracking/test_dynamic_models.py b/tests/contrib/tracking/test_dynamic_models.py index 4f93afe523..51df52e75d 100644 --- a/tests/contrib/tracking/test_dynamic_models.py +++ b/tests/contrib/tracking/test_dynamic_models.py @@ -3,7 +3,8 @@ import torch -from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcpDiscrete, NcvContinuous, NcvDiscrete +from pyro.contrib.tracking.dynamic_models import (NcpContinuous, NcvContinuous, + NcvDiscrete, NcpDiscrete) from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_ekf.py b/tests/contrib/tracking/test_ekf.py index 35db1544d1..99cec4488c 100644 --- a/tests/contrib/tracking/test_ekf.py +++ b/tests/contrib/tracking/test_ekf.py @@ -3,9 +3,10 @@ import torch -from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous from pyro.contrib.tracking.extended_kalman_filter import EKFState +from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous from pyro.contrib.tracking.measurements import PositionMeasurement + from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index c3401f4114..1d0fca7147 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -16,6 +16,7 @@ from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton + logger = logging.getLogger(__name__) diff --git a/tests/contrib/tracking/test_measurements.py b/tests/contrib/tracking/test_measurements.py index 373cad0e79..38f2afcd3d 100644 --- a/tests/contrib/tracking/test_measurements.py +++ b/tests/contrib/tracking/test_measurements.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch - from pyro.contrib.tracking.measurements import PositionMeasurement diff --git a/tests/distributions/test_empirical.py b/tests/distributions/test_empirical.py index 3f2d4435dd..7d220aa95e 100644 --- a/tests/distributions/test_empirical.py +++ b/tests/distributions/test_empirical.py @@ -5,7 +5,7 @@ import torch from pyro.distributions.empirical import Empirical -from tests.common import assert_close, assert_equal +from tests.common import assert_equal, assert_close @pytest.mark.parametrize("size", [[], [1], [2, 3]]) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index 03737ecf63..f02426696e 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -4,12 +4,14 @@ import logging import math -import pytest import torch -from pyro.distributions import GaussianScaleMixture, MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance +import pytest +from pyro.distributions import MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture +from pyro.distributions import MixtureOfDiagNormals from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 53857c2791..63f3daea57 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest from pyro.distributions.transforms import HaarTransform from tests.common import assert_equal diff --git a/tests/distributions/test_ig.py b/tests/distributions/test_ig.py index 5091e02ad7..215d00ed36 100644 --- a/tests/distributions/test_ig.py +++ b/tests/distributions/test_ig.py @@ -3,9 +3,9 @@ import math -import pytest import torch +import pytest from pyro.distributions import Gamma, InverseGamma from tests.common import assert_equal diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index 27cfdc4910..e71336b2af 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -6,8 +6,9 @@ from torch import tensor from torch.distributions import kl_divergence +from pyro.distributions.util import broadcast_shape from pyro.distributions.torch import Bernoulli, Normal -from pyro.distributions.util import broadcast_shape, scale_and_mask +from pyro.distributions.util import scale_and_mask from tests.common import assert_equal diff --git a/tests/distributions/test_mvt.py b/tests/distributions/test_mvt.py index ab2dec09ad..a61cb1b3f8 100644 --- a/tests/distributions/test_mvt.py +++ b/tests/distributions/test_mvt.py @@ -4,6 +4,7 @@ import math import pytest + import torch from torch.distributions import Gamma, MultivariateNormal, StudentT diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index eb04d455fb..f1d92bbb0b 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import pytest import torch +import pytest from pyro.distributions import AVFMultivariateNormal, MultivariateNormal, OMTMultivariateNormal from tests.common import assert_equal diff --git a/tests/distributions/test_ordered_logistic.py b/tests/distributions/test_ordered_logistic.py index 715db994fb..6c6c3ae409 100644 --- a/tests/distributions/test_ordered_logistic.py +++ b/tests/distributions/test_ordered_logistic.py @@ -6,9 +6,10 @@ import torch.tensor as tt from torch.autograd.functional import jacobian -from pyro.distributions import Normal, OrderedLogistic +from pyro.distributions import OrderedLogistic, Normal from pyro.distributions.transforms import OrderedTransform + # Tests for the OrderedLogistic distribution diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index 66f881bbc5..fabb71b451 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -3,9 +3,9 @@ import inspect import io -import pickle import pytest +import pickle import torch import pyro.distributions as dist diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 7a9bf18f53..00eaf424a5 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -1,8 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import operator -from functools import partial, reduce from unittest import TestCase import pytest @@ -11,6 +9,9 @@ import pyro.distributions as dist import pyro.distributions.transforms as T +from functools import partial, reduce +import operator + pytestmark = pytest.mark.init(rng_seed=123) diff --git a/tests/doctest_fixtures.py b/tests/doctest_fixtures.py index 0d4e785d84..8be64b2948 100644 --- a/tests/doctest_fixtures.py +++ b/tests/doctest_fixtures.py @@ -6,15 +6,16 @@ import torch import pyro -import pyro.contrib.autoname.named as named import pyro.contrib.gp as gp +import pyro.contrib.autoname.named as named import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import EmpiricalMarginal -from pyro.infer.mcmc import HMC, NUTS from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc import HMC, NUTS from pyro.params import param_with_module_name + # Fix seed for all doctest runs. pyro.set_rng_seed(0) diff --git a/tests/infer/mcmc/test_adaptation.py b/tests/infer/mcmc/test_adaptation.py index 675e43525d..2fad237d90 100644 --- a/tests/infer/mcmc/test_adaptation.py +++ b/tests/infer/mcmc/test_adaptation.py @@ -4,7 +4,12 @@ import pytest import torch -from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix, WarmupAdapter, adapt_window +from pyro.infer.mcmc.adaptation import ( + ArrowheadMassMatrix, + BlockMassMatrix, + WarmupAdapter, + adapt_window, +) from tests.common import assert_close, assert_equal diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 2f2f9f967d..58bbf0a76e 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -11,9 +11,9 @@ import pyro import pyro.distributions as dist from pyro.infer.mcmc import NUTS -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC -from tests.common import assert_close, assert_equal +from pyro.infer.mcmc.api import MCMC +from tests.common import assert_equal, assert_close logger = logging.getLogger(__name__) diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index cb203fcea8..a577da9d40 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -11,7 +11,7 @@ import pyro.distributions as dist from pyro import poutine from pyro.infer.mcmc import HMC, NUTS -from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler +from pyro.infer.mcmc.api import MCMC, _UnarySampler, _MultiSampler from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model from pyro.util import optional diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 106a4510e3..43630bcb16 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -10,12 +10,12 @@ import pyro import pyro.distributions as dist +from pyro.infer.autoguide import AutoDelta +from pyro.contrib.conjugate.infer import BetaBinomialPair, collapse_conjugate, GammaPoissonPair, posterior_replay +from pyro.infer import TraceEnum_ELBO, SVI +from pyro.infer.mcmc import ArrowheadMassMatrix, MCMC, NUTS import pyro.optim as optim import pyro.poutine as poutine -from pyro.contrib.conjugate.infer import BetaBinomialPair, GammaPoissonPair, collapse_conjugate, posterior_replay -from pyro.infer import SVI, TraceEnum_ELBO -from pyro.infer.autoguide import AutoDelta -from pyro.infer.mcmc import MCMC, NUTS, ArrowheadMassMatrix from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal diff --git a/tests/infer/test_abstract_infer.py b/tests/infer/test_abstract_infer.py index bfacd142a2..483bc4e854 100644 --- a/tests/infer/test_abstract_infer.py +++ b/tests/infer/test_abstract_infer.py @@ -8,11 +8,12 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Trace_ELBO from pyro.infer.autoguide import AutoLaplaceApproximation +from pyro.infer import SVI, Trace_ELBO from pyro.infer.mcmc import MCMC, NUTS from tests.common import assert_equal + pytestmark = pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 4c5caaf6fb..01d431f169 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -15,10 +15,11 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO + +from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, Predictive from pyro.infer.autoguide import (AutoCallable, AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, AutoGuide, AutoGuideList, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, - AutoMultivariateNormal, AutoNormal, init_to_feasible, init_to_mean, init_to_median, + AutoNormal, AutoMultivariateNormal, init_to_feasible, init_to_mean, init_to_median, init_to_sample) from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.optim import Adam diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 6b6ab5e115..61f32b43f6 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,8 +8,8 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal +from pyro.infer import Predictive, SVI, Trace_ELBO from tests.common import assert_close diff --git a/tests/infer/test_svgd.py b/tests/infer/test_svgd.py index c6944dedc2..2d10b53b55 100644 --- a/tests/infer/test_svgd.py +++ b/tests/infer/test_svgd.py @@ -6,9 +6,11 @@ import pyro import pyro.distributions as dist -from pyro.infer import SVGD, IMQSteinKernel, RBFSteinKernel -from pyro.infer.autoguide.utils import _product + +from pyro.infer import SVGD, RBFSteinKernel, IMQSteinKernel from pyro.optim import Adam +from pyro.infer.autoguide.utils import _product + from tests.common import assert_equal diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index 35667e56ef..cf55ed02ce 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -15,10 +15,11 @@ from pyro.distributions.testing import fakes from pyro.infer import config_enumerate from pyro.infer.importance import vectorized_importance_weights -from pyro.infer.traceenum_elbo import TraceEnum_ELBO from pyro.infer.tracetmc_elbo import TraceTMC_ELBO +from pyro.infer.traceenum_elbo import TraceEnum_ELBO from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 7fda7a399f..236b460050 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -3,12 +3,12 @@ import math -import pytest import torch import pyro import pyro.distributions as dist import pyro.poutine as poutine +import pytest from pyro.infer.importance import psis_diagnostic from pyro.infer.util import MultiFrameTensor from tests.common import assert_equal diff --git a/tests/ops/test_arrowhead.py b/tests/ops/test_arrowhead.py index 2ffa76bf78..13feae5697 100644 --- a/tests/ops/test_arrowhead.py +++ b/tests/ops/test_arrowhead.py @@ -5,6 +5,7 @@ import torch from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul + from tests.common import assert_close diff --git a/tests/ops/test_gamma_gaussian.py b/tests/ops/test_gamma_gaussian.py index 74c018bcc5..872a42e531 100644 --- a/tests/ops/test_gamma_gaussian.py +++ b/tests/ops/test_gamma_gaussian.py @@ -9,8 +9,12 @@ import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot, - matrix_and_mvn_to_gamma_gaussian) +from pyro.ops.gamma_gaussian import ( + GammaGaussian, + gamma_gaussian_tensordot, + matrix_and_mvn_to_gamma_gaussian, + gamma_and_mvn_to_gamma_gaussian, +) from tests.common import assert_close from tests.ops.gamma_gaussian import assert_close_gamma_gaussian, random_gamma, random_gamma_gaussian from tests.ops.gaussian import random_mvn diff --git a/tests/ops/test_newton.py b/tests/ops/test_newton.py index d264b3ae35..d502cde5d7 100644 --- a/tests/ops/test_newton.py +++ b/tests/ops/test_newton.py @@ -11,6 +11,7 @@ from pyro.ops.newton import newton_step from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 32026d77a2..53fb0213fb 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -16,8 +16,8 @@ import pyro.optim as optim from pyro.distributions.testing import fakes from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.nuts import NUTS Model = namedtuple('TestModel', ['model', 'model_args', 'model_id']) diff --git a/tests/poutine/test_nesting.py b/tests/poutine/test_nesting.py index ede0456c32..6fd6f3614d 100644 --- a/tests/poutine/test_nesting.py +++ b/tests/poutine/test_nesting.py @@ -4,10 +4,11 @@ import logging import pyro -import pyro.distributions as dist import pyro.poutine as poutine +import pyro.distributions as dist import pyro.poutine.runtime + logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 99fbfb6336..f2f4eee025 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -6,12 +6,12 @@ import logging import pickle import warnings -from queue import Queue from unittest import TestCase import pytest import torch import torch.nn as nn +from queue import Queue import pyro import pyro.distributions as dist @@ -19,7 +19,7 @@ from pyro.distributions import Bernoulli, Categorical, Normal from pyro.poutine.runtime import _DIM_ALLOCATOR, NonlocalExit from pyro.poutine.util import all_escape, discrete_escape -from tests.common import assert_close, assert_equal, assert_not_equal +from tests.common import assert_equal, assert_not_equal, assert_close logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_trace_struct.py b/tests/poutine/test_trace_struct.py index 4511ccbdf3..9ad7d351a6 100644 --- a/tests/poutine/test_trace_struct.py +++ b/tests/poutine/test_trace_struct.py @@ -8,6 +8,7 @@ from pyro.poutine import Trace from tests.common import assert_equal + EDGE_SETS = [ # 1 # / \ diff --git a/tests/pyroapi/test_pyroapi.py b/tests/pyroapi/test_pyroapi.py index 271c38efab..1fa1673b9f 100644 --- a/tests/pyroapi/test_pyroapi.py +++ b/tests/pyroapi/test_pyroapi.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 diff --git a/tests/test_generic.py b/tests/test_generic.py index 1ca5c77588..a3324b27c0 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from pyroapi.testing import MODELS -from pyro.generic import handlers, infer, ops, pyro, pyro_backend +from pyro.generic import handlers, infer, pyro, pyro_backend, ops +from pyroapi.testing import MODELS from tests.common import xfail_if_not_implemented pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 22f331a450..d285ad69b5 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import torch - import pyro import pyro.distributions as dist +import torch pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_util.py b/tests/test_util.py index 09ec92f4f7..f8b382b4ec 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import warnings - import pytest -import torch +import torch from pyro import util pytestmark = pytest.mark.stage('unit') From f696c16fec90e8f6d4d5a0ab481470ac89b3a899 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 14:39:07 -0500 Subject: [PATCH 34/91] Add back in test examples lines. --- docs/source/contrib.mue.rst | 11 +++++++++-- tests/test_examples.py | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 335bcd8965..9dbc3d05b2 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -14,6 +14,13 @@ Reference: MuE models were described in Weinstein and Marks (2020), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. +Example MuE Models +------------------ +.. automodule:: pyro.contrib.mue.models + :members: + :show-inheritance: + :member-order: bysource + State Arrangers for Parameterizing MuEs --------------------------------------- .. automodule:: pyro.contrib.mue.statearrangers @@ -21,9 +28,9 @@ State Arrangers for Parameterizing MuEs :show-inheritance: :member-order: bysource -Variable Length/Missing Data HMM +Missing Variable Length Data HMM -------------------------------- -.. automodule:: pyro.contrib.mue.variablelengthhmm +.. automodule:: pyro.contrib.mue.missingdatahmm :members: :show-inheritance: :member-order: bysource diff --git a/tests/test_examples.py b/tests/test_examples.py index b0d0cb96c8..99cb2370df 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,6 +53,8 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', + 'contrib/mue/FactorMuE.py --test --small', + 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', From cb31d2d0230187f61a7bb227b9b46857cb5a928a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 14:51:56 -0500 Subject: [PATCH 35/91] put in constraints --- pyro/contrib/mue/dataloaders.py | 2 ++ pyro/contrib/mue/missingdatahmm.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index de03fccfb1..9f3df9580c 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -18,6 +18,8 @@ class BiosequenceDataset(Dataset): def __init__(self, source, source_type='list', alphabet='amino-acid'): + super().__init__() + # Get sequences. if source_type == 'list': seqs = source diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index 201811bcd1..a6396d2044 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 import torch -from torch.distributions import constraints +from pyro.distributions import constraints from pyro.distributions.hmm import _sequential_logmatmulexp from pyro.distributions.torch_distribution import TorchDistribution from pyro.distributions.util import broadcast_shape @@ -34,9 +34,12 @@ class MissingDataDiscreteHMM(TorchDistribution): dimension of the categorical output, and be broadcastable to ``(batch_size, state_dim, categorical_size)``. """ - arg_constraints = {"initial_logits": constraints.real, - "transition_logits": constraints.real, - "observation_logits": constraints.real} + arg_constraints = {"initial_logits": constraints.real_vector, + "transition_logits": constraints.independent( + constraints.real, 2), + "observation_logits": constraints.independent( + constraints.real, 2)} + support = constraints.independent(constraints.nonnegative_integer, 2) def __init__(self, initial_logits, transition_logits, observation_logits, validate_args=None): From dd1a383fa17acd6766fe20a131bca8d79b60b09f Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 15:06:44 -0500 Subject: [PATCH 36/91] Git ignore .fasta files. --- .gitignore | 1 + tests/contrib/mue/test_seqs.fasta | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) delete mode 100644 tests/contrib/mue/test_seqs.fasta diff --git a/.gitignore b/.gitignore index f9a717d67b..bbb6f57b40 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ pyro/_version.py processed raw *.pkl +*.fasta # Logs logs diff --git a/tests/contrib/mue/test_seqs.fasta b/tests/contrib/mue/test_seqs.fasta deleted file mode 100644 index 2a90359e45..0000000000 --- a/tests/contrib/mue/test_seqs.fasta +++ /dev/null @@ -1,7 +0,0 @@ ->one -AAT -C ->two -CA ->three -T From d2c513c2034763199983dceb78d1c83db5f2889f Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 15:17:51 -0500 Subject: [PATCH 37/91] Make format's automated changes. --- docs/source/conf.py | 1 - examples/air/air.py | 2 +- examples/air/main.py | 4 +- examples/capture_recapture/cjs.py | 3 +- examples/contrib/autoname/scoping_mixture.py | 7 ++- examples/contrib/funsor/hmm.py | 1 - examples/contrib/gp/sv-dkl.py | 2 +- examples/contrib/oed/ab_test.py | 14 +++-- examples/contrib/oed/gp_bayes_opt.py | 2 +- examples/contrib/timeseries/gp_models.py | 8 +-- examples/cvae/baseline.py | 5 +- examples/cvae/cvae.py | 10 ++-- examples/cvae/main.py | 10 ++-- examples/cvae/mnist.py | 2 +- examples/cvae/util.py | 12 +++-- examples/eight_schools/data.py | 1 - examples/eight_schools/mcmc.py | 2 +- examples/eight_schools/svi.py | 2 +- examples/hmm.py | 2 +- examples/lkj.py | 3 +- examples/minipyro.py | 2 +- examples/mixed_hmm/experiment.py | 9 ++-- examples/mixed_hmm/seal_data.py | 2 - examples/rsa/generics.py | 9 ++-- examples/rsa/hyperbole.py | 9 ++-- examples/rsa/schelling.py | 3 +- examples/rsa/schelling_false.py | 3 +- examples/rsa/search_inference.py | 4 +- examples/rsa/semantic_parsing.py | 7 ++- examples/scanvi/data.py | 6 +-- examples/scanvi/scanvi.py | 14 +++-- examples/sparse_gamma_def.py | 11 ++-- examples/sparse_regression.py | 9 ++-- examples/vae/ss_vae_M2.py | 6 +-- examples/vae/vae.py | 6 +-- examples/vae/vae_comparison.py | 2 +- profiler/hmm.py | 3 +- profiler/profiling_utils.py | 2 +- pyro/contrib/__init__.py | 1 + pyro/contrib/autoname/__init__.py | 3 +- pyro/contrib/bnn/hidden_layer.py | 2 +- pyro/contrib/bnn/utils.py | 3 +- pyro/contrib/conjugate/infer.py | 2 +- pyro/contrib/easyguide/__init__.py | 1 - pyro/contrib/easyguide/easyguide.py | 2 +- pyro/contrib/examples/bart.py | 2 +- pyro/contrib/examples/finance.py | 2 +- .../examples/polyphonic_data_loader.py | 3 +- pyro/contrib/funsor/__init__.py | 12 ++--- pyro/contrib/funsor/handlers/__init__.py | 8 +-- .../contrib/funsor/handlers/enum_messenger.py | 9 ++-- .../funsor/handlers/named_messenger.py | 3 +- pyro/contrib/funsor/handlers/primitives.py | 1 - .../funsor/handlers/replay_messenger.py | 2 +- pyro/contrib/funsor/infer/__init__.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 7 ++- pyro/contrib/funsor/infer/traceenum_elbo.py | 5 +- pyro/contrib/funsor/infer/tracetmc_elbo.py | 6 +-- pyro/contrib/gp/kernels/__init__.py | 7 ++- pyro/contrib/gp/likelihoods/binary.py | 1 - pyro/contrib/gp/likelihoods/gaussian.py | 1 - pyro/contrib/gp/likelihoods/multi_class.py | 1 - pyro/contrib/gp/likelihoods/poisson.py | 1 - pyro/contrib/oed/__init__.py | 2 +- pyro/contrib/oed/eig.py | 9 ++-- pyro/contrib/oed/glmm/__init__.py | 2 +- pyro/contrib/oed/glmm/glmm.py | 6 +-- pyro/contrib/oed/glmm/guides.py | 8 ++- pyro/contrib/oed/search.py | 3 +- pyro/contrib/oed/util.py | 3 +- .../contrib/randomvariable/random_variable.py | 13 ++--- pyro/contrib/timeseries/__init__.py | 2 +- pyro/contrib/tracking/distributions.py | 2 +- pyro/contrib/tracking/dynamic_models.py | 1 + pyro/contrib/tracking/measurements.py | 1 + pyro/contrib/util.py | 2 + pyro/distributions/__init__.py | 51 ++++--------------- pyro/distributions/ordered_logistic.py | 1 + pyro/distributions/projected_normal.py | 3 +- pyro/distributions/spanning_tree.py | 1 + pyro/distributions/testing/gof.py | 2 +- pyro/distributions/transforms/__init__.py | 8 +-- pyro/distributions/transforms/ordered.py | 3 +- pyro/infer/__init__.py | 2 +- pyro/infer/abstract_infer.py | 2 +- pyro/infer/mcmc/__init__.py | 2 +- pyro/infer/mcmc/adaptation.py | 2 +- pyro/infer/mcmc/hmc.py | 3 +- pyro/infer/mcmc/util.py | 4 +- pyro/infer/predictive.py | 2 +- pyro/infer/reparam/neutra.py | 1 + pyro/infer/svgd.py | 6 +-- pyro/infer/trace_mean_field_elbo.py | 2 +- pyro/infer/trace_mmd.py | 2 +- pyro/infer/trace_tail_adaptive_elbo.py | 2 +- pyro/infer/traceenum_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 3 +- pyro/infer/tracetmc_elbo.py | 1 - pyro/logger.py | 1 - pyro/ops/arrowhead.py | 1 - pyro/ops/einsum/torch_map.py | 1 - pyro/ops/einsum/torch_sample.py | 1 - pyro/ops/newton.py | 2 +- pyro/ops/ssm_gp.py | 2 +- pyro/poutine/broadcast_messenger.py | 1 + pyro/poutine/indep_messenger.py | 1 + pyro/poutine/trace_struct.py | 2 +- tests/__init__.py | 1 - tests/conftest.py | 1 - tests/contrib/autoguide/test_inference.py | 4 +- .../autoguide/test_mean_field_entropy.py | 4 +- tests/contrib/autoname/test_scoping.py | 2 +- tests/contrib/bnn/test_hidden_layer.py | 2 +- tests/contrib/epidemiology/test_quant.py | 1 - tests/contrib/funsor/test_enum_funsor.py | 4 +- tests/contrib/funsor/test_named_handlers.py | 3 +- tests/contrib/funsor/test_pyroapi_funsor.py | 2 +- tests/contrib/funsor/test_tmc.py | 3 +- .../contrib/funsor/test_valid_models_enum.py | 5 +- .../contrib/funsor/test_valid_models_plate.py | 3 +- .../test_valid_models_sequential_plate.py | 3 +- .../contrib/funsor/test_vectorized_markov.py | 7 +-- tests/contrib/gp/test_kernels.py | 5 +- tests/contrib/gp/test_likelihoods.py | 1 - tests/contrib/gp/test_models.py | 7 ++- tests/contrib/oed/test_ewma.py | 2 +- tests/contrib/oed/test_finite_spaces_eig.py | 8 ++- tests/contrib/oed/test_glmm.py | 6 +-- tests/contrib/oed/test_linear_models_eig.py | 11 ++-- .../randomvariable/test_random_variable.py | 1 + tests/contrib/test_util.py | 5 +- tests/contrib/timeseries/test_gp.py | 9 ++-- tests/contrib/timeseries/test_lgssm.py | 4 +- tests/contrib/tracking/test_assignment.py | 4 +- tests/contrib/tracking/test_distributions.py | 3 +- tests/contrib/tracking/test_dynamic_models.py | 3 +- tests/contrib/tracking/test_ekf.py | 3 +- tests/contrib/tracking/test_em.py | 1 - tests/contrib/tracking/test_measurements.py | 1 + tests/distributions/test_empirical.py | 2 +- tests/distributions/test_gaussian_mixtures.py | 6 +-- tests/distributions/test_haar.py | 2 +- tests/distributions/test_ig.py | 2 +- tests/distributions/test_mask.py | 3 +- tests/distributions/test_mvt.py | 1 - tests/distributions/test_omt_mvn.py | 2 +- tests/distributions/test_ordered_logistic.py | 3 +- tests/distributions/test_pickle.py | 2 +- tests/distributions/test_spanning_tree.py | 3 +- tests/distributions/test_transforms.py | 5 +- tests/doctest_fixtures.py | 5 +- tests/infer/mcmc/test_adaptation.py | 7 +-- tests/infer/mcmc/test_hmc.py | 4 +- tests/infer/mcmc/test_mcmc_api.py | 2 +- tests/infer/mcmc/test_nuts.py | 8 +-- tests/infer/test_abstract_infer.py | 3 +- tests/infer/test_predictive.py | 2 +- tests/infer/test_svgd.py | 6 +-- tests/infer/test_tmc.py | 3 +- tests/infer/test_util.py | 2 +- tests/ops/test_arrowhead.py | 1 - tests/ops/test_gamma_gaussian.py | 8 +-- tests/ops/test_newton.py | 1 - tests/perf/test_benchmark.py | 2 +- tests/poutine/test_nesting.py | 3 +- tests/poutine/test_poutines.py | 4 +- tests/poutine/test_trace_struct.py | 1 - tests/pyroapi/test_pyroapi.py | 1 - tests/test_generic.py | 4 +- tests/test_primitives.py | 3 +- tests/test_util.py | 3 +- 171 files changed, 294 insertions(+), 387 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index d39a267cb9..597d1c2ac6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,7 +6,6 @@ import sphinx_rtd_theme - # import pkg_resources # -*- coding: utf-8 -*- diff --git a/examples/air/air.py b/examples/air/air.py index a9b8958d0f..985e2853bf 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -14,10 +14,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import MLP, Decoder, Encoder, Identity, Predict import pyro import pyro.distributions as dist -from modules import MLP, Decoder, Encoder, Identity, Predict # Default prior success probability for z_pres. diff --git a/examples/air/main.py b/examples/air/main.py index 43d30d3a02..0506a32b9e 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -19,15 +19,15 @@ import numpy as np import torch import visdom +from air import AIR, latents_to_tensor +from viz import draw_many, tensor_to_objs import pyro import pyro.contrib.examples.multi_mnist as multi_mnist import pyro.optim as optim import pyro.poutine as poutine -from air import AIR, latents_to_tensor from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO -from viz import draw_many, tensor_to_objs def count_accuracy(X, true_counts, air, batch_size): diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index 65b709afca..fa868899d5 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -39,11 +39,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer.autoguide import AutoDiagonalNormal from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO +from pyro.infer.autoguide import AutoDiagonalNormal from pyro.optim import Adam - """ Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival probability phi; diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 363d3ace53..d39b65eb6d 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -2,16 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import argparse + import torch from torch.distributions import constraints import pyro -import pyro.optim import pyro.distributions as dist - -from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO - +import pyro.optim from pyro.contrib.autoname import scope +from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate def model(K, data): diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 017a9aa2e8..98fca8eca8 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -64,7 +64,6 @@ from pyroapi import distributions as dist from pyroapi import handlers, infer, optim, pyro, pyro_backend - logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index f07c4052e4..19e8245c03 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -39,7 +39,7 @@ import pyro import pyro.contrib.gp as gp import pyro.infer as infer -from pyro.contrib.examples.util import get_data_loader, get_data_directory +from pyro.contrib.examples.util import get_data_directory, get_data_loader class CNN(nn.Module): diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 522dc44ad3..f417d78ca8 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -3,20 +3,18 @@ import argparse from functools import partial + +import numpy as np import torch +from gp_bayes_opt import GPBayesOptimizer from torch.distributions import constraints -import numpy as np import pyro +import pyro.contrib.gp as gp from pyro import optim -from pyro.infer import TraceEnum_ELBO from pyro.contrib.oed.eig import vi_eig -import pyro.contrib.gp as gp -from pyro.contrib.oed.glmm import ( - zero_mean_unit_obs_sd_lm, group_assignment_matrix, analytic_posterior_cov -) - -from gp_bayes_opt import GPBayesOptimizer +from pyro.contrib.oed.glmm import analytic_posterior_cov, group_assignment_matrix, zero_mean_unit_obs_sd_lm +from pyro.infer import TraceEnum_ELBO """ Example builds on the Bayesian regression tutorial [1]. It demonstrates how diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 3c114c9bcf..6132dee48a 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -7,8 +7,8 @@ from torch.distributions import transform_to import pyro.contrib.gp as gp -from pyro.infer import TraceEnum_ELBO import pyro.optim +from pyro.infer import TraceEnum_ELBO class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index 259ef99180..e2ab952b90 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -1,16 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import argparse +from os.path import exists +from urllib.request import urlopen + import numpy as np import torch import pyro from pyro.contrib.timeseries import IndependentMaternGP, LinearlyCoupledMaternGP -import argparse -from os.path import exists -from urllib.request import urlopen - # download dataset from UCI archive def download_data(): diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index cb5d279445..23e1591016 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import numpy as np from pathlib import Path -from tqdm import tqdm + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm class BaselineNet(nn.Module): diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index fb792cf4d3..f499aaf452 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -1,15 +1,17 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import numpy as np from pathlib import Path -import pyro -import pyro.distributions as dist -from pyro.infer import SVI, Trace_ELBO + +import numpy as np import torch import torch.nn as nn from tqdm import tqdm +import pyro +import pyro.distributions as dist +from pyro.infer import SVI, Trace_ELBO + class Encoder(nn.Module): def __init__(self, z_dim, hidden_1, hidden_2): diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 2056b4d3a6..73b85ac92d 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import pandas as pd -import pyro -import torch + import baseline import cvae -from util import get_data, visualize, generate_table +import pandas as pd +import torch +from util import generate_table, get_data, visualize + +import pyro def main(args): diff --git a/examples/cvae/mnist.py b/examples/cvae/mnist.py index a98c667081..12dd7409f2 100644 --- a/examples/cvae/mnist.py +++ b/examples/cvae/mnist.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset from torchvision.datasets import MNIST from torchvision.transforms import Compose, functional diff --git a/examples/cvae/util.py b/examples/cvae/util.py index e578085946..87650298ef 100644 --- a/examples/cvae/util.py +++ b/examples/cvae/util.py @@ -1,17 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import numpy as np +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from pathlib import Path -from pyro.infer import Predictive, Trace_ELBO import torch +from baseline import MaskedBCELoss +from mnist import get_data from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm -from baseline import MaskedBCELoss -from mnist import get_data + +from pyro.infer import Predictive, Trace_ELBO def imshow(inp, image_path=None): diff --git a/examples/eight_schools/data.py b/examples/eight_schools/data.py index 39529e798a..56158fa36e 100644 --- a/examples/eight_schools/data.py +++ b/examples/eight_schools/data.py @@ -3,7 +3,6 @@ import torch - J = 8 y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index 62b184ea85..d34d901e53 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -4,9 +4,9 @@ import argparse import logging +import data import torch -import data import pyro import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index a06a801768..b6f35f254b 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -5,11 +5,11 @@ import logging import torch +from data import J, sigma, y from torch.distributions import constraints, transforms import pyro import pyro.distributions as dist -from data import J, sigma, y from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam diff --git a/examples/hmm.py b/examples/hmm.py index 7c706ed558..68395da7c1 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -47,8 +47,8 @@ import pyro.contrib.examples.polyphonic_data_loader as poly import pyro.distributions as dist from pyro import poutine -from pyro.infer.autoguide import AutoDelta from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO +from pyro.infer.autoguide import AutoDelta from pyro.ops.indexing import Vindex from pyro.optim import Adam from pyro.util import ignore_jit_warnings diff --git a/examples/lkj.py b/examples/lkj.py index fb437358bd..dd54c5ed43 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import argparse + import torch import pyro import pyro.distributions as dist -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc import NUTS +from pyro.infer.mcmc.api import MCMC """ This simple example is intended to demonstrate how to use an LKJ prior with diff --git a/examples/minipyro.py b/examples/minipyro.py index 33ab6eba01..084ddfdca5 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -11,8 +11,8 @@ import torch -from pyro.generic import distributions as dist # We use the pyro.generic interface to support dynamic choice of backend. +from pyro.generic import distributions as dist from pyro.generic import infer, ops, optim, pyro, pyro_backend diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index 65584c6769..bd4bd02e6a 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -2,20 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import os +import functools import json +import os import uuid -import functools import torch +from model import guide_generic, model_generic +from seal_data import prepare_seal import pyro import pyro.poutine as poutine from pyro.infer import TraceEnum_ELBO -from model import model_generic, guide_generic -from seal_data import prepare_seal - def aic_num_parameters(model, guide=None): """ diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py index 390201a8d9..609fc69da3 100644 --- a/examples/mixed_hmm/seal_data.py +++ b/examples/mixed_hmm/seal_data.py @@ -5,10 +5,8 @@ from urllib.request import urlopen import pandas as pd - import torch - MISSING = 1e-6 diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 9a13a4e54a..4ca04bc349 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -9,18 +9,17 @@ [1] https://gscontras.github.io/probLang/chapters/07-generics.html """ -import torch - import argparse -import numbers import collections +import numbers + +import torch +from search_inference import HashingMarginal, Search, memoize import pyro import pyro.distributions as dist import pyro.poutine as poutine -from search_inference import HashingMarginal, memoize, Search - torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index a77b01870a..93dd409579 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -7,17 +7,16 @@ Taken from: https://gscontras.github.io/probLang/chapters/03-nonliteral.html """ -import torch - -import collections import argparse +import collections + +import torch +from search_inference import HashingMarginal, Search, memoize import pyro import pyro.distributions as dist import pyro.poutine as poutine -from search_inference import HashingMarginal, memoize, Search - torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index 886eb405b6..a111b5d2bc 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -11,12 +11,13 @@ Taken from: http://forestdb.org/models/schelling.html """ import argparse + import torch +from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli -from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index 998e3b70cb..f855a010e2 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -12,12 +12,13 @@ Taken from: http://forestdb.org/models/schelling-falsebelief.html """ import argparse + import torch +from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli -from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/search_inference.py b/examples/rsa/search_inference.py index 14a49766f1..7e2cb8e142 100644 --- a/examples/rsa/search_inference.py +++ b/examples/rsa/search_inference.py @@ -8,10 +8,10 @@ """ import collections +import functools +import queue import torch -import queue -import functools import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index 15ffe901aa..eb188b0a54 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -7,16 +7,15 @@ Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html """ -import torch - import argparse import collections +import torch +from search_inference import BestFirstSearch, HashingMarginal, memoize + import pyro import pyro.distributions as dist -from search_inference import HashingMarginal, BestFirstSearch, memoize - torch.set_default_dtype(torch.float64) diff --git a/examples/scanvi/data.py b/examples/scanvi/data.py index 690eab0717..429883d1a3 100644 --- a/examples/scanvi/data.py +++ b/examples/scanvi/data.py @@ -8,11 +8,11 @@ """ import math -import numpy as np -from scipy import sparse +import numpy as np import torch import torch.nn as nn +from scipy import sparse class BatchDataLoader(object): @@ -122,8 +122,8 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): return BatchDataLoader(X, Y, batch_size), num_genes, 2.0, 1.0, None - import scvi import scanpy as sc + import scvi adata = scvi.data.purified_pbmc_dataset(subset_datasets=["regulatory_t", "naive_t", "memory_t", "naive_cytotoxic"]) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 928785827a..d4928ddf27 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -19,25 +19,22 @@ import argparse +import matplotlib.pyplot as plt import numpy as np - import torch import torch.nn as nn +from data import get_data +from matplotlib.patches import Patch from torch.distributions import constraints -from torch.nn.functional import softplus, softmax +from torch.nn.functional import softmax, softplus from torch.optim import Adam import pyro import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.util import broadcast_shape +from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate from pyro.optim import MultiStepLR -from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO - -import matplotlib.pyplot as plt -from matplotlib.patches import Patch - -from data import get_data # Helper for making fully-connected neural networks @@ -300,6 +297,7 @@ def main(args): # Now that we're done training we'll inspect the latent representations we've learned if args.plot and args.dataset == 'pbmc': import scanpy as sc + # Compute latent representation (z2_loc) for each cell in the dataset latent_rep = scanvi.z2l_encoder(dataloader.data_x)[0] diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index 5de2bf6f1d..fddb1b1bc7 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -20,19 +20,16 @@ import numpy as np import torch +import wget from torch.nn.functional import softplus import pyro import pyro.optim as optim -import wget - +from pyro.contrib.easyguide import EasyGuide from pyro.contrib.examples.util import get_data_directory -from pyro.distributions import Gamma, Poisson, Normal +from pyro.distributions import Gamma, Normal, Poisson from pyro.infer import SVI, TraceMeanField_ELBO -from pyro.infer.autoguide import AutoDiagonalNormal -from pyro.infer.autoguide import init_to_feasible -from pyro.contrib.easyguide import EasyGuide - +from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible torch.set_default_tensor_type('torch.FloatTensor') pyro.util.set_rng_seed(0) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index da807cfcb7..63cffbc7c6 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -2,20 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import math import numpy as np import torch -import math +from torch.optim import Adam import pyro import pyro.distributions as dist from pyro import poutine -from pyro.infer.autoguide import AutoDelta from pyro.infer import Trace_ELBO -from pyro.infer.autoguide import init_to_median - -from torch.optim import Adam - +from pyro.infer.autoguide import AutoDelta, init_to_median """ We demonstrate how to do sparse linear regression using a variant of the diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index 265097efc4..a064140ab5 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn +from utils.custom_mlp import MLP, Exp +from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders +from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae from visdom import Visdom import pyro @@ -12,9 +15,6 @@ from pyro.contrib.examples.util import print_and_log from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate from pyro.optim import Adam -from utils.custom_mlp import MLP, Exp -from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders -from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae class SSVAE(nn.Module): diff --git a/examples/vae/vae.py b/examples/vae/vae.py index d4f54e515e..95cee4e66b 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -7,14 +7,14 @@ import torch import torch.nn as nn import visdom +from utils.mnist_cached import MNISTCached as MNIST +from utils.mnist_cached import setup_data_loaders +from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples import pyro import pyro.distributions as dist from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam -from utils.mnist_cached import MNISTCached as MNIST -from utils.mnist_cached import setup_data_loaders -from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples # define the PyTorch module that parameterizes the diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index abefcb0b03..9ee1b52704 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -10,13 +10,13 @@ import torch.nn as nn from torch.nn import functional from torchvision.utils import save_image +from utils.mnist_cached import DATA_DIR, RESULTS_DIR import pyro from pyro.contrib.examples import util from pyro.distributions import Bernoulli, Normal from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam -from utils.mnist_cached import DATA_DIR, RESULTS_DIR """ Comparison of VAE implementation in PyTorch and Pyro. This example can be diff --git a/profiler/hmm.py b/profiler/hmm.py index 4308c3df56..1825c82c20 100644 --- a/profiler/hmm.py +++ b/profiler/hmm.py @@ -8,13 +8,12 @@ import subprocess import sys from collections import defaultdict -from os.path import join, abspath +from os.path import abspath, join from numpy import median from pyro.util import timed - EXAMPLES_DIR = join(abspath(__file__), os.pardir, os.pardir, "examples") diff --git a/profiler/profiling_utils.py b/profiler/profiling_utils.py index 8375132eb2..aee4f9b564 100644 --- a/profiler/profiling_utils.py +++ b/profiler/profiling_utils.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import cProfile -from io import StringIO import functools import os import pstats import timeit from contextlib import contextmanager +from io import StringIO from prettytable import ALL, PrettyTable diff --git a/pyro/contrib/__init__.py b/pyro/contrib/__init__.py index 3f14bd1862..045ec5435f 100644 --- a/pyro/contrib/__init__.py +++ b/pyro/contrib/__init__.py @@ -25,6 +25,7 @@ try: import funsor as funsor_ # noqa: F401 + from pyro.contrib import funsor __all__ += ["funsor"] except ImportError: diff --git a/pyro/contrib/autoname/__init__.py b/pyro/contrib/autoname/__init__.py index b93f5d7db7..9a379fb2ff 100644 --- a/pyro/contrib/autoname/__init__.py +++ b/pyro/contrib/autoname/__init__.py @@ -6,9 +6,8 @@ generating unique, semantically meaningful names for sample sites. """ from pyro.contrib.autoname import named -from pyro.contrib.autoname.scoping import scope, name_count from pyro.contrib.autoname.autoname import autoname, sample - +from pyro.contrib.autoname.scoping import name_count, scope __all__ = [ "named", diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index 6a4f679e29..cc97b051fa 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from torch.distributions.utils import lazy_property import torch.nn.functional as F +from torch.distributions.utils import lazy_property from pyro.contrib.bnn.utils import adjoin_ones_vector from pyro.distributions.torch_distribution import TorchDistribution diff --git a/pyro/contrib/bnn/utils.py b/pyro/contrib/bnn/utils.py index ec2f33623a..794f66f984 100644 --- a/pyro/contrib/bnn/utils.py +++ b/pyro/contrib/bnn/utils.py @@ -1,9 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import math +import torch + def xavier_uniform(D_in, D_out): scale = math.sqrt(6.0 / float(D_in + D_out)) diff --git a/pyro/contrib/conjugate/infer.py b/pyro/contrib/conjugate/infer.py index 23a3fe791e..0c815c0126 100644 --- a/pyro/contrib/conjugate/infer.py +++ b/pyro/contrib/conjugate/infer.py @@ -6,8 +6,8 @@ import torch import pyro.distributions as dist -from pyro.distributions.util import sum_leftmost from pyro import poutine +from pyro.distributions.util import sum_leftmost from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample diff --git a/pyro/contrib/easyguide/__init__.py b/pyro/contrib/easyguide/__init__.py index 9e2577841f..d26c63c9cf 100644 --- a/pyro/contrib/easyguide/__init__.py +++ b/pyro/contrib/easyguide/__init__.py @@ -3,7 +3,6 @@ from pyro.contrib.easyguide.easyguide import EasyGuide, easy_guide - __all__ = [ "EasyGuide", "easy_guide", diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index fbc204466b..55535ae72d 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -14,8 +14,8 @@ import pyro.poutine as poutine import pyro.poutine.runtime as runtime from pyro.distributions.util import broadcast_shape, sum_rightmost -from pyro.infer.autoguide.initialization import InitMessenger from pyro.infer.autoguide.guides import prototype_hide_fn +from pyro.infer.autoguide.initialization import InitMessenger from pyro.nn.module import PyroModule, PyroParam diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index 0d89fee5fc..0398ad137d 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -14,7 +14,7 @@ import torch -from pyro.contrib.examples.util import get_data_directory, _mkdir_p +from pyro.contrib.examples.util import _mkdir_p, get_data_directory DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/finance.py b/pyro/contrib/examples/finance.py index 03572c0289..c40a0b55e8 100644 --- a/pyro/contrib/examples/finance.py +++ b/pyro/contrib/examples/finance.py @@ -6,7 +6,7 @@ import pandas as pd -from pyro.contrib.examples.util import get_data_directory, _mkdir_p +from pyro.contrib.examples.util import _mkdir_p, get_data_directory DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/polyphonic_data_loader.py b/pyro/contrib/examples/polyphonic_data_loader.py index 132c6d953d..491ae0517f 100644 --- a/pyro/contrib/examples/polyphonic_data_loader.py +++ b/pyro/contrib/examples/polyphonic_data_loader.py @@ -17,9 +17,9 @@ """ import os +import pickle from collections import namedtuple from urllib.request import urlopen -import pickle import torch import torch.nn as nn @@ -27,7 +27,6 @@ from pyro.contrib.examples.util import get_data_directory - dset = namedtuple("dset", ["name", "url", "filename"]) JSB_CHORALES = dset("jsb_chorales", diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index 30a23d5ca2..dcb9355e5e 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -3,14 +3,12 @@ import pyroapi -from pyro.primitives import ( # noqa: F401 - clear_param_store, deterministic, enable_validation, factor, get_param_store, - module, param, random_module, sample, set_rng_seed, subsample, -) - -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 -from pyro.contrib.funsor.handlers import condition, do, markov, vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers import condition, do, markov # noqa: F401 from pyro.contrib.funsor.handlers import plate as _plate +from pyro.contrib.funsor.handlers import vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 +from pyro.primitives import (clear_param_store, deterministic, enable_validation, factor, get_param_store, # noqa: F401 + module, param, random_module, sample, set_rng_seed, subsample) def plate(*args, **kwargs): diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index a98be1de94..724ec29c83 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -1,20 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from pyro.poutine import (block, condition, do, escape, infer_config, mask, reparam, scale, seed, # noqa: F401 + uncondition) from pyro.poutine.handlers import _make_handler -from pyro.poutine import ( # noqa: F401 - block, condition, do, escape, infer_config, - mask, reparam, scale, seed, uncondition, -) - from .enum_messenger import EnumMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger from .trace_messenger import TraceMessenger - _msngrs = [ EnumMessenger, MarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 15caf49078..befbeb2014 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -9,18 +9,17 @@ import math from collections import OrderedDict -import torch import funsor +import torch import pyro.poutine.runtime import pyro.poutine.util -from pyro.poutine.escape_messenger import EscapeMessenger -from pyro.poutine.subsample_messenger import _Subsample - -from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger +from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger +from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index fb7667fab3..e7cf18ffbc 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,9 +4,8 @@ from collections import OrderedDict from contextlib import ExitStack -from pyro.poutine.reentrant_messenger import ReentrantMessenger - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame +from pyro.poutine.reentrant_messenger import ReentrantMessenger class NamedMessenger(ReentrantMessenger): diff --git a/pyro/contrib/funsor/handlers/primitives.py b/pyro/contrib/funsor/handlers/primitives.py index 3d8815eff0..0b7a4c4edb 100644 --- a/pyro/contrib/funsor/handlers/primitives.py +++ b/pyro/contrib/funsor/handlers/primitives.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.poutine.runtime - from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index 2389941049..ae672d2dd4 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger from pyro.contrib.funsor.handlers.primitives import to_data +from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger class ReplayMessenger(OrigReplayMessenger): diff --git a/pyro/contrib/funsor/infer/__init__.py b/pyro/contrib/funsor/infer/__init__.py index 4525e2cef5..55f260e6a9 100644 --- a/pyro/contrib/funsor/infer/__init__.py +++ b/pyro/contrib/funsor/infer/__init__.py @@ -5,5 +5,5 @@ from .elbo import ELBO # noqa: F401 from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401 -from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 +from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 1912edca11..686926b772 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -5,14 +5,13 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import Trace_ELBO as _OrigTrace_ELBO - from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer import config_enumerate +from pyro.distributions.util import copy_docs_from +from pyro.infer import Trace_ELBO as _OrigTrace_ELBO -from .elbo import Jit_ELBO, ELBO +from .elbo import ELBO, Jit_ELBO from .traceenum_elbo import terms_from_trace diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 14386c69c0..a63a8cabc6 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -5,12 +5,11 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO - from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO def terms_from_trace(tr): diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index 8a714d8c03..aae66ce5c0 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -5,14 +5,12 @@ import funsor -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO - from pyro.contrib.funsor import to_data from pyro.contrib.funsor.handlers import enum, plate, replay, trace - from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO @copy_docs_from(_OrigTraceTMC_ELBO) diff --git a/pyro/contrib/gp/kernels/__init__.py b/pyro/contrib/gp/kernels/__init__.py index 9874e73c17..c36ddd37fb 100644 --- a/pyro/contrib/gp/kernels/__init__.py +++ b/pyro/contrib/gp/kernels/__init__.py @@ -4,10 +4,9 @@ from pyro.contrib.gp.kernels.brownian import Brownian from pyro.contrib.gp.kernels.coregionalize import Coregionalize from pyro.contrib.gp.kernels.dot_product import DotProduct, Linear, Polynomial -from pyro.contrib.gp.kernels.isotropic import (RBF, Exponential, Isotropy, Matern32, Matern52, - RationalQuadratic) -from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, - Transforming, VerticalScaling, Warping) +from pyro.contrib.gp.kernels.isotropic import RBF, Exponential, Isotropy, Matern32, Matern52, RationalQuadratic +from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, Transforming, VerticalScaling, + Warping) from pyro.contrib.gp.kernels.periodic import Cosine, Periodic from pyro.contrib.gp.kernels.static import Constant, WhiteNoise diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index 3041f5e92e..ef417f9e22 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/gaussian.py b/pyro/contrib/gp/likelihoods/gaussian.py index cb5a15d8c7..b1b65ff95c 100644 --- a/pyro/contrib/gp/likelihoods/gaussian.py +++ b/pyro/contrib/gp/likelihoods/gaussian.py @@ -6,7 +6,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood from pyro.nn.module import PyroParam diff --git a/pyro/contrib/gp/likelihoods/multi_class.py b/pyro/contrib/gp/likelihoods/multi_class.py index 9ff69e81f1..ed8463f8bd 100644 --- a/pyro/contrib/gp/likelihoods/multi_class.py +++ b/pyro/contrib/gp/likelihoods/multi_class.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/poisson.py b/pyro/contrib/gp/likelihoods/poisson.py index 48916e0634..8abed6fd2a 100644 --- a/pyro/contrib/gp/likelihoods/poisson.py +++ b/pyro/contrib/gp/likelihoods/poisson.py @@ -5,7 +5,6 @@ import pyro import pyro.distributions as dist - from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/oed/__init__.py b/pyro/contrib/oed/__init__.py index 006c57a7a1..3afd3a440d 100644 --- a/pyro/contrib/oed/__init__.py +++ b/pyro/contrib/oed/__init__.py @@ -67,7 +67,7 @@ def model(design): """ -from pyro.contrib.oed import search, eig +from pyro.contrib.oed import eig, search __all__ = [ "search", diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index 7faec28aa6..8d6c7ae22b 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -1,17 +1,18 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import math import warnings +import torch + import pyro from pyro import poutine -from pyro.infer.autoguide.utils import mean_field_entropy from pyro.contrib.oed.search import Search -from pyro.infer import EmpiricalMarginal, Importance, SVI -from pyro.util import torch_isnan, torch_isinf from pyro.contrib.util import lexpand +from pyro.infer import SVI, EmpiricalMarginal, Importance +from pyro.infer.autoguide.utils import mean_field_entropy +from pyro.util import torch_isinf, torch_isnan __all__ = [ "laplace_eig", diff --git a/pyro/contrib/oed/glmm/__init__.py b/pyro/contrib/oed/glmm/__init__.py index f9d75643ca..c17c221213 100644 --- a/pyro/contrib/oed/glmm/__init__.py +++ b/pyro/contrib/oed/glmm/__init__.py @@ -36,5 +36,5 @@ For random effects with a shared covariance matrix, see :meth:`pyro.contrib.oed.glmm.lmer_model`. """ -from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 from pyro.contrib.oed.glmm import guides # noqa: F401 +from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 diff --git a/pyro/contrib/oed/glmm/glmm.py b/pyro/contrib/oed/glmm/glmm.py index 2c418391e3..68507be53f 100644 --- a/pyro/contrib/oed/glmm/glmm.py +++ b/pyro/contrib/oed/glmm/glmm.py @@ -3,17 +3,17 @@ import warnings from collections import OrderedDict -from functools import partial from contextlib import ExitStack +from functools import partial import torch -from torch.nn.functional import softplus from torch.distributions import constraints from torch.distributions.transforms import AffineTransform, SigmoidTransform +from torch.nn.functional import softplus import pyro import pyro.distributions as dist -from pyro.contrib.util import rmv, iter_plates_to_shape +from pyro.contrib.util import iter_plates_to_shape, rmv # TODO read from torch float spec epsilon = torch.tensor(2**-24) diff --git a/pyro/contrib/oed/glmm/guides.py b/pyro/contrib/oed/glmm/guides.py index c71b06415c..d2425adff2 100644 --- a/pyro/contrib/oed/glmm/guides.py +++ b/pyro/contrib/oed/glmm/guides.py @@ -1,17 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from contextlib import ExitStack + import torch from torch import nn -from contextlib import ExitStack - import pyro import pyro.distributions as dist from pyro import poutine -from pyro.contrib.util import ( - tensor_to_dict, rmv, rvv, rtril, lexpand, iter_plates_to_shape -) +from pyro.contrib.util import iter_plates_to_shape, lexpand, rmv, rtril, rvv, tensor_to_dict from pyro.ops.linalg import rinverse diff --git a/pyro/contrib/oed/search.py b/pyro/contrib/oed/search.py index 4bf8eb1816..721f6305c3 100644 --- a/pyro/contrib/oed/search.py +++ b/pyro/contrib/oed/search.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import queue -from pyro.infer.abstract_infer import TracePosterior + import pyro.poutine as poutine +from pyro.infer.abstract_infer import TracePosterior ################################### # Search borrowed from RSA example diff --git a/pyro/contrib/oed/util.py b/pyro/contrib/oed/util.py index d5c315f85a..50774ff0bd 100644 --- a/pyro/contrib/oed/util.py +++ b/pyro/contrib/oed/util.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import math + import torch -from pyro.contrib.util import get_indices from pyro.contrib.oed.glmm import analytic_posterior_cov +from pyro.contrib.util import get_indices from pyro.infer.autoguide.utils import mean_field_entropy diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 1bc28f1caa..39c06748ae 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -4,17 +4,10 @@ from typing import Union from torch import Tensor + from pyro.distributions import TransformedDistribution -from pyro.distributions.transforms import ( - Transform, - AffineTransform, - AbsTransform, - PowerTransform, - ExpTransform, - TanhTransform, - SoftmaxTransform, - SigmoidTransform -) +from pyro.distributions.transforms import (AbsTransform, AffineTransform, ExpTransform, PowerTransform, + SigmoidTransform, SoftmaxTransform, TanhTransform, Transform) class RVMagicOps: diff --git a/pyro/contrib/timeseries/__init__.py b/pyro/contrib/timeseries/__init__.py index f119203f04..517c3f9550 100644 --- a/pyro/contrib/timeseries/__init__.py +++ b/pyro/contrib/timeseries/__init__.py @@ -6,7 +6,7 @@ models useful for forecasting applications. """ from pyro.contrib.timeseries.base import TimeSeriesModel -from pyro.contrib.timeseries.gp import IndependentMaternGP, LinearlyCoupledMaternGP, DependentMaternGP +from pyro.contrib.timeseries.gp import DependentMaternGP, IndependentMaternGP, LinearlyCoupledMaternGP from pyro.contrib.timeseries.lgssm import GenericLGSSM from pyro.contrib.timeseries.lgssmgp import GenericLGSSMWithGPNoiseModel diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index 641c764d19..fc3e61c6c1 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -5,9 +5,9 @@ from torch.distributions import constraints import pyro.distributions as dist -from pyro.distributions.torch_distribution import TorchDistribution from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement +from pyro.distributions.torch_distribution import TorchDistribution class EKFDistribution(TorchDistribution): diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index bd26b344e2..7ea41ad3a7 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -6,6 +6,7 @@ import torch from torch import nn from torch.nn import Parameter + import pyro.distributions as dist from pyro.distributions.util import eye_like diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index cb98c49ccd..8f24ea4360 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod import torch + from pyro.distributions.util import eye_like diff --git a/pyro/contrib/util.py b/pyro/contrib/util.py index 44ff34832b..e250639ca7 100644 --- a/pyro/contrib/util.py +++ b/pyro/contrib/util.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict + import torch + import pyro diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 18b804bf58..0a84d5b530 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -4,40 +4,19 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.affine_beta import AffineBeta from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.coalescent import ( - CoalescentRateLikelihood, - CoalescentTimes, - CoalescentTimesWithRate, -) -from pyro.distributions.conditional import ( - ConditionalDistribution, - ConditionalTransform, - ConditionalTransformedDistribution, - ConditionalTransformModule, -) -from pyro.distributions.conjugate import ( - BetaBinomial, - DirichletMultinomial, - GammaPoisson, -) +from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimes, CoalescentTimesWithRate +from pyro.distributions.conditional import (ConditionalDistribution, ConditionalTransform, + ConditionalTransformedDistribution, ConditionalTransformModule) +from pyro.distributions.conjugate import BetaBinomial, DirichletMultinomial, GammaPoisson from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals -from pyro.distributions.diag_normal_mixture_shared_cov import ( - MixtureOfDiagNormalsSharedCovariance, -) +from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.extended import ExtendedBetaBinomial, ExtendedBinomial from pyro.distributions.folded import FoldedDistribution from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.hmm import ( - DiscreteHMM, - GammaGaussianHMM, - GaussianHMM, - GaussianMRF, - IndependentHMM, - LinearHMM, -) +from pyro.distributions.hmm import DiscreteHMM, GammaGaussianHMM, GaussianHMM, GaussianMRF, IndependentHMM, LinearHMM from pyro.distributions.improper_uniform import ImproperUniform from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJ, LKJCorrCholesky @@ -50,10 +29,8 @@ from pyro.distributions.polya_gamma import TruncatedPolyaGamma from pyro.distributions.projected_normal import ProjectedNormal from pyro.distributions.rejector import Rejector -from pyro.distributions.relaxed_straight_through import ( - RelaxedBernoulliStraightThrough, - RelaxedOneHotCategoricalStraightThrough, -) +from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, + RelaxedOneHotCategoricalStraightThrough) from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import * # noqa F403 @@ -61,17 +38,9 @@ from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution, TorchDistribution from pyro.distributions.torch_transform import ComposeTransformModule, TransformModule from pyro.distributions.unit import Unit -from pyro.distributions.util import ( - enable_validation, - is_validation_enabled, - validation_enabled, -) +from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled from pyro.distributions.von_mises_3d import VonMises3D -from pyro.distributions.zero_inflated import ( - ZeroInflatedDistribution, - ZeroInflatedNegativeBinomial, - ZeroInflatedPoisson, -) +from pyro.distributions.zero_inflated import ZeroInflatedDistribution, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson from . import constraints, kl, transforms diff --git a/pyro/distributions/ordered_logistic.py b/pyro/distributions/ordered_logistic.py index d6d288fafb..c8d4ef459d 100644 --- a/pyro/distributions/ordered_logistic.py +++ b/pyro/distributions/ordered_logistic.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch + from pyro.distributions import constraints from pyro.distributions.torch import Categorical diff --git a/pyro/distributions/projected_normal.py b/pyro/distributions/projected_normal.py index 31e7aa909c..1ac3ed28f6 100644 --- a/pyro/distributions/projected_normal.py +++ b/pyro/distributions/projected_normal.py @@ -5,9 +5,10 @@ import torch +from pyro.ops.tensor_utils import safe_normalize + from . import constraints from .torch_distribution import TorchDistribution -from pyro.ops.tensor_utils import safe_normalize class ProjectedNormal(TorchDistribution): diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index 7add44583c..a3b4fb40f4 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -218,6 +218,7 @@ def _get_cpp_module(): global _cpp_module if _cpp_module is None: import os + from torch.utils.cpp_extension import load path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp") _cpp_module = load(name="cpp_spanning_tree", diff --git a/pyro/distributions/testing/gof.py b/pyro/distributions/testing/gof.py index 7178874a6f..4d544b923c 100644 --- a/pyro/distributions/testing/gof.py +++ b/pyro/distributions/testing/gof.py @@ -55,8 +55,8 @@ def test_my_distribution(): `goftests `_ library. """ -import warnings import math +import warnings import torch diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 0b1b2c4ec5..0632723a26 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -3,11 +3,11 @@ from torch.distributions import biject_to, transform_to from torch.distributions.transforms import * # noqa F403 -from torch.distributions.transforms import __all__ as torch_transforms from torch.distributions.transforms import ComposeTransform, ExpTransform, LowerCholeskyTransform +from torch.distributions.transforms import __all__ as torch_transforms -from ..constraints import (IndependentConstraint, corr_cholesky_constraint, corr_matrix, - ordered_vector, positive_definite, positive_ordered_vector, sphere) +from ..constraints import (IndependentConstraint, corr_cholesky_constraint, corr_matrix, ordered_vector, + positive_definite, positive_ordered_vector, sphere) from ..torch_transform import ComposeTransformModule from .affine_autoregressive import (AffineAutoregressive, ConditionalAffineAutoregressive, affine_autoregressive, conditional_affine_autoregressive) @@ -26,12 +26,12 @@ matrix_exponential) from .neural_autoregressive import (ConditionalNeuralAutoregressive, NeuralAutoregressive, conditional_neural_autoregressive, neural_autoregressive) +from .normalize import Normalize from .ordered import OrderedTransform from .permute import Permute, permute from .planar import ConditionalPlanar, Planar, conditional_planar, planar from .polynomial import Polynomial, polynomial from .radial import ConditionalRadial, Radial, conditional_radial, radial -from .normalize import Normalize from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import (ConditionalSplineAutoregressive, SplineAutoregressive, conditional_spline_autoregressive, spline_autoregressive) diff --git a/pyro/distributions/transforms/ordered.py b/pyro/distributions/transforms/ordered.py index 79aea6a261..95497e45fa 100644 --- a/pyro/distributions/transforms/ordered.py +++ b/pyro/distributions/transforms/ordered.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from pyro.distributions.transforms import Transform + from pyro.distributions import constraints +from pyro.distributions.transforms import Transform class OrderedTransform(Transform): diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 421e5a16fa..5485b7476e 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -17,13 +17,13 @@ from pyro.infer.smcfilter import SMCFilter from pyro.infer.svgd import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.svi import SVI -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO from pyro.infer.trace_mean_field_elbo import JitTraceMeanField_ELBO, TraceMeanField_ELBO from pyro.infer.trace_mmd import Trace_MMD from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 0424949e7d..8891c9cc64 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -10,8 +10,8 @@ import pyro.poutine as poutine from pyro.distributions import Categorical, Empirical -from pyro.poutine.util import site_is_subsample from pyro.ops.stats import waic +from pyro.poutine.util import site_is_subsample class EmpiricalMarginal(Empirical): diff --git a/pyro/infer/mcmc/__init__.py b/pyro/infer/mcmc/__init__.py index e33cbec518..99d241b162 100644 --- a/pyro/infer/mcmc/__init__.py +++ b/pyro/infer/mcmc/__init__.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS __all__ = [ diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index 46497a53b5..c8d41924f6 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -9,7 +9,7 @@ import pyro from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul from pyro.ops.dual_averaging import DualAveraging -from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance +from pyro.ops.welford import WelfordArrowheadCovariance, WelfordCovariance adapt_window = namedtuple("adapt_window", ["start", "end"]) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 857004212f..139d7f4a0d 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -8,9 +8,8 @@ import pyro import pyro.distributions as dist -from pyro.distributions.util import scalar_like from pyro.distributions.testing.fakes import NonreparameterizedNormal - +from pyro.distributions.util import scalar_like from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 372104c53f..89831a93f8 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import traceback as tb import warnings from collections import OrderedDict, defaultdict from functools import partial, reduce from itertools import product -import traceback as tb import torch -from torch.distributions import biject_to from opt_einsum import shared_intermediates +from torch.distributions import biject_to import pyro import pyro.poutine as poutine diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 82617b2a98..dbeee300a3 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from functools import reduce import warnings +from functools import reduce import torch diff --git a/pyro/infer/reparam/neutra.py b/pyro/infer/reparam/neutra.py index 6e444344e0..89f39d062d 100644 --- a/pyro/infer/reparam/neutra.py +++ b/pyro/infer/reparam/neutra.py @@ -8,6 +8,7 @@ from pyro import poutine from pyro.distributions.util import sum_rightmost from pyro.infer.autoguide.guides import AutoContinuous + from .reparam import Reparam diff --git a/pyro/infer/svgd.py b/pyro/infer/svgd.py index ebc526a86d..9e722a745f 100644 --- a/pyro/infer/svgd.py +++ b/pyro/infer/svgd.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from abc import ABCMeta, abstractmethod import math +from abc import ABCMeta, abstractmethod import torch from torch.distributions import biject_to @@ -10,10 +10,10 @@ import pyro from pyro import poutine from pyro.distributions import Delta -from pyro.infer.trace_elbo import Trace_ELBO +from pyro.distributions.util import copy_docs_from from pyro.infer.autoguide.guides import AutoContinuous from pyro.infer.autoguide.initialization import init_to_sample -from pyro.distributions.util import copy_docs_from +from pyro.infer.trace_elbo import Trace_ELBO def vectorize(fn, num_particles, max_plate_nesting): diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index d04210ee56..6eab28596d 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -10,7 +10,7 @@ import pyro.ops.jit from pyro.distributions.util import scale_and_mask from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import is_validation_enabled, torch_item, check_fully_reparametrized +from pyro.infer.util import check_fully_reparametrized, is_validation_enabled, torch_item from pyro.util import warn_if_nan diff --git a/pyro/infer/trace_mmd.py b/pyro/infer/trace_mmd.py index 1cc71992b9..661ff727c2 100644 --- a/pyro/infer/trace_mmd.py +++ b/pyro/infer/trace_mmd.py @@ -9,8 +9,8 @@ import pyro.ops.jit from pyro import poutine from pyro.infer.elbo import ELBO -from pyro.infer.util import torch_item, is_validation_enabled from pyro.infer.enum import get_importance_trace +from pyro.infer.util import is_validation_enabled, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/trace_tail_adaptive_elbo.py b/pyro/infer/trace_tail_adaptive_elbo.py index a69ea6d191..b05251a300 100644 --- a/pyro/infer/trace_tail_adaptive_elbo.py +++ b/pyro/infer/trace_tail_adaptive_elbo.py @@ -6,7 +6,7 @@ import torch from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import is_validation_enabled, check_fully_reparametrized +from pyro.infer.util import check_fully_reparametrized, is_validation_enabled class TraceTailAdaptive_ELBO(Trace_ELBO): diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index e2e483eec7..504060963f 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -1,10 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import queue import warnings import weakref from collections import OrderedDict -import queue import torch from opt_einsum import shared_intermediates diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 852ea6e658..8e7bc7ed6f 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -11,8 +11,7 @@ from pyro.distributions.util import detach, is_identically_zero from pyro.infer import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import (MultiFrameTensor, get_plate_stacks, - is_validation_enabled, torch_backward, torch_item) +from pyro.infer.util import MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_backward, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/tracetmc_elbo.py b/pyro/infer/tracetmc_elbo.py index 51c3ba3b78..f78b277080 100644 --- a/pyro/infer/tracetmc_elbo.py +++ b/pyro/infer/tracetmc_elbo.py @@ -7,7 +7,6 @@ import torch import pyro.poutine as poutine - from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO from pyro.infer.enum import get_importance_trace, iter_discrete_escape, iter_discrete_extend diff --git a/pyro/logger.py b/pyro/logger.py index 5bee771e4a..64a8c70c46 100644 --- a/pyro/logger.py +++ b/pyro/logger.py @@ -3,7 +3,6 @@ import logging - default_format = '%(levelname)s \t %(message)s' log = logging.getLogger("pyro") log.setLevel(logging.INFO) diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index 9641ab4ec2..4714d3095e 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -5,7 +5,6 @@ import torch - SymmArrowhead = namedtuple("SymmArrowhead", ["top", "bottom_diag"]) TriuArrowhead = namedtuple("TriuArrowhead", ["top", "bottom_diag"]) diff --git a/pyro/ops/einsum/torch_map.py b/pyro/ops/einsum/torch_map.py index 6e2832bcff..e4293c1140 100644 --- a/pyro/ops/einsum/torch_map.py +++ b/pyro/ops/einsum/torch_map.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import operator - from functools import reduce from pyro.ops import packed diff --git a/pyro/ops/einsum/torch_sample.py b/pyro/ops/einsum/torch_sample.py index 06c8108886..5420c328ba 100644 --- a/pyro/ops/einsum/torch_sample.py +++ b/pyro/ops/einsum/torch_sample.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import operator - from functools import reduce import pyro.distributions as dist diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index 6611d3a93c..e651b4284e 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -4,8 +4,8 @@ import torch from torch.autograd import grad +from pyro.ops.linalg import eig_3d, rinverse from pyro.util import warn_if_nan -from pyro.ops.linalg import rinverse, eig_3d def newton_step(loss, x, trust_radius=None): diff --git a/pyro/ops/ssm_gp.py b/pyro/ops/ssm_gp.py index eb88ba9d70..89abcb2912 100644 --- a/pyro/ops/ssm_gp.py +++ b/pyro/ops/ssm_gp.py @@ -6,7 +6,7 @@ import torch from torch.distributions import constraints -from pyro.nn import PyroModule, pyro_method, PyroParam +from pyro.nn import PyroModule, PyroParam, pyro_method root_three = math.sqrt(3.0) root_five = math.sqrt(5.0) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 20a5b1d5e1..2ff2a57ae9 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.util import ignore_jit_warnings + from .messenger import Messenger diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 9ed1575857..0b65987932 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -7,6 +7,7 @@ import torch from pyro.util import ignore_jit_warnings + from .messenger import Messenger from .runtime import _DIM_ALLOCATOR diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 39cd234bb9..176c7a772f 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict import sys +from collections import OrderedDict import opt_einsum diff --git a/tests/__init__.py b/tests/__init__.py index 200bfc2d65..4056718ce7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging - import os # create log handler for tests diff --git a/tests/conftest.py b/tests/conftest.py index 2cfcba39d4..699cca55c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ import pyro - torch.set_default_tensor_type(os.environ.get('PYRO_TENSOR_TYPE', 'torch.DoubleTensor')) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 9e9002d507..767b763ff4 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -12,10 +12,10 @@ import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.distributions.transforms import iterated, block_autoregressive +from pyro.distributions.transforms import block_autoregressive, iterated +from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide import (AutoDiagonalNormal, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal) -from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide.guides import AutoNormalizingFlow from tests.common import assert_equal from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/contrib/autoguide/test_mean_field_entropy.py index 9f8c301b32..2f5cd163db 100644 --- a/tests/contrib/autoguide/test_mean_field_entropy.py +++ b/tests/contrib/autoguide/test_mean_field_entropy.py @@ -1,9 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch -import scipy.special as sc import pytest +import scipy.special as sc +import torch import pyro import pyro.distributions as dist diff --git a/tests/contrib/autoname/test_scoping.py b/tests/contrib/autoname/test_scoping.py index d10d6f2d7f..aa7e44bae6 100644 --- a/tests/contrib/autoname/test_scoping.py +++ b/tests/contrib/autoname/test_scoping.py @@ -8,7 +8,7 @@ import pyro import pyro.distributions.torch as dist import pyro.poutine as poutine -from pyro.contrib.autoname import scope, name_count +from pyro.contrib.autoname import name_count, scope logger = logging.getLogger(__name__) diff --git a/tests/contrib/bnn/test_hidden_layer.py b/tests/contrib/bnn/test_hidden_layer.py index 1067cc03f2..c688572d0f 100644 --- a/tests/contrib/bnn/test_hidden_layer.py +++ b/tests/contrib/bnn/test_hidden_layer.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch import torch.nn.functional as F from torch.distributions import Normal -import pytest from pyro.contrib.bnn import HiddenLayer from tests.common import assert_equal diff --git a/tests/contrib/epidemiology/test_quant.py b/tests/contrib/epidemiology/test_quant.py index f9dc53bb64..d2a0edbc1d 100644 --- a/tests/contrib/epidemiology/test_quant.py +++ b/tests/contrib/epidemiology/test_quant.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - import torch from pyro.contrib.epidemiology.util import compute_bin_probs diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 75fff55463..9b273e5e2e 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -7,7 +7,6 @@ import pyroapi import pytest import torch - from torch.autograd import grad from torch.distributions import constraints @@ -17,9 +16,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro + + import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_named_handlers.py b/tests/contrib/funsor/test_named_handlers.py index 48c464daa3..c4c57b7bd5 100644 --- a/tests/contrib/funsor/test_named_handlers.py +++ b/tests/contrib/funsor/test_named_handlers.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict import logging +from collections import OrderedDict import pytest import torch @@ -11,6 +11,7 @@ try: import funsor from funsor.tensor import Tensor + import pyro.contrib.funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger funsor.set_backend("torch") diff --git a/tests/contrib/funsor/test_pyroapi_funsor.py b/tests/contrib/funsor/test_pyroapi_funsor.py index 74dbf972e3..9e050462e9 100644 --- a/tests/contrib/funsor/test_pyroapi_funsor.py +++ b/tests/contrib/funsor/test_pyroapi_funsor.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 try: # triggers backend registration import funsor + import pyro.contrib.funsor # noqa: F401 funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index cc1ab52178..54d4eedaae 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -14,9 +14,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro, pyro_backend + + import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_enum.py b/tests/contrib/funsor/test_valid_models_enum.py index 3ef3241ad2..7df1b23a90 100644 --- a/tests/contrib/funsor/test_valid_models_enum.py +++ b/tests/contrib/funsor/test_valid_models_enum.py @@ -1,10 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict import contextlib import logging import os +from collections import defaultdict from queue import LifoQueue import pytest @@ -19,11 +19,12 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import infer, handlers, pyro, pyro_backend + from pyroapi import handlers, infer, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_plate.py b/tests/contrib/funsor/test_valid_models_plate.py index ed20ee4be4..f5d30fc1b7 100644 --- a/tests/contrib/funsor/test_valid_models_plate.py +++ b/tests/contrib/funsor/test_valid_models_plate.py @@ -12,9 +12,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro + + import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_valid_models_sequential_plate.py b/tests/contrib/funsor/test_valid_models_sequential_plate.py index 1de6af5b08..40eeb79cb6 100644 --- a/tests/contrib/funsor/test_valid_models_sequential_plate.py +++ b/tests/contrib/funsor/test_valid_models_sequential_plate.py @@ -11,9 +11,10 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro + + import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 5048d9a1c4..99e548a3bd 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -3,7 +3,6 @@ import pytest import torch - from pyroapi import pyro_backend from torch.distributions import constraints @@ -13,10 +12,12 @@ try: import funsor from funsor.testing import assert_close - import pyro.contrib.funsor from pyroapi import distributions as dist + + import pyro.contrib.funsor funsor.set_backend("torch") - from pyroapi import handlers, pyro, infer + from pyroapi import handlers, infer, pyro + from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/gp/test_kernels.py b/tests/contrib/gp/test_kernels.py index cc9797ff2c..db1c803786 100644 --- a/tests/contrib/gp/test_kernels.py +++ b/tests/contrib/gp/test_kernels.py @@ -6,9 +6,8 @@ import pytest import torch -from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, - Exponential, Linear, Matern32, Matern52, Periodic, - Polynomial, Product, RationalQuadratic, Sum, +from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, Exponential, Linear, + Matern32, Matern52, Periodic, Polynomial, Product, RationalQuadratic, Sum, VerticalScaling, Warping, WhiteNoise) from tests.common import assert_equal diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index 71fbe663ad..c63c1ce9a5 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -11,7 +11,6 @@ from pyro.contrib.gp.models import VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train - T = namedtuple("TestGPLikelihood", ["model_class", "X", "y", "kernel", "likelihood"]) X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0], [3.0, 4.0, 6.0]]) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index d5afa24ec2..711089025d 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -8,13 +8,12 @@ import torch import pyro.distributions as dist -from pyro.contrib.gp.kernels import Cosine, Matern32, RBF, WhiteNoise +from pyro.contrib.gp.kernels import RBF, Cosine, Matern32, WhiteNoise from pyro.contrib.gp.likelihoods import Gaussian -from pyro.contrib.gp.models import (GPLVM, GPRegression, SparseGPRegression, - VariationalGP, VariationalSparseGP) +from pyro.contrib.gp.models import GPLVM, GPRegression, SparseGPRegression, VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.nn.module import PyroSample from tests.common import assert_equal diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index 57e8e4e7da..aa8df92188 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -3,9 +3,9 @@ import math +import pytest import torch -import pytest from pyro.contrib.oed.eig import EwmaLog from tests.common import assert_equal diff --git a/tests/contrib/oed/test_finite_spaces_eig.py b/tests/contrib/oed/test_finite_spaces_eig.py index 49fc02493a..b6f69234d4 100644 --- a/tests/contrib/oed/test_finite_spaces_eig.py +++ b/tests/contrib/oed/test_finite_spaces_eig.py @@ -1,17 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import pytest +import torch import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import ( - nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, lfire_eig, - donsker_varadhan_eig) +from pyro.contrib.oed.eig import (donsker_varadhan_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, nmc_eig, + posterior_eig, vnmc_eig) from pyro.contrib.util import iter_plates_to_shape - from tests.common import assert_equal try: diff --git a/tests/contrib/oed/test_glmm.py b/tests/contrib/oed/test_glmm.py index 6e855525dd..cb3e95d169 100644 --- a/tests/contrib/oed/test_glmm.py +++ b/tests/contrib/oed/test_glmm.py @@ -8,10 +8,8 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.contrib.oed.glmm import ( - known_covariance_linear_model, group_linear_model, zero_mean_unit_obs_sd_lm, - normal_inverse_gamma_linear_model, logistic_regression_model, sigmoid_model -) +from pyro.contrib.oed.glmm import (group_linear_model, known_covariance_linear_model, logistic_regression_model, + normal_inverse_gamma_linear_model, sigmoid_model, zero_mean_unit_obs_sd_lm) from tests.common import assert_equal diff --git a/tests/contrib/oed/test_linear_models_eig.py b/tests/contrib/oed/test_linear_models_eig.py index 30280cb602..f84ba916e5 100644 --- a/tests/contrib/oed/test_linear_models_eig.py +++ b/tests/contrib/oed/test_linear_models_eig.py @@ -1,20 +1,19 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import torch import pytest +import torch import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.infer import Trace_ELBO +from pyro.contrib.oed.eig import (donsker_varadhan_eig, laplace_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, + nmc_eig, posterior_eig, vnmc_eig) from pyro.contrib.oed.glmm import known_covariance_linear_model +from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from pyro.contrib.oed.util import linear_model_ground_truth -from pyro.contrib.oed.eig import ( - nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, laplace_eig, lfire_eig, - donsker_varadhan_eig) from pyro.contrib.util import rmv, rvv -from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide +from pyro.infer import Trace_ELBO from tests.common import assert_equal diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 5a1a43c194..4c392bc997 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -4,6 +4,7 @@ import math import torch.tensor as tt + from pyro.distributions import Uniform N_SAMPLES = 100 diff --git a/tests/contrib/test_util.py b/tests/contrib/test_util.py index 442ca61bec..60a3115dad 100644 --- a/tests/contrib/test_util.py +++ b/tests/contrib/test_util.py @@ -2,12 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict + import pytest import torch -from pyro.contrib.util import ( - get_indices, tensor_to_dict, rmv, rvv, lexpand, rexpand, rdiag, rtril -) +from pyro.contrib.util import get_indices, lexpand, rdiag, rexpand, rmv, rtril, rvv, tensor_to_dict from tests.common import assert_equal diff --git a/tests/contrib/timeseries/test_gp.py b/tests/contrib/timeseries/test_gp.py index 2698faa01b..e2e39a0aba 100644 --- a/tests/contrib/timeseries/test_gp.py +++ b/tests/contrib/timeseries/test_gp.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import math + +import pytest import torch -from tests.common import assert_equal import pyro -from pyro.contrib.timeseries import (IndependentMaternGP, LinearlyCoupledMaternGP, GenericLGSSM, - GenericLGSSMWithGPNoiseModel, DependentMaternGP) +from pyro.contrib.timeseries import (DependentMaternGP, GenericLGSSM, GenericLGSSMWithGPNoiseModel, IndependentMaternGP, + LinearlyCoupledMaternGP) from pyro.ops.tensor_utils import block_diag_embed -import pytest +from tests.common import assert_equal @pytest.mark.parametrize('model,obs_dim,nu_statedim', [('ssmgp', 3, 1.5), ('ssmgp', 2, 2.5), diff --git a/tests/contrib/timeseries/test_lgssm.py b/tests/contrib/timeseries/test_lgssm.py index f5c2dac137..5b5ed9d339 100644 --- a/tests/contrib/timeseries/test_lgssm.py +++ b/tests/contrib/timeseries/test_lgssm.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch -from tests.common import assert_equal from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel -import pytest +from tests.common import assert_equal @pytest.mark.parametrize('model_class', ['lgssm', 'lgssmgp']) diff --git a/tests/contrib/tracking/test_assignment.py b/tests/contrib/tracking/test_assignment.py index 554a373eb3..9c425dd502 100644 --- a/tests/contrib/tracking/test_assignment.py +++ b/tests/contrib/tracking/test_assignment.py @@ -1,12 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + import pytest import torch from torch.autograd import grad -import logging - import pyro import pyro.distributions as dist from pyro.contrib.tracking.assignment import MarginalAssignment, MarginalAssignmentPersistent, MarginalAssignmentSparse diff --git a/tests/contrib/tracking/test_distributions.py b/tests/contrib/tracking/test_distributions.py index 4c589ac221..fe4c149b49 100644 --- a/tests/contrib/tracking/test_distributions.py +++ b/tests/contrib/tracking/test_distributions.py @@ -1,13 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch from pyro.contrib.tracking.distributions import EKFDistribution from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous -import pytest - @pytest.mark.parametrize('Model', [NcpContinuous, NcvContinuous]) @pytest.mark.parametrize('dim', [2, 3]) diff --git a/tests/contrib/tracking/test_dynamic_models.py b/tests/contrib/tracking/test_dynamic_models.py index 51df52e75d..4f93afe523 100644 --- a/tests/contrib/tracking/test_dynamic_models.py +++ b/tests/contrib/tracking/test_dynamic_models.py @@ -3,8 +3,7 @@ import torch -from pyro.contrib.tracking.dynamic_models import (NcpContinuous, NcvContinuous, - NcvDiscrete, NcpDiscrete) +from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcpDiscrete, NcvContinuous, NcvDiscrete from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_ekf.py b/tests/contrib/tracking/test_ekf.py index 99cec4488c..35db1544d1 100644 --- a/tests/contrib/tracking/test_ekf.py +++ b/tests/contrib/tracking/test_ekf.py @@ -3,10 +3,9 @@ import torch -from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous +from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement - from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index 1d0fca7147..c3401f4114 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -16,7 +16,6 @@ from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton - logger = logging.getLogger(__name__) diff --git a/tests/contrib/tracking/test_measurements.py b/tests/contrib/tracking/test_measurements.py index 38f2afcd3d..373cad0e79 100644 --- a/tests/contrib/tracking/test_measurements.py +++ b/tests/contrib/tracking/test_measurements.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch + from pyro.contrib.tracking.measurements import PositionMeasurement diff --git a/tests/distributions/test_empirical.py b/tests/distributions/test_empirical.py index 7d220aa95e..3f2d4435dd 100644 --- a/tests/distributions/test_empirical.py +++ b/tests/distributions/test_empirical.py @@ -5,7 +5,7 @@ import torch from pyro.distributions.empirical import Empirical -from tests.common import assert_equal, assert_close +from tests.common import assert_close, assert_equal @pytest.mark.parametrize("size", [[], [1], [2, 3]]) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index f02426696e..03737ecf63 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -4,14 +4,12 @@ import logging import math +import pytest import torch -import pytest -from pyro.distributions import MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture -from pyro.distributions import MixtureOfDiagNormals +from pyro.distributions import GaussianScaleMixture, MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 63f3daea57..53857c2791 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch -import pytest from pyro.distributions.transforms import HaarTransform from tests.common import assert_equal diff --git a/tests/distributions/test_ig.py b/tests/distributions/test_ig.py index 215d00ed36..5091e02ad7 100644 --- a/tests/distributions/test_ig.py +++ b/tests/distributions/test_ig.py @@ -3,9 +3,9 @@ import math +import pytest import torch -import pytest from pyro.distributions import Gamma, InverseGamma from tests.common import assert_equal diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index e71336b2af..27cfdc4910 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -6,9 +6,8 @@ from torch import tensor from torch.distributions import kl_divergence -from pyro.distributions.util import broadcast_shape from pyro.distributions.torch import Bernoulli, Normal -from pyro.distributions.util import scale_and_mask +from pyro.distributions.util import broadcast_shape, scale_and_mask from tests.common import assert_equal diff --git a/tests/distributions/test_mvt.py b/tests/distributions/test_mvt.py index a61cb1b3f8..ab2dec09ad 100644 --- a/tests/distributions/test_mvt.py +++ b/tests/distributions/test_mvt.py @@ -4,7 +4,6 @@ import math import pytest - import torch from torch.distributions import Gamma, MultivariateNormal, StudentT diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index f1d92bbb0b..eb04d455fb 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np +import pytest import torch -import pytest from pyro.distributions import AVFMultivariateNormal, MultivariateNormal, OMTMultivariateNormal from tests.common import assert_equal diff --git a/tests/distributions/test_ordered_logistic.py b/tests/distributions/test_ordered_logistic.py index 6c6c3ae409..715db994fb 100644 --- a/tests/distributions/test_ordered_logistic.py +++ b/tests/distributions/test_ordered_logistic.py @@ -6,10 +6,9 @@ import torch.tensor as tt from torch.autograd.functional import jacobian -from pyro.distributions import OrderedLogistic, Normal +from pyro.distributions import Normal, OrderedLogistic from pyro.distributions.transforms import OrderedTransform - # Tests for the OrderedLogistic distribution diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index fabb71b451..66f881bbc5 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -3,9 +3,9 @@ import inspect import io +import pickle import pytest -import pickle import torch import pyro.distributions as dist diff --git a/tests/distributions/test_spanning_tree.py b/tests/distributions/test_spanning_tree.py index 3336aee03b..5cdf85eae0 100644 --- a/tests/distributions/test_spanning_tree.py +++ b/tests/distributions/test_spanning_tree.py @@ -5,9 +5,10 @@ import os from collections import Counter -import pyro import pytest import torch + +import pyro from pyro.distributions.spanning_tree import (NUM_SPANNING_TREES, SpanningTree, find_best_tree, make_complete_graph, sample_tree) from tests.common import assert_equal, xfail_if_not_implemented diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 58f69b641b..6c31b2caeb 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import operator +from functools import partial, reduce from unittest import TestCase import pytest @@ -10,9 +12,6 @@ import pyro.distributions.transforms as T from tests.common import assert_close -from functools import partial, reduce -import operator - pytestmark = pytest.mark.init(rng_seed=123) diff --git a/tests/doctest_fixtures.py b/tests/doctest_fixtures.py index 8be64b2948..0d4e785d84 100644 --- a/tests/doctest_fixtures.py +++ b/tests/doctest_fixtures.py @@ -6,16 +6,15 @@ import torch import pyro -import pyro.contrib.gp as gp import pyro.contrib.autoname.named as named +import pyro.contrib.gp as gp import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import EmpiricalMarginal -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc import HMC, NUTS +from pyro.infer.mcmc.api import MCMC from pyro.params import param_with_module_name - # Fix seed for all doctest runs. pyro.set_rng_seed(0) diff --git a/tests/infer/mcmc/test_adaptation.py b/tests/infer/mcmc/test_adaptation.py index 2fad237d90..675e43525d 100644 --- a/tests/infer/mcmc/test_adaptation.py +++ b/tests/infer/mcmc/test_adaptation.py @@ -4,12 +4,7 @@ import pytest import torch -from pyro.infer.mcmc.adaptation import ( - ArrowheadMassMatrix, - BlockMassMatrix, - WarmupAdapter, - adapt_window, -) +from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix, WarmupAdapter, adapt_window from tests.common import assert_close, assert_equal diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 58bbf0a76e..2f2f9f967d 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -11,9 +11,9 @@ import pyro import pyro.distributions as dist from pyro.infer.mcmc import NUTS -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC -from tests.common import assert_equal, assert_close +from pyro.infer.mcmc.hmc import HMC +from tests.common import assert_close, assert_equal logger = logging.getLogger(__name__) diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index a577da9d40..cb203fcea8 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -11,7 +11,7 @@ import pyro.distributions as dist from pyro import poutine from pyro.infer.mcmc import HMC, NUTS -from pyro.infer.mcmc.api import MCMC, _UnarySampler, _MultiSampler +from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model from pyro.util import optional diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 43630bcb16..106a4510e3 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -10,12 +10,12 @@ import pyro import pyro.distributions as dist -from pyro.infer.autoguide import AutoDelta -from pyro.contrib.conjugate.infer import BetaBinomialPair, collapse_conjugate, GammaPoissonPair, posterior_replay -from pyro.infer import TraceEnum_ELBO, SVI -from pyro.infer.mcmc import ArrowheadMassMatrix, MCMC, NUTS import pyro.optim as optim import pyro.poutine as poutine +from pyro.contrib.conjugate.infer import BetaBinomialPair, GammaPoissonPair, collapse_conjugate, posterior_replay +from pyro.infer import SVI, TraceEnum_ELBO +from pyro.infer.autoguide import AutoDelta +from pyro.infer.mcmc import MCMC, NUTS, ArrowheadMassMatrix from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal diff --git a/tests/infer/test_abstract_infer.py b/tests/infer/test_abstract_infer.py index 483bc4e854..bfacd142a2 100644 --- a/tests/infer/test_abstract_infer.py +++ b/tests/infer/test_abstract_infer.py @@ -8,12 +8,11 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer.autoguide import AutoLaplaceApproximation from pyro.infer import SVI, Trace_ELBO +from pyro.infer.autoguide import AutoLaplaceApproximation from pyro.infer.mcmc import MCMC, NUTS from tests.common import assert_equal - pytestmark = pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 836879d4b6..cf7d342e36 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,8 +8,8 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine +from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal -from pyro.infer import Predictive, SVI, Trace_ELBO from tests.common import assert_close diff --git a/tests/infer/test_svgd.py b/tests/infer/test_svgd.py index 2d10b53b55..c6944dedc2 100644 --- a/tests/infer/test_svgd.py +++ b/tests/infer/test_svgd.py @@ -6,11 +6,9 @@ import pyro import pyro.distributions as dist - -from pyro.infer import SVGD, RBFSteinKernel, IMQSteinKernel -from pyro.optim import Adam +from pyro.infer import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.autoguide.utils import _product - +from pyro.optim import Adam from tests.common import assert_equal diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index cf55ed02ce..35667e56ef 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -15,11 +15,10 @@ from pyro.distributions.testing import fakes from pyro.infer import config_enumerate from pyro.infer.importance import vectorized_importance_weights -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.traceenum_elbo import TraceEnum_ELBO +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 236b460050..7fda7a399f 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -3,12 +3,12 @@ import math +import pytest import torch import pyro import pyro.distributions as dist import pyro.poutine as poutine -import pytest from pyro.infer.importance import psis_diagnostic from pyro.infer.util import MultiFrameTensor from tests.common import assert_equal diff --git a/tests/ops/test_arrowhead.py b/tests/ops/test_arrowhead.py index 13feae5697..2ffa76bf78 100644 --- a/tests/ops/test_arrowhead.py +++ b/tests/ops/test_arrowhead.py @@ -5,7 +5,6 @@ import torch from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul - from tests.common import assert_close diff --git a/tests/ops/test_gamma_gaussian.py b/tests/ops/test_gamma_gaussian.py index 872a42e531..74c018bcc5 100644 --- a/tests/ops/test_gamma_gaussian.py +++ b/tests/ops/test_gamma_gaussian.py @@ -9,12 +9,8 @@ import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -from pyro.ops.gamma_gaussian import ( - GammaGaussian, - gamma_gaussian_tensordot, - matrix_and_mvn_to_gamma_gaussian, - gamma_and_mvn_to_gamma_gaussian, -) +from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot, + matrix_and_mvn_to_gamma_gaussian) from tests.common import assert_close from tests.ops.gamma_gaussian import assert_close_gamma_gaussian, random_gamma, random_gamma_gaussian from tests.ops.gaussian import random_mvn diff --git a/tests/ops/test_newton.py b/tests/ops/test_newton.py index d502cde5d7..d264b3ae35 100644 --- a/tests/ops/test_newton.py +++ b/tests/ops/test_newton.py @@ -11,7 +11,6 @@ from pyro.ops.newton import newton_step from tests.common import assert_equal - logger = logging.getLogger(__name__) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 53fb0213fb..32026d77a2 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -16,8 +16,8 @@ import pyro.optim as optim from pyro.distributions.testing import fakes from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO -from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS Model = namedtuple('TestModel', ['model', 'model_args', 'model_id']) diff --git a/tests/poutine/test_nesting.py b/tests/poutine/test_nesting.py index 6fd6f3614d..ede0456c32 100644 --- a/tests/poutine/test_nesting.py +++ b/tests/poutine/test_nesting.py @@ -4,11 +4,10 @@ import logging import pyro -import pyro.poutine as poutine import pyro.distributions as dist +import pyro.poutine as poutine import pyro.poutine.runtime - logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index f2f4eee025..99fbfb6336 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -6,12 +6,12 @@ import logging import pickle import warnings +from queue import Queue from unittest import TestCase import pytest import torch import torch.nn as nn -from queue import Queue import pyro import pyro.distributions as dist @@ -19,7 +19,7 @@ from pyro.distributions import Bernoulli, Categorical, Normal from pyro.poutine.runtime import _DIM_ALLOCATOR, NonlocalExit from pyro.poutine.util import all_escape, discrete_escape -from tests.common import assert_equal, assert_not_equal, assert_close +from tests.common import assert_close, assert_equal, assert_not_equal logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_trace_struct.py b/tests/poutine/test_trace_struct.py index 9ad7d351a6..4511ccbdf3 100644 --- a/tests/poutine/test_trace_struct.py +++ b/tests/poutine/test_trace_struct.py @@ -8,7 +8,6 @@ from pyro.poutine import Trace from tests.common import assert_equal - EDGE_SETS = [ # 1 # / \ diff --git a/tests/pyroapi/test_pyroapi.py b/tests/pyroapi/test_pyroapi.py index 1fa1673b9f..271c38efab 100644 --- a/tests/pyroapi/test_pyroapi.py +++ b/tests/pyroapi/test_pyroapi.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 diff --git a/tests/test_generic.py b/tests/test_generic.py index a3324b27c0..1ca5c77588 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - -from pyro.generic import handlers, infer, pyro, pyro_backend, ops from pyroapi.testing import MODELS + +from pyro.generic import handlers, infer, ops, pyro, pyro_backend from tests.common import xfail_if_not_implemented pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_primitives.py b/tests/test_primitives.py index d285ad69b5..22f331a450 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch + import pyro import pyro.distributions as dist -import torch pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_util.py b/tests/test_util.py index f8b382b4ec..09ec92f4f7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -import pytest +import pytest import torch + from pyro import util pytestmark = pytest.mark.stage('unit') From 5c77f1642db7af5ea4375cd96583e04c3ea311bd Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 15:18:28 -0500 Subject: [PATCH 38/91] Revert "Make format's automated changes." This reverts commit d2c513c2034763199983dceb78d1c83db5f2889f. --- docs/source/conf.py | 1 + examples/air/air.py | 2 +- examples/air/main.py | 4 +- examples/capture_recapture/cjs.py | 3 +- examples/contrib/autoname/scoping_mixture.py | 7 +-- examples/contrib/funsor/hmm.py | 1 + examples/contrib/gp/sv-dkl.py | 2 +- examples/contrib/oed/ab_test.py | 14 ++--- examples/contrib/oed/gp_bayes_opt.py | 2 +- examples/contrib/timeseries/gp_models.py | 8 +-- examples/cvae/baseline.py | 5 +- examples/cvae/cvae.py | 10 ++-- examples/cvae/main.py | 10 ++-- examples/cvae/mnist.py | 2 +- examples/cvae/util.py | 12 ++--- examples/eight_schools/data.py | 1 + examples/eight_schools/mcmc.py | 2 +- examples/eight_schools/svi.py | 2 +- examples/hmm.py | 2 +- examples/lkj.py | 3 +- examples/minipyro.py | 2 +- examples/mixed_hmm/experiment.py | 9 ++-- examples/mixed_hmm/seal_data.py | 2 + examples/rsa/generics.py | 9 ++-- examples/rsa/hyperbole.py | 9 ++-- examples/rsa/schelling.py | 3 +- examples/rsa/schelling_false.py | 3 +- examples/rsa/search_inference.py | 4 +- examples/rsa/semantic_parsing.py | 7 +-- examples/scanvi/data.py | 6 +-- examples/scanvi/scanvi.py | 14 ++--- examples/sparse_gamma_def.py | 11 ++-- examples/sparse_regression.py | 9 ++-- examples/vae/ss_vae_M2.py | 6 +-- examples/vae/vae.py | 6 +-- examples/vae/vae_comparison.py | 2 +- profiler/hmm.py | 3 +- profiler/profiling_utils.py | 2 +- pyro/contrib/__init__.py | 1 - pyro/contrib/autoname/__init__.py | 3 +- pyro/contrib/bnn/hidden_layer.py | 2 +- pyro/contrib/bnn/utils.py | 3 +- pyro/contrib/conjugate/infer.py | 2 +- pyro/contrib/easyguide/__init__.py | 1 + pyro/contrib/easyguide/easyguide.py | 2 +- pyro/contrib/examples/bart.py | 2 +- pyro/contrib/examples/finance.py | 2 +- .../examples/polyphonic_data_loader.py | 3 +- pyro/contrib/funsor/__init__.py | 12 +++-- pyro/contrib/funsor/handlers/__init__.py | 8 ++- .../contrib/funsor/handlers/enum_messenger.py | 9 ++-- .../funsor/handlers/named_messenger.py | 3 +- pyro/contrib/funsor/handlers/primitives.py | 1 + .../funsor/handlers/replay_messenger.py | 2 +- pyro/contrib/funsor/infer/__init__.py | 2 +- pyro/contrib/funsor/infer/trace_elbo.py | 7 +-- pyro/contrib/funsor/infer/traceenum_elbo.py | 5 +- pyro/contrib/funsor/infer/tracetmc_elbo.py | 6 ++- pyro/contrib/gp/kernels/__init__.py | 7 +-- pyro/contrib/gp/likelihoods/binary.py | 1 + pyro/contrib/gp/likelihoods/gaussian.py | 1 + pyro/contrib/gp/likelihoods/multi_class.py | 1 + pyro/contrib/gp/likelihoods/poisson.py | 1 + pyro/contrib/oed/__init__.py | 2 +- pyro/contrib/oed/eig.py | 9 ++-- pyro/contrib/oed/glmm/__init__.py | 2 +- pyro/contrib/oed/glmm/glmm.py | 6 +-- pyro/contrib/oed/glmm/guides.py | 8 +-- pyro/contrib/oed/search.py | 3 +- pyro/contrib/oed/util.py | 3 +- .../contrib/randomvariable/random_variable.py | 13 +++-- pyro/contrib/timeseries/__init__.py | 2 +- pyro/contrib/tracking/distributions.py | 2 +- pyro/contrib/tracking/dynamic_models.py | 1 - pyro/contrib/tracking/measurements.py | 1 - pyro/contrib/util.py | 2 - pyro/distributions/__init__.py | 51 +++++++++++++++---- pyro/distributions/ordered_logistic.py | 1 - pyro/distributions/projected_normal.py | 3 +- pyro/distributions/spanning_tree.py | 1 - pyro/distributions/testing/gof.py | 2 +- pyro/distributions/transforms/__init__.py | 8 +-- pyro/distributions/transforms/ordered.py | 3 +- pyro/infer/__init__.py | 2 +- pyro/infer/abstract_infer.py | 2 +- pyro/infer/mcmc/__init__.py | 2 +- pyro/infer/mcmc/adaptation.py | 2 +- pyro/infer/mcmc/hmc.py | 3 +- pyro/infer/mcmc/util.py | 4 +- pyro/infer/predictive.py | 2 +- pyro/infer/reparam/neutra.py | 1 - pyro/infer/svgd.py | 6 +-- pyro/infer/trace_mean_field_elbo.py | 2 +- pyro/infer/trace_mmd.py | 2 +- pyro/infer/trace_tail_adaptive_elbo.py | 2 +- pyro/infer/traceenum_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 3 +- pyro/infer/tracetmc_elbo.py | 1 + pyro/logger.py | 1 + pyro/ops/arrowhead.py | 1 + pyro/ops/einsum/torch_map.py | 1 + pyro/ops/einsum/torch_sample.py | 1 + pyro/ops/newton.py | 2 +- pyro/ops/ssm_gp.py | 2 +- pyro/poutine/broadcast_messenger.py | 1 - pyro/poutine/indep_messenger.py | 1 - pyro/poutine/trace_struct.py | 2 +- tests/__init__.py | 1 + tests/conftest.py | 1 + tests/contrib/autoguide/test_inference.py | 4 +- .../autoguide/test_mean_field_entropy.py | 4 +- tests/contrib/autoname/test_scoping.py | 2 +- tests/contrib/bnn/test_hidden_layer.py | 2 +- tests/contrib/epidemiology/test_quant.py | 1 + tests/contrib/funsor/test_enum_funsor.py | 4 +- tests/contrib/funsor/test_named_handlers.py | 3 +- tests/contrib/funsor/test_pyroapi_funsor.py | 2 +- tests/contrib/funsor/test_tmc.py | 3 +- .../contrib/funsor/test_valid_models_enum.py | 5 +- .../contrib/funsor/test_valid_models_plate.py | 3 +- .../test_valid_models_sequential_plate.py | 3 +- .../contrib/funsor/test_vectorized_markov.py | 7 ++- tests/contrib/gp/test_kernels.py | 5 +- tests/contrib/gp/test_likelihoods.py | 1 + tests/contrib/gp/test_models.py | 7 +-- tests/contrib/oed/test_ewma.py | 2 +- tests/contrib/oed/test_finite_spaces_eig.py | 8 +-- tests/contrib/oed/test_glmm.py | 6 ++- tests/contrib/oed/test_linear_models_eig.py | 11 ++-- .../randomvariable/test_random_variable.py | 1 - tests/contrib/test_util.py | 5 +- tests/contrib/timeseries/test_gp.py | 9 ++-- tests/contrib/timeseries/test_lgssm.py | 4 +- tests/contrib/tracking/test_assignment.py | 4 +- tests/contrib/tracking/test_distributions.py | 3 +- tests/contrib/tracking/test_dynamic_models.py | 3 +- tests/contrib/tracking/test_ekf.py | 3 +- tests/contrib/tracking/test_em.py | 1 + tests/contrib/tracking/test_measurements.py | 1 - tests/distributions/test_empirical.py | 2 +- tests/distributions/test_gaussian_mixtures.py | 6 ++- tests/distributions/test_haar.py | 2 +- tests/distributions/test_ig.py | 2 +- tests/distributions/test_mask.py | 3 +- tests/distributions/test_mvt.py | 1 + tests/distributions/test_omt_mvn.py | 2 +- tests/distributions/test_ordered_logistic.py | 3 +- tests/distributions/test_pickle.py | 2 +- tests/distributions/test_spanning_tree.py | 3 +- tests/distributions/test_transforms.py | 5 +- tests/doctest_fixtures.py | 5 +- tests/infer/mcmc/test_adaptation.py | 7 ++- tests/infer/mcmc/test_hmc.py | 4 +- tests/infer/mcmc/test_mcmc_api.py | 2 +- tests/infer/mcmc/test_nuts.py | 8 +-- tests/infer/test_abstract_infer.py | 3 +- tests/infer/test_predictive.py | 2 +- tests/infer/test_svgd.py | 6 ++- tests/infer/test_tmc.py | 3 +- tests/infer/test_util.py | 2 +- tests/ops/test_arrowhead.py | 1 + tests/ops/test_gamma_gaussian.py | 8 ++- tests/ops/test_newton.py | 1 + tests/perf/test_benchmark.py | 2 +- tests/poutine/test_nesting.py | 3 +- tests/poutine/test_poutines.py | 4 +- tests/poutine/test_trace_struct.py | 1 + tests/pyroapi/test_pyroapi.py | 1 + tests/test_generic.py | 4 +- tests/test_primitives.py | 3 +- tests/test_util.py | 3 +- 171 files changed, 387 insertions(+), 294 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 597d1c2ac6..d39a267cb9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,6 +6,7 @@ import sphinx_rtd_theme + # import pkg_resources # -*- coding: utf-8 -*- diff --git a/examples/air/air.py b/examples/air/air.py index 985e2853bf..a9b8958d0f 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -14,10 +14,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from modules import MLP, Decoder, Encoder, Identity, Predict import pyro import pyro.distributions as dist +from modules import MLP, Decoder, Encoder, Identity, Predict # Default prior success probability for z_pres. diff --git a/examples/air/main.py b/examples/air/main.py index 0506a32b9e..43d30d3a02 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -19,15 +19,15 @@ import numpy as np import torch import visdom -from air import AIR, latents_to_tensor -from viz import draw_many, tensor_to_objs import pyro import pyro.contrib.examples.multi_mnist as multi_mnist import pyro.optim as optim import pyro.poutine as poutine +from air import AIR, latents_to_tensor from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO +from viz import draw_many, tensor_to_objs def count_accuracy(X, true_counts, air, batch_size): diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index fa868899d5..65b709afca 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -39,10 +39,11 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO from pyro.infer.autoguide import AutoDiagonalNormal +from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO from pyro.optim import Adam + """ Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival probability phi; diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index d39b65eb6d..363d3ace53 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -2,15 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - import torch from torch.distributions import constraints import pyro -import pyro.distributions as dist import pyro.optim +import pyro.distributions as dist + +from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO + from pyro.contrib.autoname import scope -from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate def model(K, data): diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 98fca8eca8..017a9aa2e8 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -64,6 +64,7 @@ from pyroapi import distributions as dist from pyroapi import handlers, infer, optim, pyro, pyro_backend + logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 19e8245c03..f07c4052e4 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -39,7 +39,7 @@ import pyro import pyro.contrib.gp as gp import pyro.infer as infer -from pyro.contrib.examples.util import get_data_directory, get_data_loader +from pyro.contrib.examples.util import get_data_loader, get_data_directory class CNN(nn.Module): diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index f417d78ca8..522dc44ad3 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -3,18 +3,20 @@ import argparse from functools import partial - -import numpy as np import torch -from gp_bayes_opt import GPBayesOptimizer from torch.distributions import constraints +import numpy as np import pyro -import pyro.contrib.gp as gp from pyro import optim -from pyro.contrib.oed.eig import vi_eig -from pyro.contrib.oed.glmm import analytic_posterior_cov, group_assignment_matrix, zero_mean_unit_obs_sd_lm from pyro.infer import TraceEnum_ELBO +from pyro.contrib.oed.eig import vi_eig +import pyro.contrib.gp as gp +from pyro.contrib.oed.glmm import ( + zero_mean_unit_obs_sd_lm, group_assignment_matrix, analytic_posterior_cov +) + +from gp_bayes_opt import GPBayesOptimizer """ Example builds on the Bayesian regression tutorial [1]. It demonstrates how diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 6132dee48a..3c114c9bcf 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -7,8 +7,8 @@ from torch.distributions import transform_to import pyro.contrib.gp as gp -import pyro.optim from pyro.infer import TraceEnum_ELBO +import pyro.optim class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index e2ab952b90..259ef99180 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -1,16 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import argparse -from os.path import exists -from urllib.request import urlopen - import numpy as np import torch import pyro from pyro.contrib.timeseries import IndependentMaternGP, LinearlyCoupledMaternGP +import argparse +from os.path import exists +from urllib.request import urlopen + # download dataset from UCI archive def download_data(): diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index 23e1591016..cb5d279445 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from pathlib import Path - import numpy as np +from pathlib import Path +from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F -from tqdm import tqdm class BaselineNet(nn.Module): diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index f499aaf452..fb792cf4d3 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -1,16 +1,14 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path - import numpy as np -import torch -import torch.nn as nn -from tqdm import tqdm - +from pathlib import Path import pyro import pyro.distributions as dist from pyro.infer import SVI, Trace_ELBO +import torch +import torch.nn as nn +from tqdm import tqdm class Encoder(nn.Module): diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 73b85ac92d..2056b4d3a6 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -2,14 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - -import baseline -import cvae import pandas as pd -import torch -from util import generate_table, get_data, visualize - import pyro +import torch +import baseline +import cvae +from util import get_data, visualize, generate_table def main(args): diff --git a/examples/cvae/mnist.py b/examples/cvae/mnist.py index 12dd7409f2..a98c667081 100644 --- a/examples/cvae/mnist.py +++ b/examples/cvae/mnist.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset, DataLoader from torchvision.datasets import MNIST from torchvision.transforms import Compose, functional diff --git a/examples/cvae/util.py b/examples/cvae/util.py index 87650298ef..e578085946 100644 --- a/examples/cvae/util.py +++ b/examples/cvae/util.py @@ -1,19 +1,17 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path - -import matplotlib.pyplot as plt import numpy as np +import matplotlib.pyplot as plt import pandas as pd +from pathlib import Path +from pyro.infer import Predictive, Trace_ELBO import torch -from baseline import MaskedBCELoss -from mnist import get_data from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm - -from pyro.infer import Predictive, Trace_ELBO +from baseline import MaskedBCELoss +from mnist import get_data def imshow(inp, image_path=None): diff --git a/examples/eight_schools/data.py b/examples/eight_schools/data.py index 56158fa36e..39529e798a 100644 --- a/examples/eight_schools/data.py +++ b/examples/eight_schools/data.py @@ -3,6 +3,7 @@ import torch + J = 8 y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index d34d901e53..62b184ea85 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -4,9 +4,9 @@ import argparse import logging -import data import torch +import data import pyro import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index b6f35f254b..a06a801768 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -5,11 +5,11 @@ import logging import torch -from data import J, sigma, y from torch.distributions import constraints, transforms import pyro import pyro.distributions as dist +from data import J, sigma, y from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam diff --git a/examples/hmm.py b/examples/hmm.py index 68395da7c1..7c706ed558 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -47,8 +47,8 @@ import pyro.contrib.examples.polyphonic_data_loader as poly import pyro.distributions as dist from pyro import poutine -from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO from pyro.infer.autoguide import AutoDelta +from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO from pyro.ops.indexing import Vindex from pyro.optim import Adam from pyro.util import ignore_jit_warnings diff --git a/examples/lkj.py b/examples/lkj.py index dd54c5ed43..fb437358bd 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - import torch import pyro import pyro.distributions as dist -from pyro.infer.mcmc import NUTS from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc import NUTS """ This simple example is intended to demonstrate how to use an LKJ prior with diff --git a/examples/minipyro.py b/examples/minipyro.py index 084ddfdca5..33ab6eba01 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -11,8 +11,8 @@ import torch -# We use the pyro.generic interface to support dynamic choice of backend. from pyro.generic import distributions as dist +# We use the pyro.generic interface to support dynamic choice of backend. from pyro.generic import infer, ops, optim, pyro, pyro_backend diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index bd4bd02e6a..65584c6769 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -2,19 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import functools -import json import os +import json import uuid +import functools import torch -from model import guide_generic, model_generic -from seal_data import prepare_seal import pyro import pyro.poutine as poutine from pyro.infer import TraceEnum_ELBO +from model import model_generic, guide_generic +from seal_data import prepare_seal + def aic_num_parameters(model, guide=None): """ diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py index 609fc69da3..390201a8d9 100644 --- a/examples/mixed_hmm/seal_data.py +++ b/examples/mixed_hmm/seal_data.py @@ -5,8 +5,10 @@ from urllib.request import urlopen import pandas as pd + import torch + MISSING = 1e-6 diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 4ca04bc349..9a13a4e54a 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -9,17 +9,18 @@ [1] https://gscontras.github.io/probLang/chapters/07-generics.html """ +import torch + import argparse -import collections import numbers - -import torch -from search_inference import HashingMarginal, Search, memoize +import collections import pyro import pyro.distributions as dist import pyro.poutine as poutine +from search_inference import HashingMarginal, memoize, Search + torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index 93dd409579..a77b01870a 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -7,16 +7,17 @@ Taken from: https://gscontras.github.io/probLang/chapters/03-nonliteral.html """ -import argparse -import collections - import torch -from search_inference import HashingMarginal, Search, memoize + +import collections +import argparse import pyro import pyro.distributions as dist import pyro.poutine as poutine +from search_inference import HashingMarginal, memoize, Search + torch.set_default_dtype(torch.float64) # double precision for numerical stability diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index a111b5d2bc..886eb405b6 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -11,13 +11,12 @@ Taken from: http://forestdb.org/models/schelling.html """ import argparse - import torch -from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli +from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index f855a010e2..998e3b70cb 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -12,13 +12,12 @@ Taken from: http://forestdb.org/models/schelling-falsebelief.html """ import argparse - import torch -from search_inference import HashingMarginal, Search import pyro import pyro.poutine as poutine from pyro.distributions import Bernoulli +from search_inference import HashingMarginal, Search def location(preference): diff --git a/examples/rsa/search_inference.py b/examples/rsa/search_inference.py index 7e2cb8e142..14a49766f1 100644 --- a/examples/rsa/search_inference.py +++ b/examples/rsa/search_inference.py @@ -8,10 +8,10 @@ """ import collections -import functools -import queue import torch +import queue +import functools import pyro.distributions as dist import pyro.poutine as poutine diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index eb188b0a54..15ffe901aa 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -7,15 +7,16 @@ Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html """ +import torch + import argparse import collections -import torch -from search_inference import BestFirstSearch, HashingMarginal, memoize - import pyro import pyro.distributions as dist +from search_inference import HashingMarginal, BestFirstSearch, memoize + torch.set_default_dtype(torch.float64) diff --git a/examples/scanvi/data.py b/examples/scanvi/data.py index 429883d1a3..690eab0717 100644 --- a/examples/scanvi/data.py +++ b/examples/scanvi/data.py @@ -8,11 +8,11 @@ """ import math - import numpy as np +from scipy import sparse + import torch import torch.nn as nn -from scipy import sparse class BatchDataLoader(object): @@ -122,8 +122,8 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): return BatchDataLoader(X, Y, batch_size), num_genes, 2.0, 1.0, None - import scanpy as sc import scvi + import scanpy as sc adata = scvi.data.purified_pbmc_dataset(subset_datasets=["regulatory_t", "naive_t", "memory_t", "naive_cytotoxic"]) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index d4928ddf27..928785827a 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -19,22 +19,25 @@ import argparse -import matplotlib.pyplot as plt import numpy as np + import torch import torch.nn as nn -from data import get_data -from matplotlib.patches import Patch from torch.distributions import constraints -from torch.nn.functional import softmax, softplus +from torch.nn.functional import softplus, softmax from torch.optim import Adam import pyro import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.util import broadcast_shape -from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate from pyro.optim import MultiStepLR +from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO + +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +from data import get_data # Helper for making fully-connected neural networks @@ -297,7 +300,6 @@ def main(args): # Now that we're done training we'll inspect the latent representations we've learned if args.plot and args.dataset == 'pbmc': import scanpy as sc - # Compute latent representation (z2_loc) for each cell in the dataset latent_rep = scanvi.z2l_encoder(dataloader.data_x)[0] diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index fddb1b1bc7..5de2bf6f1d 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -20,16 +20,19 @@ import numpy as np import torch -import wget from torch.nn.functional import softplus import pyro import pyro.optim as optim -from pyro.contrib.easyguide import EasyGuide +import wget + from pyro.contrib.examples.util import get_data_directory -from pyro.distributions import Gamma, Normal, Poisson +from pyro.distributions import Gamma, Poisson, Normal from pyro.infer import SVI, TraceMeanField_ELBO -from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible +from pyro.infer.autoguide import AutoDiagonalNormal +from pyro.infer.autoguide import init_to_feasible +from pyro.contrib.easyguide import EasyGuide + torch.set_default_tensor_type('torch.FloatTensor') pyro.util.set_rng_seed(0) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 63cffbc7c6..da807cfcb7 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -2,17 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import math import numpy as np import torch -from torch.optim import Adam +import math import pyro import pyro.distributions as dist from pyro import poutine +from pyro.infer.autoguide import AutoDelta from pyro.infer import Trace_ELBO -from pyro.infer.autoguide import AutoDelta, init_to_median +from pyro.infer.autoguide import init_to_median + +from torch.optim import Adam + """ We demonstrate how to do sparse linear regression using a variant of the diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index a064140ab5..265097efc4 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -5,9 +5,6 @@ import torch import torch.nn as nn -from utils.custom_mlp import MLP, Exp -from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders -from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae from visdom import Visdom import pyro @@ -15,6 +12,9 @@ from pyro.contrib.examples.util import print_and_log from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate from pyro.optim import Adam +from utils.custom_mlp import MLP, Exp +from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders +from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae class SSVAE(nn.Module): diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 95cee4e66b..d4f54e515e 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -7,14 +7,14 @@ import torch import torch.nn as nn import visdom -from utils.mnist_cached import MNISTCached as MNIST -from utils.mnist_cached import setup_data_loaders -from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples import pyro import pyro.distributions as dist from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam +from utils.mnist_cached import MNISTCached as MNIST +from utils.mnist_cached import setup_data_loaders +from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples # define the PyTorch module that parameterizes the diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index 9ee1b52704..abefcb0b03 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -10,13 +10,13 @@ import torch.nn as nn from torch.nn import functional from torchvision.utils import save_image -from utils.mnist_cached import DATA_DIR, RESULTS_DIR import pyro from pyro.contrib.examples import util from pyro.distributions import Bernoulli, Normal from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam +from utils.mnist_cached import DATA_DIR, RESULTS_DIR """ Comparison of VAE implementation in PyTorch and Pyro. This example can be diff --git a/profiler/hmm.py b/profiler/hmm.py index 1825c82c20..4308c3df56 100644 --- a/profiler/hmm.py +++ b/profiler/hmm.py @@ -8,12 +8,13 @@ import subprocess import sys from collections import defaultdict -from os.path import abspath, join +from os.path import join, abspath from numpy import median from pyro.util import timed + EXAMPLES_DIR = join(abspath(__file__), os.pardir, os.pardir, "examples") diff --git a/profiler/profiling_utils.py b/profiler/profiling_utils.py index aee4f9b564..8375132eb2 100644 --- a/profiler/profiling_utils.py +++ b/profiler/profiling_utils.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import cProfile +from io import StringIO import functools import os import pstats import timeit from contextlib import contextmanager -from io import StringIO from prettytable import ALL, PrettyTable diff --git a/pyro/contrib/__init__.py b/pyro/contrib/__init__.py index 045ec5435f..3f14bd1862 100644 --- a/pyro/contrib/__init__.py +++ b/pyro/contrib/__init__.py @@ -25,7 +25,6 @@ try: import funsor as funsor_ # noqa: F401 - from pyro.contrib import funsor __all__ += ["funsor"] except ImportError: diff --git a/pyro/contrib/autoname/__init__.py b/pyro/contrib/autoname/__init__.py index 9a379fb2ff..b93f5d7db7 100644 --- a/pyro/contrib/autoname/__init__.py +++ b/pyro/contrib/autoname/__init__.py @@ -6,8 +6,9 @@ generating unique, semantically meaningful names for sample sites. """ from pyro.contrib.autoname import named +from pyro.contrib.autoname.scoping import scope, name_count from pyro.contrib.autoname.autoname import autoname, sample -from pyro.contrib.autoname.scoping import name_count, scope + __all__ = [ "named", diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index cc97b051fa..6a4f679e29 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.nn.functional as F from torch.distributions.utils import lazy_property +import torch.nn.functional as F from pyro.contrib.bnn.utils import adjoin_ones_vector from pyro.distributions.torch_distribution import TorchDistribution diff --git a/pyro/contrib/bnn/utils.py b/pyro/contrib/bnn/utils.py index 794f66f984..ec2f33623a 100644 --- a/pyro/contrib/bnn/utils.py +++ b/pyro/contrib/bnn/utils.py @@ -1,9 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import math - import torch +import math def xavier_uniform(D_in, D_out): diff --git a/pyro/contrib/conjugate/infer.py b/pyro/contrib/conjugate/infer.py index 0c815c0126..23a3fe791e 100644 --- a/pyro/contrib/conjugate/infer.py +++ b/pyro/contrib/conjugate/infer.py @@ -6,8 +6,8 @@ import torch import pyro.distributions as dist -from pyro import poutine from pyro.distributions.util import sum_leftmost +from pyro import poutine from pyro.poutine.messenger import Messenger from pyro.poutine.util import site_is_subsample diff --git a/pyro/contrib/easyguide/__init__.py b/pyro/contrib/easyguide/__init__.py index d26c63c9cf..9e2577841f 100644 --- a/pyro/contrib/easyguide/__init__.py +++ b/pyro/contrib/easyguide/__init__.py @@ -3,6 +3,7 @@ from pyro.contrib.easyguide.easyguide import EasyGuide, easy_guide + __all__ = [ "EasyGuide", "easy_guide", diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index 55535ae72d..fbc204466b 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -14,8 +14,8 @@ import pyro.poutine as poutine import pyro.poutine.runtime as runtime from pyro.distributions.util import broadcast_shape, sum_rightmost -from pyro.infer.autoguide.guides import prototype_hide_fn from pyro.infer.autoguide.initialization import InitMessenger +from pyro.infer.autoguide.guides import prototype_hide_fn from pyro.nn.module import PyroModule, PyroParam diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index 0398ad137d..0d89fee5fc 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -14,7 +14,7 @@ import torch -from pyro.contrib.examples.util import _mkdir_p, get_data_directory +from pyro.contrib.examples.util import get_data_directory, _mkdir_p DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/finance.py b/pyro/contrib/examples/finance.py index c40a0b55e8..03572c0289 100644 --- a/pyro/contrib/examples/finance.py +++ b/pyro/contrib/examples/finance.py @@ -6,7 +6,7 @@ import pandas as pd -from pyro.contrib.examples.util import _mkdir_p, get_data_directory +from pyro.contrib.examples.util import get_data_directory, _mkdir_p DATA = get_data_directory(__file__) diff --git a/pyro/contrib/examples/polyphonic_data_loader.py b/pyro/contrib/examples/polyphonic_data_loader.py index 491ae0517f..132c6d953d 100644 --- a/pyro/contrib/examples/polyphonic_data_loader.py +++ b/pyro/contrib/examples/polyphonic_data_loader.py @@ -17,9 +17,9 @@ """ import os -import pickle from collections import namedtuple from urllib.request import urlopen +import pickle import torch import torch.nn as nn @@ -27,6 +27,7 @@ from pyro.contrib.examples.util import get_data_directory + dset = namedtuple("dset", ["name", "url", "filename"]) JSB_CHORALES = dset("jsb_chorales", diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index dcb9355e5e..30a23d5ca2 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -3,12 +3,14 @@ import pyroapi -from pyro.contrib.funsor.handlers import condition, do, markov # noqa: F401 -from pyro.contrib.funsor.handlers import plate as _plate -from pyro.contrib.funsor.handlers import vectorized_markov # noqa: F401 +from pyro.primitives import ( # noqa: F401 + clear_param_store, deterministic, enable_validation, factor, get_param_store, + module, param, random_module, sample, set_rng_seed, subsample, +) + from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor # noqa: F401 -from pyro.primitives import (clear_param_store, deterministic, enable_validation, factor, get_param_store, # noqa: F401 - module, param, random_module, sample, set_rng_seed, subsample) +from pyro.contrib.funsor.handlers import condition, do, markov, vectorized_markov # noqa: F401 +from pyro.contrib.funsor.handlers import plate as _plate def plate(*args, **kwargs): diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 724ec29c83..a98be1de94 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -1,16 +1,20 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.poutine import (block, condition, do, escape, infer_config, mask, reparam, scale, seed, # noqa: F401 - uncondition) from pyro.poutine.handlers import _make_handler +from pyro.poutine import ( # noqa: F401 + block, condition, do, escape, infer_config, + mask, reparam, scale, seed, uncondition, +) + from .enum_messenger import EnumMessenger, queue # noqa: F401 from .named_messenger import MarkovMessenger, NamedMessenger from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger from .replay_messenger import ReplayMessenger from .trace_messenger import TraceMessenger + _msngrs = [ EnumMessenger, MarkovMessenger, diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index befbeb2014..15caf49078 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -9,17 +9,18 @@ import math from collections import OrderedDict -import funsor import torch +import funsor import pyro.poutine.runtime import pyro.poutine.util -from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger +from pyro.poutine.escape_messenger import EscapeMessenger +from pyro.poutine.subsample_messenger import _Subsample + from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor +from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger -from pyro.poutine.escape_messenger import EscapeMessenger -from pyro.poutine.subsample_messenger import _Subsample funsor.set_backend("torch") diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index e7cf18ffbc..fb7667fab3 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -4,9 +4,10 @@ from collections import OrderedDict from contextlib import ExitStack -from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame from pyro.poutine.reentrant_messenger import ReentrantMessenger +from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType, StackFrame + class NamedMessenger(ReentrantMessenger): """ diff --git a/pyro/contrib/funsor/handlers/primitives.py b/pyro/contrib/funsor/handlers/primitives.py index 0b7a4c4edb..3d8815eff0 100644 --- a/pyro/contrib/funsor/handlers/primitives.py +++ b/pyro/contrib/funsor/handlers/primitives.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.poutine.runtime + from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimRequest, DimType diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index ae672d2dd4..2389941049 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from pyro.contrib.funsor.handlers.primitives import to_data from pyro.poutine.replay_messenger import ReplayMessenger as OrigReplayMessenger +from pyro.contrib.funsor.handlers.primitives import to_data class ReplayMessenger(OrigReplayMessenger): diff --git a/pyro/contrib/funsor/infer/__init__.py b/pyro/contrib/funsor/infer/__init__.py index 55f260e6a9..4525e2cef5 100644 --- a/pyro/contrib/funsor/infer/__init__.py +++ b/pyro/contrib/funsor/infer/__init__.py @@ -5,5 +5,5 @@ from .elbo import ELBO # noqa: F401 from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401 -from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 +from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 686926b772..1912edca11 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -5,13 +5,14 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import Trace_ELBO as _OrigTrace_ELBO + from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer import config_enumerate -from pyro.distributions.util import copy_docs_from -from pyro.infer import Trace_ELBO as _OrigTrace_ELBO -from .elbo import ELBO, Jit_ELBO +from .elbo import Jit_ELBO, ELBO from .traceenum_elbo import terms_from_trace diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index a63a8cabc6..14386c69c0 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -5,11 +5,12 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO + from pyro.contrib.funsor import to_data, to_funsor from pyro.contrib.funsor.handlers import enum, plate, replay, trace from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceEnum_ELBO as _OrigTraceEnum_ELBO def terms_from_trace(tr): diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index aae66ce5c0..8a714d8c03 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -5,12 +5,14 @@ import funsor +from pyro.distributions.util import copy_docs_from +from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO + from pyro.contrib.funsor import to_data from pyro.contrib.funsor.handlers import enum, plate, replay, trace + from pyro.contrib.funsor.infer.elbo import ELBO, Jit_ELBO from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace -from pyro.distributions.util import copy_docs_from -from pyro.infer import TraceTMC_ELBO as _OrigTraceTMC_ELBO @copy_docs_from(_OrigTraceTMC_ELBO) diff --git a/pyro/contrib/gp/kernels/__init__.py b/pyro/contrib/gp/kernels/__init__.py index c36ddd37fb..9874e73c17 100644 --- a/pyro/contrib/gp/kernels/__init__.py +++ b/pyro/contrib/gp/kernels/__init__.py @@ -4,9 +4,10 @@ from pyro.contrib.gp.kernels.brownian import Brownian from pyro.contrib.gp.kernels.coregionalize import Coregionalize from pyro.contrib.gp.kernels.dot_product import DotProduct, Linear, Polynomial -from pyro.contrib.gp.kernels.isotropic import RBF, Exponential, Isotropy, Matern32, Matern52, RationalQuadratic -from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, Transforming, VerticalScaling, - Warping) +from pyro.contrib.gp.kernels.isotropic import (RBF, Exponential, Isotropy, Matern32, Matern52, + RationalQuadratic) +from pyro.contrib.gp.kernels.kernel import (Combination, Exponent, Kernel, Product, Sum, + Transforming, VerticalScaling, Warping) from pyro.contrib.gp.kernels.periodic import Cosine, Periodic from pyro.contrib.gp.kernels.static import Constant, WhiteNoise diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index ef417f9e22..3041f5e92e 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/gaussian.py b/pyro/contrib/gp/likelihoods/gaussian.py index b1b65ff95c..cb5a15d8c7 100644 --- a/pyro/contrib/gp/likelihoods/gaussian.py +++ b/pyro/contrib/gp/likelihoods/gaussian.py @@ -6,6 +6,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood from pyro.nn.module import PyroParam diff --git a/pyro/contrib/gp/likelihoods/multi_class.py b/pyro/contrib/gp/likelihoods/multi_class.py index ed8463f8bd..9ff69e81f1 100644 --- a/pyro/contrib/gp/likelihoods/multi_class.py +++ b/pyro/contrib/gp/likelihoods/multi_class.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/gp/likelihoods/poisson.py b/pyro/contrib/gp/likelihoods/poisson.py index 8abed6fd2a..48916e0634 100644 --- a/pyro/contrib/gp/likelihoods/poisson.py +++ b/pyro/contrib/gp/likelihoods/poisson.py @@ -5,6 +5,7 @@ import pyro import pyro.distributions as dist + from pyro.contrib.gp.likelihoods.likelihood import Likelihood diff --git a/pyro/contrib/oed/__init__.py b/pyro/contrib/oed/__init__.py index 3afd3a440d..006c57a7a1 100644 --- a/pyro/contrib/oed/__init__.py +++ b/pyro/contrib/oed/__init__.py @@ -67,7 +67,7 @@ def model(design): """ -from pyro.contrib.oed import eig, search +from pyro.contrib.oed import search, eig __all__ = [ "search", diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index 8d6c7ae22b..7faec28aa6 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -1,18 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import torch import math import warnings -import torch - import pyro from pyro import poutine +from pyro.infer.autoguide.utils import mean_field_entropy from pyro.contrib.oed.search import Search +from pyro.infer import EmpiricalMarginal, Importance, SVI +from pyro.util import torch_isnan, torch_isinf from pyro.contrib.util import lexpand -from pyro.infer import SVI, EmpiricalMarginal, Importance -from pyro.infer.autoguide.utils import mean_field_entropy -from pyro.util import torch_isinf, torch_isnan __all__ = [ "laplace_eig", diff --git a/pyro/contrib/oed/glmm/__init__.py b/pyro/contrib/oed/glmm/__init__.py index c17c221213..f9d75643ca 100644 --- a/pyro/contrib/oed/glmm/__init__.py +++ b/pyro/contrib/oed/glmm/__init__.py @@ -36,5 +36,5 @@ For random effects with a shared covariance matrix, see :meth:`pyro.contrib.oed.glmm.lmer_model`. """ -from pyro.contrib.oed.glmm import guides # noqa: F401 from pyro.contrib.oed.glmm.glmm import * # noqa: F403,F401 +from pyro.contrib.oed.glmm import guides # noqa: F401 diff --git a/pyro/contrib/oed/glmm/glmm.py b/pyro/contrib/oed/glmm/glmm.py index 68507be53f..2c418391e3 100644 --- a/pyro/contrib/oed/glmm/glmm.py +++ b/pyro/contrib/oed/glmm/glmm.py @@ -3,17 +3,17 @@ import warnings from collections import OrderedDict -from contextlib import ExitStack from functools import partial +from contextlib import ExitStack import torch +from torch.nn.functional import softplus from torch.distributions import constraints from torch.distributions.transforms import AffineTransform, SigmoidTransform -from torch.nn.functional import softplus import pyro import pyro.distributions as dist -from pyro.contrib.util import iter_plates_to_shape, rmv +from pyro.contrib.util import rmv, iter_plates_to_shape # TODO read from torch float spec epsilon = torch.tensor(2**-24) diff --git a/pyro/contrib/oed/glmm/guides.py b/pyro/contrib/oed/glmm/guides.py index d2425adff2..c71b06415c 100644 --- a/pyro/contrib/oed/glmm/guides.py +++ b/pyro/contrib/oed/glmm/guides.py @@ -1,15 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from contextlib import ExitStack - import torch from torch import nn +from contextlib import ExitStack + import pyro import pyro.distributions as dist from pyro import poutine -from pyro.contrib.util import iter_plates_to_shape, lexpand, rmv, rtril, rvv, tensor_to_dict +from pyro.contrib.util import ( + tensor_to_dict, rmv, rvv, rtril, lexpand, iter_plates_to_shape +) from pyro.ops.linalg import rinverse diff --git a/pyro/contrib/oed/search.py b/pyro/contrib/oed/search.py index 721f6305c3..4bf8eb1816 100644 --- a/pyro/contrib/oed/search.py +++ b/pyro/contrib/oed/search.py @@ -2,9 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import queue - -import pyro.poutine as poutine from pyro.infer.abstract_infer import TracePosterior +import pyro.poutine as poutine ################################### # Search borrowed from RSA example diff --git a/pyro/contrib/oed/util.py b/pyro/contrib/oed/util.py index 50774ff0bd..d5c315f85a 100644 --- a/pyro/contrib/oed/util.py +++ b/pyro/contrib/oed/util.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import math - import torch -from pyro.contrib.oed.glmm import analytic_posterior_cov from pyro.contrib.util import get_indices +from pyro.contrib.oed.glmm import analytic_posterior_cov from pyro.infer.autoguide.utils import mean_field_entropy diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index 39c06748ae..1bc28f1caa 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -4,10 +4,17 @@ from typing import Union from torch import Tensor - from pyro.distributions import TransformedDistribution -from pyro.distributions.transforms import (AbsTransform, AffineTransform, ExpTransform, PowerTransform, - SigmoidTransform, SoftmaxTransform, TanhTransform, Transform) +from pyro.distributions.transforms import ( + Transform, + AffineTransform, + AbsTransform, + PowerTransform, + ExpTransform, + TanhTransform, + SoftmaxTransform, + SigmoidTransform +) class RVMagicOps: diff --git a/pyro/contrib/timeseries/__init__.py b/pyro/contrib/timeseries/__init__.py index 517c3f9550..f119203f04 100644 --- a/pyro/contrib/timeseries/__init__.py +++ b/pyro/contrib/timeseries/__init__.py @@ -6,7 +6,7 @@ models useful for forecasting applications. """ from pyro.contrib.timeseries.base import TimeSeriesModel -from pyro.contrib.timeseries.gp import DependentMaternGP, IndependentMaternGP, LinearlyCoupledMaternGP +from pyro.contrib.timeseries.gp import IndependentMaternGP, LinearlyCoupledMaternGP, DependentMaternGP from pyro.contrib.timeseries.lgssm import GenericLGSSM from pyro.contrib.timeseries.lgssmgp import GenericLGSSMWithGPNoiseModel diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index fc3e61c6c1..641c764d19 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -5,9 +5,9 @@ from torch.distributions import constraints import pyro.distributions as dist +from pyro.distributions.torch_distribution import TorchDistribution from pyro.contrib.tracking.extended_kalman_filter import EKFState from pyro.contrib.tracking.measurements import PositionMeasurement -from pyro.distributions.torch_distribution import TorchDistribution class EKFDistribution(TorchDistribution): diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index 7ea41ad3a7..bd26b344e2 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -6,7 +6,6 @@ import torch from torch import nn from torch.nn import Parameter - import pyro.distributions as dist from pyro.distributions.util import eye_like diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index 8f24ea4360..cb98c49ccd 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod import torch - from pyro.distributions.util import eye_like diff --git a/pyro/contrib/util.py b/pyro/contrib/util.py index e250639ca7..44ff34832b 100644 --- a/pyro/contrib/util.py +++ b/pyro/contrib/util.py @@ -2,9 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict - import torch - import pyro diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 0a84d5b530..18b804bf58 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -4,19 +4,40 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.affine_beta import AffineBeta from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.coalescent import CoalescentRateLikelihood, CoalescentTimes, CoalescentTimesWithRate -from pyro.distributions.conditional import (ConditionalDistribution, ConditionalTransform, - ConditionalTransformedDistribution, ConditionalTransformModule) -from pyro.distributions.conjugate import BetaBinomial, DirichletMultinomial, GammaPoisson +from pyro.distributions.coalescent import ( + CoalescentRateLikelihood, + CoalescentTimes, + CoalescentTimesWithRate, +) +from pyro.distributions.conditional import ( + ConditionalDistribution, + ConditionalTransform, + ConditionalTransformedDistribution, + ConditionalTransformModule, +) +from pyro.distributions.conjugate import ( + BetaBinomial, + DirichletMultinomial, + GammaPoisson, +) from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals -from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance +from pyro.distributions.diag_normal_mixture_shared_cov import ( + MixtureOfDiagNormalsSharedCovariance, +) from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.extended import ExtendedBetaBinomial, ExtendedBinomial from pyro.distributions.folded import FoldedDistribution from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.hmm import DiscreteHMM, GammaGaussianHMM, GaussianHMM, GaussianMRF, IndependentHMM, LinearHMM +from pyro.distributions.hmm import ( + DiscreteHMM, + GammaGaussianHMM, + GaussianHMM, + GaussianMRF, + IndependentHMM, + LinearHMM, +) from pyro.distributions.improper_uniform import ImproperUniform from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJ, LKJCorrCholesky @@ -29,8 +50,10 @@ from pyro.distributions.polya_gamma import TruncatedPolyaGamma from pyro.distributions.projected_normal import ProjectedNormal from pyro.distributions.rejector import Rejector -from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, - RelaxedOneHotCategoricalStraightThrough) +from pyro.distributions.relaxed_straight_through import ( + RelaxedBernoulliStraightThrough, + RelaxedOneHotCategoricalStraightThrough, +) from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable from pyro.distributions.torch import * # noqa F403 @@ -38,9 +61,17 @@ from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution, TorchDistribution from pyro.distributions.torch_transform import ComposeTransformModule, TransformModule from pyro.distributions.unit import Unit -from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled +from pyro.distributions.util import ( + enable_validation, + is_validation_enabled, + validation_enabled, +) from pyro.distributions.von_mises_3d import VonMises3D -from pyro.distributions.zero_inflated import ZeroInflatedDistribution, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson +from pyro.distributions.zero_inflated import ( + ZeroInflatedDistribution, + ZeroInflatedNegativeBinomial, + ZeroInflatedPoisson, +) from . import constraints, kl, transforms diff --git a/pyro/distributions/ordered_logistic.py b/pyro/distributions/ordered_logistic.py index c8d4ef459d..d6d288fafb 100644 --- a/pyro/distributions/ordered_logistic.py +++ b/pyro/distributions/ordered_logistic.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch - from pyro.distributions import constraints from pyro.distributions.torch import Categorical diff --git a/pyro/distributions/projected_normal.py b/pyro/distributions/projected_normal.py index 1ac3ed28f6..31e7aa909c 100644 --- a/pyro/distributions/projected_normal.py +++ b/pyro/distributions/projected_normal.py @@ -5,10 +5,9 @@ import torch -from pyro.ops.tensor_utils import safe_normalize - from . import constraints from .torch_distribution import TorchDistribution +from pyro.ops.tensor_utils import safe_normalize class ProjectedNormal(TorchDistribution): diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index a3b4fb40f4..7add44583c 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -218,7 +218,6 @@ def _get_cpp_module(): global _cpp_module if _cpp_module is None: import os - from torch.utils.cpp_extension import load path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp") _cpp_module = load(name="cpp_spanning_tree", diff --git a/pyro/distributions/testing/gof.py b/pyro/distributions/testing/gof.py index 4d544b923c..7178874a6f 100644 --- a/pyro/distributions/testing/gof.py +++ b/pyro/distributions/testing/gof.py @@ -55,8 +55,8 @@ def test_my_distribution(): `goftests `_ library. """ -import math import warnings +import math import torch diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 0632723a26..0b1b2c4ec5 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -3,11 +3,11 @@ from torch.distributions import biject_to, transform_to from torch.distributions.transforms import * # noqa F403 -from torch.distributions.transforms import ComposeTransform, ExpTransform, LowerCholeskyTransform from torch.distributions.transforms import __all__ as torch_transforms +from torch.distributions.transforms import ComposeTransform, ExpTransform, LowerCholeskyTransform -from ..constraints import (IndependentConstraint, corr_cholesky_constraint, corr_matrix, ordered_vector, - positive_definite, positive_ordered_vector, sphere) +from ..constraints import (IndependentConstraint, corr_cholesky_constraint, corr_matrix, + ordered_vector, positive_definite, positive_ordered_vector, sphere) from ..torch_transform import ComposeTransformModule from .affine_autoregressive import (AffineAutoregressive, ConditionalAffineAutoregressive, affine_autoregressive, conditional_affine_autoregressive) @@ -26,12 +26,12 @@ matrix_exponential) from .neural_autoregressive import (ConditionalNeuralAutoregressive, NeuralAutoregressive, conditional_neural_autoregressive, neural_autoregressive) -from .normalize import Normalize from .ordered import OrderedTransform from .permute import Permute, permute from .planar import ConditionalPlanar, Planar, conditional_planar, planar from .polynomial import Polynomial, polynomial from .radial import ConditionalRadial, Radial, conditional_radial, radial +from .normalize import Normalize from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import (ConditionalSplineAutoregressive, SplineAutoregressive, conditional_spline_autoregressive, spline_autoregressive) diff --git a/pyro/distributions/transforms/ordered.py b/pyro/distributions/transforms/ordered.py index 95497e45fa..79aea6a261 100644 --- a/pyro/distributions/transforms/ordered.py +++ b/pyro/distributions/transforms/ordered.py @@ -2,9 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch - -from pyro.distributions import constraints from pyro.distributions.transforms import Transform +from pyro.distributions import constraints class OrderedTransform(Transform): diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 5485b7476e..421e5a16fa 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -17,13 +17,13 @@ from pyro.infer.smcfilter import SMCFilter from pyro.infer.svgd import SVGD, IMQSteinKernel, RBFSteinKernel from pyro.infer.svi import SVI +from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO from pyro.infer.trace_mean_field_elbo import JitTraceMeanField_ELBO, TraceMeanField_ELBO from pyro.infer.trace_mmd import Trace_MMD from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO -from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 8891c9cc64..0424949e7d 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -10,8 +10,8 @@ import pyro.poutine as poutine from pyro.distributions import Categorical, Empirical -from pyro.ops.stats import waic from pyro.poutine.util import site_is_subsample +from pyro.ops.stats import waic class EmpiricalMarginal(Empirical): diff --git a/pyro/infer/mcmc/__init__.py b/pyro/infer/mcmc/__init__.py index 99d241b162..e33cbec518 100644 --- a/pyro/infer/mcmc/__init__.py +++ b/pyro/infer/mcmc/__init__.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.nuts import NUTS __all__ = [ diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index c8d41924f6..46497a53b5 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -9,7 +9,7 @@ import pyro from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul from pyro.ops.dual_averaging import DualAveraging -from pyro.ops.welford import WelfordArrowheadCovariance, WelfordCovariance +from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance adapt_window = namedtuple("adapt_window", ["start", "end"]) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 139d7f4a0d..857004212f 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -8,8 +8,9 @@ import pyro import pyro.distributions as dist -from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.distributions.util import scalar_like +from pyro.distributions.testing.fakes import NonreparameterizedNormal + from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 89831a93f8..372104c53f 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import functools -import traceback as tb import warnings from collections import OrderedDict, defaultdict from functools import partial, reduce from itertools import product +import traceback as tb import torch -from opt_einsum import shared_intermediates from torch.distributions import biject_to +from opt_einsum import shared_intermediates import pyro import pyro.poutine as poutine diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index dbeee300a3..82617b2a98 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import warnings from functools import reduce +import warnings import torch diff --git a/pyro/infer/reparam/neutra.py b/pyro/infer/reparam/neutra.py index 89f39d062d..6e444344e0 100644 --- a/pyro/infer/reparam/neutra.py +++ b/pyro/infer/reparam/neutra.py @@ -8,7 +8,6 @@ from pyro import poutine from pyro.distributions.util import sum_rightmost from pyro.infer.autoguide.guides import AutoContinuous - from .reparam import Reparam diff --git a/pyro/infer/svgd.py b/pyro/infer/svgd.py index 9e722a745f..ebc526a86d 100644 --- a/pyro/infer/svgd.py +++ b/pyro/infer/svgd.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import math from abc import ABCMeta, abstractmethod +import math import torch from torch.distributions import biject_to @@ -10,10 +10,10 @@ import pyro from pyro import poutine from pyro.distributions import Delta -from pyro.distributions.util import copy_docs_from +from pyro.infer.trace_elbo import Trace_ELBO from pyro.infer.autoguide.guides import AutoContinuous from pyro.infer.autoguide.initialization import init_to_sample -from pyro.infer.trace_elbo import Trace_ELBO +from pyro.distributions.util import copy_docs_from def vectorize(fn, num_particles, max_plate_nesting): diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index 6eab28596d..d04210ee56 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -10,7 +10,7 @@ import pyro.ops.jit from pyro.distributions.util import scale_and_mask from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import check_fully_reparametrized, is_validation_enabled, torch_item +from pyro.infer.util import is_validation_enabled, torch_item, check_fully_reparametrized from pyro.util import warn_if_nan diff --git a/pyro/infer/trace_mmd.py b/pyro/infer/trace_mmd.py index 661ff727c2..1cc71992b9 100644 --- a/pyro/infer/trace_mmd.py +++ b/pyro/infer/trace_mmd.py @@ -9,8 +9,8 @@ import pyro.ops.jit from pyro import poutine from pyro.infer.elbo import ELBO +from pyro.infer.util import torch_item, is_validation_enabled from pyro.infer.enum import get_importance_trace -from pyro.infer.util import is_validation_enabled, torch_item from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/trace_tail_adaptive_elbo.py b/pyro/infer/trace_tail_adaptive_elbo.py index b05251a300..a69ea6d191 100644 --- a/pyro/infer/trace_tail_adaptive_elbo.py +++ b/pyro/infer/trace_tail_adaptive_elbo.py @@ -6,7 +6,7 @@ import torch from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.util import check_fully_reparametrized, is_validation_enabled +from pyro.infer.util import is_validation_enabled, check_fully_reparametrized class TraceTailAdaptive_ELBO(Trace_ELBO): diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 504060963f..e2e483eec7 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -1,10 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import queue import warnings import weakref from collections import OrderedDict +import queue import torch from opt_einsum import shared_intermediates diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 8e7bc7ed6f..852ea6e658 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -11,7 +11,8 @@ from pyro.distributions.util import detach, is_identically_zero from pyro.infer import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_backward, torch_item +from pyro.infer.util import (MultiFrameTensor, get_plate_stacks, + is_validation_enabled, torch_backward, torch_item) from pyro.util import check_if_enumerated, warn_if_nan diff --git a/pyro/infer/tracetmc_elbo.py b/pyro/infer/tracetmc_elbo.py index f78b277080..51c3ba3b78 100644 --- a/pyro/infer/tracetmc_elbo.py +++ b/pyro/infer/tracetmc_elbo.py @@ -7,6 +7,7 @@ import torch import pyro.poutine as poutine + from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO from pyro.infer.enum import get_importance_trace, iter_discrete_escape, iter_discrete_extend diff --git a/pyro/logger.py b/pyro/logger.py index 64a8c70c46..5bee771e4a 100644 --- a/pyro/logger.py +++ b/pyro/logger.py @@ -3,6 +3,7 @@ import logging + default_format = '%(levelname)s \t %(message)s' log = logging.getLogger("pyro") log.setLevel(logging.INFO) diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index 4714d3095e..9641ab4ec2 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -5,6 +5,7 @@ import torch + SymmArrowhead = namedtuple("SymmArrowhead", ["top", "bottom_diag"]) TriuArrowhead = namedtuple("TriuArrowhead", ["top", "bottom_diag"]) diff --git a/pyro/ops/einsum/torch_map.py b/pyro/ops/einsum/torch_map.py index e4293c1140..6e2832bcff 100644 --- a/pyro/ops/einsum/torch_map.py +++ b/pyro/ops/einsum/torch_map.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import operator + from functools import reduce from pyro.ops import packed diff --git a/pyro/ops/einsum/torch_sample.py b/pyro/ops/einsum/torch_sample.py index 5420c328ba..06c8108886 100644 --- a/pyro/ops/einsum/torch_sample.py +++ b/pyro/ops/einsum/torch_sample.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import operator + from functools import reduce import pyro.distributions as dist diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index e651b4284e..6611d3a93c 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -4,8 +4,8 @@ import torch from torch.autograd import grad -from pyro.ops.linalg import eig_3d, rinverse from pyro.util import warn_if_nan +from pyro.ops.linalg import rinverse, eig_3d def newton_step(loss, x, trust_radius=None): diff --git a/pyro/ops/ssm_gp.py b/pyro/ops/ssm_gp.py index 89abcb2912..eb88ba9d70 100644 --- a/pyro/ops/ssm_gp.py +++ b/pyro/ops/ssm_gp.py @@ -6,7 +6,7 @@ import torch from torch.distributions import constraints -from pyro.nn import PyroModule, PyroParam, pyro_method +from pyro.nn import PyroModule, pyro_method, PyroParam root_three = math.sqrt(3.0) root_five = math.sqrt(5.0) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 2ff2a57ae9..20a5b1d5e1 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.util import ignore_jit_warnings - from .messenger import Messenger diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 0b65987932..9ed1575857 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -7,7 +7,6 @@ import torch from pyro.util import ignore_jit_warnings - from .messenger import Messenger from .runtime import _DIM_ALLOCATOR diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 176c7a772f..39cd234bb9 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -1,8 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import sys from collections import OrderedDict +import sys import opt_einsum diff --git a/tests/__init__.py b/tests/__init__.py index 4056718ce7..200bfc2d65 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging + import os # create log handler for tests diff --git a/tests/conftest.py b/tests/conftest.py index 699cca55c2..2cfcba39d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pyro + torch.set_default_tensor_type(os.environ.get('PYRO_TENSOR_TYPE', 'torch.DoubleTensor')) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 767b763ff4..9e9002d507 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -12,10 +12,10 @@ import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.distributions.transforms import block_autoregressive, iterated -from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO +from pyro.distributions.transforms import iterated, block_autoregressive from pyro.infer.autoguide import (AutoDiagonalNormal, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal) +from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO from pyro.infer.autoguide.guides import AutoNormalizingFlow from tests.common import assert_equal from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/contrib/autoguide/test_mean_field_entropy.py index 2f5cd163db..9f8c301b32 100644 --- a/tests/contrib/autoguide/test_mean_field_entropy.py +++ b/tests/contrib/autoguide/test_mean_field_entropy.py @@ -1,9 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest -import scipy.special as sc import torch +import scipy.special as sc +import pytest import pyro import pyro.distributions as dist diff --git a/tests/contrib/autoname/test_scoping.py b/tests/contrib/autoname/test_scoping.py index aa7e44bae6..d10d6f2d7f 100644 --- a/tests/contrib/autoname/test_scoping.py +++ b/tests/contrib/autoname/test_scoping.py @@ -8,7 +8,7 @@ import pyro import pyro.distributions.torch as dist import pyro.poutine as poutine -from pyro.contrib.autoname import name_count, scope +from pyro.contrib.autoname import scope, name_count logger = logging.getLogger(__name__) diff --git a/tests/contrib/bnn/test_hidden_layer.py b/tests/contrib/bnn/test_hidden_layer.py index c688572d0f..1067cc03f2 100644 --- a/tests/contrib/bnn/test_hidden_layer.py +++ b/tests/contrib/bnn/test_hidden_layer.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch import torch.nn.functional as F from torch.distributions import Normal +import pytest from pyro.contrib.bnn import HiddenLayer from tests.common import assert_equal diff --git a/tests/contrib/epidemiology/test_quant.py b/tests/contrib/epidemiology/test_quant.py index d2a0edbc1d..f9dc53bb64 100644 --- a/tests/contrib/epidemiology/test_quant.py +++ b/tests/contrib/epidemiology/test_quant.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + import torch from pyro.contrib.epidemiology.util import compute_bin_probs diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 9b273e5e2e..75fff55463 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -7,6 +7,7 @@ import pyroapi import pytest import torch + from torch.autograd import grad from torch.distributions import constraints @@ -16,10 +17,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro - - import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_named_handlers.py b/tests/contrib/funsor/test_named_handlers.py index c4c57b7bd5..48c464daa3 100644 --- a/tests/contrib/funsor/test_named_handlers.py +++ b/tests/contrib/funsor/test_named_handlers.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import logging from collections import OrderedDict +import logging import pytest import torch @@ -11,7 +11,6 @@ try: import funsor from funsor.tensor import Tensor - import pyro.contrib.funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger funsor.set_backend("torch") diff --git a/tests/contrib/funsor/test_pyroapi_funsor.py b/tests/contrib/funsor/test_pyroapi_funsor.py index 9e050462e9..74dbf972e3 100644 --- a/tests/contrib/funsor/test_pyroapi_funsor.py +++ b/tests/contrib/funsor/test_pyroapi_funsor.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 try: # triggers backend registration import funsor - import pyro.contrib.funsor # noqa: F401 funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index 54d4eedaae..cc1ab52178 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -14,10 +14,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro, pyro_backend - - import pyro.contrib.funsor funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_enum.py b/tests/contrib/funsor/test_valid_models_enum.py index 7df1b23a90..3ef3241ad2 100644 --- a/tests/contrib/funsor/test_valid_models_enum.py +++ b/tests/contrib/funsor/test_valid_models_enum.py @@ -1,10 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict import contextlib import logging import os -from collections import defaultdict from queue import LifoQueue import pytest @@ -19,12 +19,11 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor - import pyro.contrib.funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK funsor.set_backend("torch") from pyroapi import distributions as dist - from pyroapi import handlers, infer, pyro, pyro_backend + from pyroapi import infer, handlers, pyro, pyro_backend except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/funsor/test_valid_models_plate.py b/tests/contrib/funsor/test_valid_models_plate.py index f5d30fc1b7..ed20ee4be4 100644 --- a/tests/contrib/funsor/test_valid_models_plate.py +++ b/tests/contrib/funsor/test_valid_models_plate.py @@ -12,10 +12,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro - - import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_valid_models_sequential_plate.py b/tests/contrib/funsor/test_valid_models_sequential_plate.py index 40eeb79cb6..1de6af5b08 100644 --- a/tests/contrib/funsor/test_valid_models_sequential_plate.py +++ b/tests/contrib/funsor/test_valid_models_sequential_plate.py @@ -11,10 +11,9 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + import pyro.contrib.funsor from pyroapi import distributions as dist from pyroapi import infer, pyro - - import pyro.contrib.funsor from tests.contrib.funsor.test_valid_models_enum import assert_ok funsor.set_backend("torch") except ImportError: diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 99e548a3bd..5048d9a1c4 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -3,6 +3,7 @@ import pytest import torch + from pyroapi import pyro_backend from torch.distributions import constraints @@ -12,12 +13,10 @@ try: import funsor from funsor.testing import assert_close - from pyroapi import distributions as dist - import pyro.contrib.funsor + from pyroapi import distributions as dist funsor.set_backend("torch") - from pyroapi import handlers, infer, pyro - + from pyroapi import handlers, pyro, infer from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") diff --git a/tests/contrib/gp/test_kernels.py b/tests/contrib/gp/test_kernels.py index db1c803786..cc9797ff2c 100644 --- a/tests/contrib/gp/test_kernels.py +++ b/tests/contrib/gp/test_kernels.py @@ -6,8 +6,9 @@ import pytest import torch -from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, Exponential, Linear, - Matern32, Matern52, Periodic, Polynomial, Product, RationalQuadratic, Sum, +from pyro.contrib.gp.kernels import (RBF, Brownian, Constant, Coregionalize, Cosine, Exponent, + Exponential, Linear, Matern32, Matern52, Periodic, + Polynomial, Product, RationalQuadratic, Sum, VerticalScaling, Warping, WhiteNoise) from tests.common import assert_equal diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index c63c1ce9a5..71fbe663ad 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -11,6 +11,7 @@ from pyro.contrib.gp.models import VariationalGP, VariationalSparseGP from pyro.contrib.gp.util import train + T = namedtuple("TestGPLikelihood", ["model_class", "X", "y", "kernel", "likelihood"]) X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0], [3.0, 4.0, 6.0]]) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 711089025d..d5afa24ec2 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -8,12 +8,13 @@ import torch import pyro.distributions as dist -from pyro.contrib.gp.kernels import RBF, Cosine, Matern32, WhiteNoise +from pyro.contrib.gp.kernels import Cosine, Matern32, RBF, WhiteNoise from pyro.contrib.gp.likelihoods import Gaussian -from pyro.contrib.gp.models import GPLVM, GPRegression, SparseGPRegression, VariationalGP, VariationalSparseGP +from pyro.contrib.gp.models import (GPLVM, GPRegression, SparseGPRegression, + VariationalGP, VariationalSparseGP) from pyro.contrib.gp.util import train -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.nn.module import PyroSample from tests.common import assert_equal diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index aa8df92188..57e8e4e7da 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -3,9 +3,9 @@ import math -import pytest import torch +import pytest from pyro.contrib.oed.eig import EwmaLog from tests.common import assert_equal diff --git a/tests/contrib/oed/test_finite_spaces_eig.py b/tests/contrib/oed/test_finite_spaces_eig.py index b6f69234d4..49fc02493a 100644 --- a/tests/contrib/oed/test_finite_spaces_eig.py +++ b/tests/contrib/oed/test_finite_spaces_eig.py @@ -1,15 +1,17 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import (donsker_varadhan_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, nmc_eig, - posterior_eig, vnmc_eig) +from pyro.contrib.oed.eig import ( + nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, lfire_eig, + donsker_varadhan_eig) from pyro.contrib.util import iter_plates_to_shape + from tests.common import assert_equal try: diff --git a/tests/contrib/oed/test_glmm.py b/tests/contrib/oed/test_glmm.py index cb3e95d169..6e855525dd 100644 --- a/tests/contrib/oed/test_glmm.py +++ b/tests/contrib/oed/test_glmm.py @@ -8,8 +8,10 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.contrib.oed.glmm import (group_linear_model, known_covariance_linear_model, logistic_regression_model, - normal_inverse_gamma_linear_model, sigmoid_model, zero_mean_unit_obs_sd_lm) +from pyro.contrib.oed.glmm import ( + known_covariance_linear_model, group_linear_model, zero_mean_unit_obs_sd_lm, + normal_inverse_gamma_linear_model, logistic_regression_model, sigmoid_model +) from tests.common import assert_equal diff --git a/tests/contrib/oed/test_linear_models_eig.py b/tests/contrib/oed/test_linear_models_eig.py index f84ba916e5..30280cb602 100644 --- a/tests/contrib/oed/test_linear_models_eig.py +++ b/tests/contrib/oed/test_linear_models_eig.py @@ -1,19 +1,20 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest import pyro import pyro.distributions as dist import pyro.optim as optim -from pyro.contrib.oed.eig import (donsker_varadhan_eig, laplace_eig, lfire_eig, marginal_eig, marginal_likelihood_eig, - nmc_eig, posterior_eig, vnmc_eig) +from pyro.infer import Trace_ELBO from pyro.contrib.oed.glmm import known_covariance_linear_model -from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from pyro.contrib.oed.util import linear_model_ground_truth +from pyro.contrib.oed.eig import ( + nmc_eig, posterior_eig, marginal_eig, marginal_likelihood_eig, vnmc_eig, laplace_eig, lfire_eig, + donsker_varadhan_eig) from pyro.contrib.util import rmv, rvv -from pyro.infer import Trace_ELBO +from pyro.contrib.oed.glmm.guides import LinearModelLaplaceGuide from tests.common import assert_equal diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 4c392bc997..5a1a43c194 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -4,7 +4,6 @@ import math import torch.tensor as tt - from pyro.distributions import Uniform N_SAMPLES = 100 diff --git a/tests/contrib/test_util.py b/tests/contrib/test_util.py index 60a3115dad..442ca61bec 100644 --- a/tests/contrib/test_util.py +++ b/tests/contrib/test_util.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict - import pytest import torch -from pyro.contrib.util import get_indices, lexpand, rdiag, rexpand, rmv, rtril, rvv, tensor_to_dict +from pyro.contrib.util import ( + get_indices, tensor_to_dict, rmv, rvv, lexpand, rexpand, rdiag, rtril +) from tests.common import assert_equal diff --git a/tests/contrib/timeseries/test_gp.py b/tests/contrib/timeseries/test_gp.py index e2e39a0aba..2698faa01b 100644 --- a/tests/contrib/timeseries/test_gp.py +++ b/tests/contrib/timeseries/test_gp.py @@ -2,15 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import math - -import pytest import torch +from tests.common import assert_equal import pyro -from pyro.contrib.timeseries import (DependentMaternGP, GenericLGSSM, GenericLGSSMWithGPNoiseModel, IndependentMaternGP, - LinearlyCoupledMaternGP) +from pyro.contrib.timeseries import (IndependentMaternGP, LinearlyCoupledMaternGP, GenericLGSSM, + GenericLGSSMWithGPNoiseModel, DependentMaternGP) from pyro.ops.tensor_utils import block_diag_embed -from tests.common import assert_equal +import pytest @pytest.mark.parametrize('model,obs_dim,nu_statedim', [('ssmgp', 3, 1.5), ('ssmgp', 2, 2.5), diff --git a/tests/contrib/timeseries/test_lgssm.py b/tests/contrib/timeseries/test_lgssm.py index 5b5ed9d339..f5c2dac137 100644 --- a/tests/contrib/timeseries/test_lgssm.py +++ b/tests/contrib/timeseries/test_lgssm.py @@ -1,11 +1,11 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch -from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel from tests.common import assert_equal +from pyro.contrib.timeseries import GenericLGSSM, GenericLGSSMWithGPNoiseModel +import pytest @pytest.mark.parametrize('model_class', ['lgssm', 'lgssmgp']) diff --git a/tests/contrib/tracking/test_assignment.py b/tests/contrib/tracking/test_assignment.py index 9c425dd502..554a373eb3 100644 --- a/tests/contrib/tracking/test_assignment.py +++ b/tests/contrib/tracking/test_assignment.py @@ -1,12 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import logging - import pytest import torch from torch.autograd import grad +import logging + import pyro import pyro.distributions as dist from pyro.contrib.tracking.assignment import MarginalAssignment, MarginalAssignmentPersistent, MarginalAssignmentSparse diff --git a/tests/contrib/tracking/test_distributions.py b/tests/contrib/tracking/test_distributions.py index fe4c149b49..4c589ac221 100644 --- a/tests/contrib/tracking/test_distributions.py +++ b/tests/contrib/tracking/test_distributions.py @@ -1,12 +1,13 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch from pyro.contrib.tracking.distributions import EKFDistribution from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous +import pytest + @pytest.mark.parametrize('Model', [NcpContinuous, NcvContinuous]) @pytest.mark.parametrize('dim', [2, 3]) diff --git a/tests/contrib/tracking/test_dynamic_models.py b/tests/contrib/tracking/test_dynamic_models.py index 4f93afe523..51df52e75d 100644 --- a/tests/contrib/tracking/test_dynamic_models.py +++ b/tests/contrib/tracking/test_dynamic_models.py @@ -3,7 +3,8 @@ import torch -from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcpDiscrete, NcvContinuous, NcvDiscrete +from pyro.contrib.tracking.dynamic_models import (NcpContinuous, NcvContinuous, + NcvDiscrete, NcpDiscrete) from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_ekf.py b/tests/contrib/tracking/test_ekf.py index 35db1544d1..99cec4488c 100644 --- a/tests/contrib/tracking/test_ekf.py +++ b/tests/contrib/tracking/test_ekf.py @@ -3,9 +3,10 @@ import torch -from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous from pyro.contrib.tracking.extended_kalman_filter import EKFState +from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous from pyro.contrib.tracking.measurements import PositionMeasurement + from tests.common import assert_equal, assert_not_equal diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index c3401f4114..1d0fca7147 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -16,6 +16,7 @@ from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton + logger = logging.getLogger(__name__) diff --git a/tests/contrib/tracking/test_measurements.py b/tests/contrib/tracking/test_measurements.py index 373cad0e79..38f2afcd3d 100644 --- a/tests/contrib/tracking/test_measurements.py +++ b/tests/contrib/tracking/test_measurements.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch - from pyro.contrib.tracking.measurements import PositionMeasurement diff --git a/tests/distributions/test_empirical.py b/tests/distributions/test_empirical.py index 3f2d4435dd..7d220aa95e 100644 --- a/tests/distributions/test_empirical.py +++ b/tests/distributions/test_empirical.py @@ -5,7 +5,7 @@ import torch from pyro.distributions.empirical import Empirical -from tests.common import assert_close, assert_equal +from tests.common import assert_equal, assert_close @pytest.mark.parametrize("size", [[], [1], [2, 3]]) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index 03737ecf63..f02426696e 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -4,12 +4,14 @@ import logging import math -import pytest import torch -from pyro.distributions import GaussianScaleMixture, MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance +import pytest +from pyro.distributions import MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture +from pyro.distributions import MixtureOfDiagNormals from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 53857c2791..63f3daea57 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import pytest import torch +import pytest from pyro.distributions.transforms import HaarTransform from tests.common import assert_equal diff --git a/tests/distributions/test_ig.py b/tests/distributions/test_ig.py index 5091e02ad7..215d00ed36 100644 --- a/tests/distributions/test_ig.py +++ b/tests/distributions/test_ig.py @@ -3,9 +3,9 @@ import math -import pytest import torch +import pytest from pyro.distributions import Gamma, InverseGamma from tests.common import assert_equal diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index 27cfdc4910..e71336b2af 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -6,8 +6,9 @@ from torch import tensor from torch.distributions import kl_divergence +from pyro.distributions.util import broadcast_shape from pyro.distributions.torch import Bernoulli, Normal -from pyro.distributions.util import broadcast_shape, scale_and_mask +from pyro.distributions.util import scale_and_mask from tests.common import assert_equal diff --git a/tests/distributions/test_mvt.py b/tests/distributions/test_mvt.py index ab2dec09ad..a61cb1b3f8 100644 --- a/tests/distributions/test_mvt.py +++ b/tests/distributions/test_mvt.py @@ -4,6 +4,7 @@ import math import pytest + import torch from torch.distributions import Gamma, MultivariateNormal, StudentT diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index eb04d455fb..f1d92bbb0b 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import pytest import torch +import pytest from pyro.distributions import AVFMultivariateNormal, MultivariateNormal, OMTMultivariateNormal from tests.common import assert_equal diff --git a/tests/distributions/test_ordered_logistic.py b/tests/distributions/test_ordered_logistic.py index 715db994fb..6c6c3ae409 100644 --- a/tests/distributions/test_ordered_logistic.py +++ b/tests/distributions/test_ordered_logistic.py @@ -6,9 +6,10 @@ import torch.tensor as tt from torch.autograd.functional import jacobian -from pyro.distributions import Normal, OrderedLogistic +from pyro.distributions import OrderedLogistic, Normal from pyro.distributions.transforms import OrderedTransform + # Tests for the OrderedLogistic distribution diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index 66f881bbc5..fabb71b451 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -3,9 +3,9 @@ import inspect import io -import pickle import pytest +import pickle import torch import pyro.distributions as dist diff --git a/tests/distributions/test_spanning_tree.py b/tests/distributions/test_spanning_tree.py index 5cdf85eae0..3336aee03b 100644 --- a/tests/distributions/test_spanning_tree.py +++ b/tests/distributions/test_spanning_tree.py @@ -5,10 +5,9 @@ import os from collections import Counter +import pyro import pytest import torch - -import pyro from pyro.distributions.spanning_tree import (NUM_SPANNING_TREES, SpanningTree, find_best_tree, make_complete_graph, sample_tree) from tests.common import assert_equal, xfail_if_not_implemented diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 6c31b2caeb..58f69b641b 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -1,8 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import operator -from functools import partial, reduce from unittest import TestCase import pytest @@ -12,6 +10,9 @@ import pyro.distributions.transforms as T from tests.common import assert_close +from functools import partial, reduce +import operator + pytestmark = pytest.mark.init(rng_seed=123) diff --git a/tests/doctest_fixtures.py b/tests/doctest_fixtures.py index 0d4e785d84..8be64b2948 100644 --- a/tests/doctest_fixtures.py +++ b/tests/doctest_fixtures.py @@ -6,15 +6,16 @@ import torch import pyro -import pyro.contrib.autoname.named as named import pyro.contrib.gp as gp +import pyro.contrib.autoname.named as named import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import EmpiricalMarginal -from pyro.infer.mcmc import HMC, NUTS from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc import HMC, NUTS from pyro.params import param_with_module_name + # Fix seed for all doctest runs. pyro.set_rng_seed(0) diff --git a/tests/infer/mcmc/test_adaptation.py b/tests/infer/mcmc/test_adaptation.py index 675e43525d..2fad237d90 100644 --- a/tests/infer/mcmc/test_adaptation.py +++ b/tests/infer/mcmc/test_adaptation.py @@ -4,7 +4,12 @@ import pytest import torch -from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix, WarmupAdapter, adapt_window +from pyro.infer.mcmc.adaptation import ( + ArrowheadMassMatrix, + BlockMassMatrix, + WarmupAdapter, + adapt_window, +) from tests.common import assert_close, assert_equal diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 2f2f9f967d..58bbf0a76e 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -11,9 +11,9 @@ import pyro import pyro.distributions as dist from pyro.infer.mcmc import NUTS -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC -from tests.common import assert_close, assert_equal +from pyro.infer.mcmc.api import MCMC +from tests.common import assert_equal, assert_close logger = logging.getLogger(__name__) diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index cb203fcea8..a577da9d40 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -11,7 +11,7 @@ import pyro.distributions as dist from pyro import poutine from pyro.infer.mcmc import HMC, NUTS -from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler +from pyro.infer.mcmc.api import MCMC, _UnarySampler, _MultiSampler from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model from pyro.util import optional diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 106a4510e3..43630bcb16 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -10,12 +10,12 @@ import pyro import pyro.distributions as dist +from pyro.infer.autoguide import AutoDelta +from pyro.contrib.conjugate.infer import BetaBinomialPair, collapse_conjugate, GammaPoissonPair, posterior_replay +from pyro.infer import TraceEnum_ELBO, SVI +from pyro.infer.mcmc import ArrowheadMassMatrix, MCMC, NUTS import pyro.optim as optim import pyro.poutine as poutine -from pyro.contrib.conjugate.infer import BetaBinomialPair, GammaPoissonPair, collapse_conjugate, posterior_replay -from pyro.infer import SVI, TraceEnum_ELBO -from pyro.infer.autoguide import AutoDelta -from pyro.infer.mcmc import MCMC, NUTS, ArrowheadMassMatrix from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal diff --git a/tests/infer/test_abstract_infer.py b/tests/infer/test_abstract_infer.py index bfacd142a2..483bc4e854 100644 --- a/tests/infer/test_abstract_infer.py +++ b/tests/infer/test_abstract_infer.py @@ -8,11 +8,12 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Trace_ELBO from pyro.infer.autoguide import AutoLaplaceApproximation +from pyro.infer import SVI, Trace_ELBO from pyro.infer.mcmc import MCMC, NUTS from tests.common import assert_equal + pytestmark = pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index cf7d342e36..836879d4b6 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,8 +8,8 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal +from pyro.infer import Predictive, SVI, Trace_ELBO from tests.common import assert_close diff --git a/tests/infer/test_svgd.py b/tests/infer/test_svgd.py index c6944dedc2..2d10b53b55 100644 --- a/tests/infer/test_svgd.py +++ b/tests/infer/test_svgd.py @@ -6,9 +6,11 @@ import pyro import pyro.distributions as dist -from pyro.infer import SVGD, IMQSteinKernel, RBFSteinKernel -from pyro.infer.autoguide.utils import _product + +from pyro.infer import SVGD, RBFSteinKernel, IMQSteinKernel from pyro.optim import Adam +from pyro.infer.autoguide.utils import _product + from tests.common import assert_equal diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index 35667e56ef..cf55ed02ce 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -15,10 +15,11 @@ from pyro.distributions.testing import fakes from pyro.infer import config_enumerate from pyro.infer.importance import vectorized_importance_weights -from pyro.infer.traceenum_elbo import TraceEnum_ELBO from pyro.infer.tracetmc_elbo import TraceTMC_ELBO +from pyro.infer.traceenum_elbo import TraceEnum_ELBO from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 7fda7a399f..236b460050 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -3,12 +3,12 @@ import math -import pytest import torch import pyro import pyro.distributions as dist import pyro.poutine as poutine +import pytest from pyro.infer.importance import psis_diagnostic from pyro.infer.util import MultiFrameTensor from tests.common import assert_equal diff --git a/tests/ops/test_arrowhead.py b/tests/ops/test_arrowhead.py index 2ffa76bf78..13feae5697 100644 --- a/tests/ops/test_arrowhead.py +++ b/tests/ops/test_arrowhead.py @@ -5,6 +5,7 @@ import torch from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_gram, triu_inverse, triu_matvecmul + from tests.common import assert_close diff --git a/tests/ops/test_gamma_gaussian.py b/tests/ops/test_gamma_gaussian.py index 74c018bcc5..872a42e531 100644 --- a/tests/ops/test_gamma_gaussian.py +++ b/tests/ops/test_gamma_gaussian.py @@ -9,8 +9,12 @@ import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot, - matrix_and_mvn_to_gamma_gaussian) +from pyro.ops.gamma_gaussian import ( + GammaGaussian, + gamma_gaussian_tensordot, + matrix_and_mvn_to_gamma_gaussian, + gamma_and_mvn_to_gamma_gaussian, +) from tests.common import assert_close from tests.ops.gamma_gaussian import assert_close_gamma_gaussian, random_gamma, random_gamma_gaussian from tests.ops.gaussian import random_mvn diff --git a/tests/ops/test_newton.py b/tests/ops/test_newton.py index d264b3ae35..d502cde5d7 100644 --- a/tests/ops/test_newton.py +++ b/tests/ops/test_newton.py @@ -11,6 +11,7 @@ from pyro.ops.newton import newton_step from tests.common import assert_equal + logger = logging.getLogger(__name__) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 32026d77a2..53fb0213fb 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -16,8 +16,8 @@ import pyro.optim as optim from pyro.distributions.testing import fakes from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO -from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.hmc import HMC +from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.nuts import NUTS Model = namedtuple('TestModel', ['model', 'model_args', 'model_id']) diff --git a/tests/poutine/test_nesting.py b/tests/poutine/test_nesting.py index ede0456c32..6fd6f3614d 100644 --- a/tests/poutine/test_nesting.py +++ b/tests/poutine/test_nesting.py @@ -4,10 +4,11 @@ import logging import pyro -import pyro.distributions as dist import pyro.poutine as poutine +import pyro.distributions as dist import pyro.poutine.runtime + logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 99fbfb6336..f2f4eee025 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -6,12 +6,12 @@ import logging import pickle import warnings -from queue import Queue from unittest import TestCase import pytest import torch import torch.nn as nn +from queue import Queue import pyro import pyro.distributions as dist @@ -19,7 +19,7 @@ from pyro.distributions import Bernoulli, Categorical, Normal from pyro.poutine.runtime import _DIM_ALLOCATOR, NonlocalExit from pyro.poutine.util import all_escape, discrete_escape -from tests.common import assert_close, assert_equal, assert_not_equal +from tests.common import assert_equal, assert_not_equal, assert_close logger = logging.getLogger(__name__) diff --git a/tests/poutine/test_trace_struct.py b/tests/poutine/test_trace_struct.py index 4511ccbdf3..9ad7d351a6 100644 --- a/tests/poutine/test_trace_struct.py +++ b/tests/poutine/test_trace_struct.py @@ -8,6 +8,7 @@ from pyro.poutine import Trace from tests.common import assert_equal + EDGE_SETS = [ # 1 # / \ diff --git a/tests/pyroapi/test_pyroapi.py b/tests/pyroapi/test_pyroapi.py index 271c38efab..1fa1673b9f 100644 --- a/tests/pyroapi/test_pyroapi.py +++ b/tests/pyroapi/test_pyroapi.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 diff --git a/tests/test_generic.py b/tests/test_generic.py index 1ca5c77588..a3324b27c0 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from pyroapi.testing import MODELS -from pyro.generic import handlers, infer, ops, pyro, pyro_backend +from pyro.generic import handlers, infer, pyro, pyro_backend, ops +from pyroapi.testing import MODELS from tests.common import xfail_if_not_implemented pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 22f331a450..d285ad69b5 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import torch - import pyro import pyro.distributions as dist +import torch pytestmark = pytest.mark.stage('unit') diff --git a/tests/test_util.py b/tests/test_util.py index 09ec92f4f7..f8b382b4ec 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import warnings - import pytest -import torch +import torch from pyro import util pytestmark = pytest.mark.stage('unit') From 161f1e141b9a6b54672aeb7f830a384b466f24da Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 15:28:33 -0500 Subject: [PATCH 39/91] Adjust license headers. --- pyro/contrib/mue/__init__.py | 2 -- pyro/contrib/mue/missingdatahmm.py | 1 + pyro/contrib/mue/statearrangers.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py index d6960608d6..e69de29bb2 100644 --- a/pyro/contrib/mue/__init__.py +++ b/pyro/contrib/mue/__init__.py @@ -1,2 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index a6396d2044..f858dd68cb 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -1,5 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + import torch from pyro.distributions import constraints diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index c8511b3585..a54d7bf361 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -1,5 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + import torch import torch.nn as nn From 1a86f9a6c2ec1409a4ea5c20f0fa5d01b2ff8c9f Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 20 Feb 2021 18:02:06 -0500 Subject: [PATCH 40/91] Profile and Factor example tests edited and added to main list. --- examples/contrib/mue/FactorMuE.py | 10 +- examples/contrib/mue/ProfileHMM.py | 212 ++++++++++++++++++----------- tests/test_examples.py | 6 +- 3 files changed, 145 insertions(+), 83 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 91d463ef8a..95c7169b4f 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -37,7 +37,7 @@ def main(args): pyro.set_rng_seed(args.rng_seed) - # Construct example dataset. + # Load dataset. if args.test: dataset = generate_data(args.small) else: @@ -73,8 +73,8 @@ def main(args): losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler) # Plot and save. + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if args.plots: - time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") plt.figure(figsize=(6, 6)) plt.plot(losses) plt.xlabel('step') @@ -135,7 +135,7 @@ def main(args): parser.add_argument("--small", action='store_true', default=False, help='Run with small example dataset.') parser.add_argument("-r", "--rng-seed", default=0, type=int) - parser.add_argument("-f", "--file", default=None, + parser.add_argument("-f", "--file", default=None, type=str, help='Input file (fasta format).') parser.add_argument("-a", "--alphabet", default='amino-acid', help='Alphabet (amino-acid OR dna).') @@ -143,7 +143,7 @@ def main(args): help='z space dimension.') parser.add_argument("-b", "--batch-size", default=10, type=int, help='Batch size.') - parser.add_argument("-M", "--latent-seq-length", default=None, + parser.add_argument("-M", "--latent-seq-length", default=None, type=int, help='Latent sequence length.') parser.add_argument("-idfac", "--indel-factor", default=False, type=bool, help='Indel parameters depend on latent variable.') @@ -153,7 +153,7 @@ def main(args): help='Use automatic relevance detection prior.') parser.add_argument("-sub", "--substitution-matrix", default=True, type=bool, help='Use substitution matrix.') - parser.add_argument("-D", "--latent-alphabet", default=None, + parser.add_argument("-D", "--latent-alphabet", default=None, type=int, help='Latent alphabet length.') parser.add_argument("-L", "--length-model", default=False, type=bool, help='Model sequence length.') diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 2b59316ae9..d9e5c47dd1 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -7,99 +7,159 @@ import argparse import datetime +import json +import os import matplotlib.pyplot as plt import torch +from torch.optim import Adam import pyro +from pyro.contrib.mue.dataloaders import BiosequenceDataset from pyro.contrib.mue.models import ProfileHMM -from pyro.infer import SVI, Trace_ELBO -from pyro.optim import Adam +from pyro.optim import MultiStepLR -def main(args): - - torch.manual_seed(0) - torch.set_default_tensor_type('torch.DoubleTensor') - - small_test = args.test - +def generate_data(small_test): + """Generate example dataset.""" if small_test: mult_dat = 1 - mult_step = 1 else: mult_dat = 10 - mult_step = 10 - - data = torch.cat([torch.tensor([[0., 1.], - [1., 0.], - [0., 1.], - [0., 1.], - [1., 0.], - [0., 0.]])[None, :, :] - for j in range(6*mult_dat)] + - [torch.tensor([[0., 1.], - [1., 0.], - [1., 0.], - [0., 1.], - [0., 0.], - [0., 0.]])[None, :, :] - for j in range(4*mult_dat)], dim=0) - # Set up inference. - latent_seq_length, alphabet_length = 6, 2 - adam_params = {"lr": 0.05, "betas": (0.90, 0.999)} - optimizer = Adam(adam_params) - model = ProfileHMM(latent_seq_length, alphabet_length) - - svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO()) - n_steps = 10*mult_step - - # Run inference. - losses = [] - t0 = datetime.datetime.now() - for step in range(n_steps): - loss = svi.step(data) - losses.append(loss) - if step % 10 == 0: - print(loss, ' ', datetime.datetime.now() - t0) + + seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat + dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) + + return dataset + + +def main(args): + + pyro.set_rng_seed(args.rng_seed) + + # Load dataset. + if args.test: + dataset = generate_data(args.small) + else: + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + args.batch_size = min([dataset.data_size, args.batch_size]) + + # Construct model. + latent_seq_length = args.latent_seq_length + if args.latent_seq_length is None: + latent_seq_length = dataset.max_length + model = ProfileHMM(latent_seq_length, dataset.alphabet_length, + length_model=args.length_model, + prior_scale=args.prior_scale, + indel_prior_bias=args.indel_prior_bias) + + # Infer. + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': args.learning_rate}, + 'milestones': json.loads(args.milestones), + 'gamma': args.learning_gamma}) + if args.test and not args.small: + n_epochs = 100 + else: + n_epochs = args.n_epochs + losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler) # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - plt.figure(figsize=(6, 6)) - plt.plot(losses) - plt.xlabel('step') - plt.ylabel('loss') - plt.savefig('phmm_plot.loss_{}.pdf'.format(time_stamp)) - - plt.figure(figsize=(6, 6)) - precursor_seq = pyro.param("precursor_seq_q_mn").detach() - precursor_seq_expect = torch.exp(precursor_seq - - precursor_seq.logsumexp(-1, True)) - plt.plot(precursor_seq_expect[:, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of character 1') - plt.savefig('phmm_plot.precursor_seq_prob_{}.pdf'.format(time_stamp)) - - plt.figure(figsize=(6, 6)) - insert = pyro.param("insert_q_mn").detach() - insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) - plt.plot(insert_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of insert') - plt.savefig('phmm_plot.insert_prob_{}.pdf'.format(time_stamp)) - plt.figure(figsize=(6, 6)) - delete = pyro.param("delete_q_mn").detach() - delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) - plt.plot(delete_expect[:, :, 1].numpy()) - plt.xlabel('position') - plt.ylabel('probability of delete') - plt.savefig('phmm_plot.delete_prob_{}.pdf'.format(time_stamp)) + if args.plots: + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.loss_{}.pdf'.format(time_stamp))) + + plt.figure(figsize=(6, 6)) + precursor_seq = pyro.param("precursor_seq_q_mn").detach() + precursor_seq_expect = torch.exp(precursor_seq - + precursor_seq.logsumexp(-1, True)) + plt.plot(precursor_seq_expect[:, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of character 1') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.precursor_seq_prob_{}.pdf'.format( + time_stamp))) + + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + if args.save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.delete_prob_{}.pdf'.format(time_stamp))) + + if args.save: + pyro.get_param_store().save(os.path.join( + args.out_folder, + 'ProfileHMM_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'ProfileHMM_results.input_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('[args]\n') + for elem in list(args.__dict__.keys()): + ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="Basic profile HMM model (constant + MuE).") - parser.add_argument('-t', '--test', action='store_true', default=False, - help='small dataset, a few steps') + parser = argparse.ArgumentParser(description="Factor MuE model.") + parser.add_argument("--test", action='store_true', default=False, + help='Run with generated example dataset.') + parser.add_argument("--small", action='store_true', default=False, + help='Run with small example dataset.') + parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("-f", "--file", default=None, type=str, + help='Input file (fasta format).') + parser.add_argument("-a", "--alphabet", default='amino-acid', + help='Alphabet (amino-acid OR dna).') + parser.add_argument("-b", "--batch-size", default=10, type=int, + help='Batch size.') + parser.add_argument("-M", "--latent-seq-length", default=None, type=int, + help='Latent sequence length.') + parser.add_argument("-L", "--length-model", default=False, type=bool, + help='Model sequence length.') + parser.add_argument("--prior-scale", default=1., type=float, + help='Prior scale parameter (all parameters).') + parser.add_argument("--indel-prior-bias", default=10., type=float, + help='Indel prior bias parameter.') + parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, + help='Learning rate for Adam optimizer.') + parser.add_argument("--milestones", default='[]', type=str, + help='Milestones for multistage learning rate.') + parser.add_argument("--learning-gamma", default=0.5, type=float, + help='Gamma parameter for multistage learning rate.') + parser.add_argument("-e", "--n-epochs", default=10, type=int, + help='Number of epochs of training.') + parser.add_argument("-p", "--plots", default=True, type=bool, + help='Make plots.') + parser.add_argument("-s", "--save", default=True, type=bool, + help='Save plots and results.') + parser.add_argument("-outf", "--out-folder", default='.', + help='Folder to save plots.') args = parser.parse_args() + + torch.set_default_dtype(torch.float64) + main(args) diff --git a/tests/test_examples.py b/tests/test_examples.py index 050b3ff6ca..4458206002 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,8 +53,10 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/mue/FactorMuE.py --test --small', - 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False', + 'contrib/mue/FactorMuE.py --test --small -p False -s False', + 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False -p False -s False', + 'contrib/mue/ProfileHMM.py --test --small -p False -s False', + 'contrib/mue/ProfileHMM.py --test --small -L True -p False -s False', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', From ae861ccd1fa9c1e98c1b31bc1b2f390e7918adb7 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 13:29:43 -0500 Subject: [PATCH 41/91] adjust prior names --- pyro/contrib/mue/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index a4cded1e2c..f65c21c079 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -17,14 +17,14 @@ import pyro.distributions as dist from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM from pyro.contrib.mue.statearrangers import Profile -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import MultiStepLR class ProfileHMM(nn.Module): """Model: Constant + MuE. """ def __init__(self, latent_seq_length, alphabet_length, - length_model=False, prior_scale=1., indel_prior_strength=10.): + length_model=False, prior_scale=1., indel_prior_bias=10.): super().__init__() assert isinstance(latent_seq_length, int) and latent_seq_length > 0 @@ -40,13 +40,13 @@ def __init__(self, latent_seq_length, alphabet_length, self.length_model = length_model assert isinstance(prior_scale, float) self.prior_scale = prior_scale - assert isinstance(indel_prior_strength, float) - self.indel_prior = torch.tensor([indel_prior_strength, 0.]) + assert isinstance(indel_prior_bias, float) + self.indel_prior = torch.tensor([indel_prior_bias, 0.]) # Initialize state arranger. self.statearrange = Profile(latent_seq_length) - def model(self, data): + def model(self, data=None): # Latent sequence. precursor_seq = pyro.sample("precursor_seq", dist.Normal( @@ -95,7 +95,7 @@ def model(self, data): observation_logits), obs=seq_data_ind) - def guide(self, data): + def guide(self, data=None): # Sequence. precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)) From 08f16cb96b520b47fd980ae5a3f6c9308d502f7b Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 14:44:38 -0500 Subject: [PATCH 42/91] Rearrange profile HMM for jit compilation. --- pyro/contrib/mue/models.py | 52 ++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index f65c21c079..21dc17af11 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -12,14 +12,18 @@ import torch.nn as nn from torch.nn.functional import softplus from torch.optim import Adam +from torch.utils.data import DataLoader import pyro +from pyro import poutine import pyro.distributions as dist from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM from pyro.contrib.mue.statearrangers import Profile from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import MultiStepLR +import pdb + class ProfileHMM(nn.Module): """Model: Constant + MuE. """ @@ -46,7 +50,7 @@ def __init__(self, latent_seq_length, alphabet_length, # Initialize state arranger. self.statearrange = Profile(latent_seq_length) - def model(self, data=None): + def model(self, seq_data, L_data, local_scale, local_length=1): # Latent sequence. precursor_seq = pyro.sample("precursor_seq", dist.Normal( @@ -81,21 +85,19 @@ def model(self, data=None): torch.tensor(200.), torch.tensor(1000.))) L_mean = softplus(length) - # Draw samples. - with pyro.plate("batch", len(data), - subsample_size=self.batch_size) as ind: + with pyro.plate("batch", local_length): + with poutine.scale(scale=local_scale): - seq_data_ind, L_data_ind = data[ind] - if self.length_model: - pyro.sample("obs_L", dist.Poisson(L_mean), - obs=L_data_ind) - pyro.sample("obs_seq", - MissingDataDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=seq_data_ind) + if self.length_model: + pyro.sample("obs_L", dist.Poisson(L_mean), + obs=L_data) + pyro.sample("obs_seq", + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=seq_data) - def guide(self, data=None): + def guide(self, seq_data, L_data, local_scale, local_length=1): # Sequence. precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)) @@ -133,7 +135,7 @@ def guide(self, data=None): pyro.sample("length", dist.Normal( length_q_mn, softplus(length_q_sd))) - def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): + def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None): """Infer model parameters with stochastic variational inference.""" # Setup. @@ -144,20 +146,20 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) - n_steps = int(np.ceil(torch.tensor(len(dataset)/self.batch_size)) - )*epochs - svi = SVI(self.model, self.guide, scheduler, loss=Trace_ELBO()) + svi = SVI(self.model, self.guide, scheduler, loss=JitTrace_ELBO()) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Run inference. losses = [] t0 = datetime.datetime.now() - for step in range(n_steps): - loss = svi.step(dataset) - losses.append(loss) - scheduler.step() - if (step + 1) % (n_steps/epochs) == 0: - print(int(epochs*(step+1)/n_steps), loss, ' ', - datetime.datetime.now() - t0) + for epoch in range(epochs): + for seq_data, L_data in dataload: + loss = svi.step(seq_data, L_data, + torch.tensor(dataset.data_size/L_data.shape[0]), + local_length=L_data.shape[0]) + losses.append(loss) + scheduler.step() + print(epoch, loss, ' ', datetime.datetime.now() - t0) return losses From 6c6a8462a5e14b3e6d471d6c1d6146506fd95abb Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 15:03:09 -0500 Subject: [PATCH 43/91] Debug jit compile ELBO in profile HMM --- pyro/contrib/mue/models.py | 17 +++++++++++------ tests/contrib/mue/test_models.py | 5 +++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 21dc17af11..14aec53afe 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -50,7 +50,7 @@ def __init__(self, latent_seq_length, alphabet_length, # Initialize state arranger. self.statearrange = Profile(latent_seq_length) - def model(self, seq_data, L_data, local_scale, local_length=1): + def model(self, seq_data, L_data, local_scale, local_num=1): # Latent sequence. precursor_seq = pyro.sample("precursor_seq", dist.Normal( @@ -85,7 +85,7 @@ def model(self, seq_data, L_data, local_scale, local_length=1): torch.tensor(200.), torch.tensor(1000.))) L_mean = softplus(length) - with pyro.plate("batch", local_length): + with pyro.plate("batch", local_num): with poutine.scale(scale=local_scale): if self.length_model: @@ -97,7 +97,7 @@ def model(self, seq_data, L_data, local_scale, local_length=1): observation_logits), obs=seq_data) - def guide(self, seq_data, L_data, local_scale, local_length=1): + def guide(self, seq_data, L_data, local_scale, local_num=1): # Sequence. precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)) @@ -135,7 +135,8 @@ def guide(self, seq_data, L_data, local_scale, local_length=1): pyro.sample("length", dist.Normal( length_q_mn, softplus(length_q_sd))) - def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None): + def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, + jit=False): """Infer model parameters with stochastic variational inference.""" # Setup. @@ -146,7 +147,11 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None): 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) - svi = SVI(self.model, self.guide, scheduler, loss=JitTrace_ELBO()) + if jit: + Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + Elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=Elbo) dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Run inference. @@ -156,7 +161,7 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None): for seq_data, L_data in dataload: loss = svi.step(seq_data, L_data, torch.tensor(dataset.data_size/L_data.shape[0]), - local_length=L_data.shape[0]) + local_num=L_data.shape[0]) losses.append(loss) scheduler.step() print(epoch, loss, ' ', datetime.datetime.now() - t0) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 5f716a9c33..c0b864ea1c 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -12,7 +12,8 @@ @pytest.mark.parametrize('length_model', [False, True]) -def test_ProfileHMM_smoke(length_model): +@pytest.mark.parametrize('jit', [False, True]) +def test_ProfileHMM_smoke(length_model, jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] alph = ['A', 'B'] @@ -27,7 +28,7 @@ def test_ProfileHMM_smoke(length_model): length_model) n_epochs = 5 batch_size = 2 - losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) assert not np.isnan(losses[-1]) From ede13d42175f53f2e72666e7c94bf65263dfcbd9 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 15:49:36 -0500 Subject: [PATCH 44/91] Reconfigure FactorMuE for jit compilation --- pyro/contrib/mue/models.py | 146 +++++++++++++++++-------------- tests/contrib/mue/test_models.py | 17 ++-- 2 files changed, 87 insertions(+), 76 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 14aec53afe..88050e3c8d 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -50,7 +50,7 @@ def __init__(self, latent_seq_length, alphabet_length, # Initialize state arranger. self.statearrange = Profile(latent_seq_length) - def model(self, seq_data, L_data, local_scale, local_num=1): + def model(self, seq_data, L_data, local_scale): # Latent sequence. precursor_seq = pyro.sample("precursor_seq", dist.Normal( @@ -85,7 +85,7 @@ def model(self, seq_data, L_data, local_scale, local_num=1): torch.tensor(200.), torch.tensor(1000.))) L_mean = softplus(length) - with pyro.plate("batch", local_num): + with pyro.plate("batch", L_data.shape[0]): with poutine.scale(scale=local_scale): if self.length_model: @@ -97,7 +97,7 @@ def model(self, seq_data, L_data, local_scale, local_num=1): observation_logits), obs=seq_data) - def guide(self, seq_data, L_data, local_scale, local_num=1): + def guide(self, seq_data, L_data, local_scale): # Sequence. precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)) @@ -160,8 +160,7 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, for epoch in range(epochs): for seq_data, L_data in dataload: loss = svi.step(seq_data, L_data, - torch.tensor(dataset.data_size/L_data.shape[0]), - local_num=L_data.shape[0]) + torch.tensor(dataset.data_size/L_data.shape[0])) losses.append(loss) scheduler.step() print(epoch, loss, ' ', datetime.datetime.now() - t0) @@ -295,7 +294,7 @@ def decoder(self, z, W, B, inverse_temp): return out - def model(self, data): + def model(self, seq_data, L_data, local_scale): # ARD prior. if self.ARD_prior: @@ -342,48 +341,47 @@ def model(self, data): self.latent_alphabet_length, self.alphabet_length]) ).to_event(2)) - with pyro.plate("batch", len(data), - subsample_size=self.batch_size) as ind: - # Sample latent variable from prior. - if self.z_prior_distribution == 'Normal': - z = pyro.sample("latent", dist.Normal( - torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) - elif self.z_prior_distribution == 'Laplace': - z = pyro.sample("latent", dist.Laplace( - torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) - - # Decode latent sequence. - decoded = self.decoder(z, W, B, inverse_temp) - if self.indel_factor_dependence: - insert_logits = decoded['insert_logits'] - delete_logits = decoded['delete_logits'] - - # Construct HMM parameters. - if self.substitution_matrix: - initial_logits, transition_logits, observation_logits = ( - self.statearrange(decoded['precursor_seq_logits'], - decoded['insert_seq_logits'], - insert_logits, delete_logits, - substitute)) - else: - initial_logits, transition_logits, observation_logits = ( - self.statearrange(decoded['precursor_seq_logits'], - decoded['insert_seq_logits'], - insert_logits, delete_logits)) - # Draw samples. - seq_data_ind, L_data_ind = data[ind] - if self.length_model: - pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), - obs=L_data_ind) - pyro.sample("obs_seq", - MissingDataDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=seq_data_ind) - - def guide(self, data): + with pyro.plate("batch", L_data.shape[0]): + with poutine.scale(scale=local_scale): + # Sample latent variable from prior. + if self.z_prior_distribution == 'Normal': + z = pyro.sample("latent", dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + z = pyro.sample("latent", dist.Laplace( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + + # Decode latent sequence. + decoded = self.decoder(z, W, B, inverse_temp) + if self.indel_factor_dependence: + insert_logits = decoded['insert_logits'] + delete_logits = decoded['delete_logits'] + + # Construct HMM parameters. + if self.substitution_matrix: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits, + substitute)) + else: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits)) + # Draw samples. + if self.length_model: + pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), + obs=L_data) + pyro.sample("obs_seq", + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=seq_data) + + def guide(self, seq_data, L_data, local_scale): # Register encoder with pyro. pyro.module("encoder", self.encoder) @@ -436,17 +434,21 @@ def guide(self, data): substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) # Per data latent variables. - with pyro.plate("batch", len(data), - subsample_size=self.batch_size) as ind: + with pyro.plate("batch", L_data.shape[0]): # Encode sequences. - z_loc, z_scale = self.encoder(data[ind][0]) - # Sample. - if self.z_prior_distribution == 'Normal': - pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) - elif self.z_prior_distribution == 'Laplace': - pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) - - def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): + z_loc, z_scale = self.encoder(seq_data) + # Scale log likelihood since mini-batching. + with poutine.scale(scale=local_scale): + # Sample. + if self.z_prior_distribution == 'Normal': + pyro.sample("latent", + dist.Normal(z_loc, z_scale).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + pyro.sample("latent", + dist.Laplace(z_loc, z_scale).to_event(1)) + + def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None, + jit=False): """Infer model parameters with stochastic variational inference.""" # Setup. @@ -457,20 +459,28 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None): 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) - n_steps = int(np.ceil(torch.tensor(len(dataset)/self.batch_size)) - )*epochs - svi = SVI(self.model, self.guide, scheduler, loss=Trace_ELBO()) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) + # Initialize guide. + for seq_data, L_data in dataload: + self.guide(seq_data, L_data, torch.tensor(1.)) + break + if jit: + Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + Elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=Elbo) # Run inference. losses = [] t0 = datetime.datetime.now() - for step in range(n_steps): - loss = svi.step(dataset) - losses.append(loss) - scheduler.step() - if (step + 1) % (n_steps/epochs) == 0: - print(int(epochs*(step+1)/n_steps), loss, ' ', - datetime.datetime.now() - t0) + for epoch in range(epochs): + for seq_data, L_data in dataload: + print(seq_data) + loss = svi.step(seq_data, L_data, + torch.tensor(dataset.data_size/L_data.shape[0])) + losses.append(loss) + scheduler.step() + print(epoch, loss, ' ', datetime.datetime.now() - t0) return losses def reconstruct_precursor_seq(self, data, ind, param): diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index c0b864ea1c..4fe707385d 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -33,13 +33,14 @@ def test_ProfileHMM_smoke(length_model, jit): assert not np.isnan(losses[-1]) -@pytest.mark.parametrize('indel_factor_dependence', [False, True]) -@pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) -@pytest.mark.parametrize('ARD_prior', [False, True]) -@pytest.mark.parametrize('substitution_matrix', [False, True]) -@pytest.mark.parametrize('length_model', [False, True]) +@pytest.mark.parametrize('indel_factor_dependence', [False])#, True]) +@pytest.mark.parametrize('z_prior_distribution', ['Normal'])#, 'Laplace']) +@pytest.mark.parametrize('ARD_prior', [False])#, True]) +@pytest.mark.parametrize('substitution_matrix', [False])#, True]) +@pytest.mark.parametrize('length_model', [False])#, True]) +@pytest.mark.parametrize('jit', [False, True]) def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, - ARD_prior, substitution_matrix, length_model): + ARD_prior, substitution_matrix, length_model, jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] alph = ['A', 'B'] @@ -58,8 +59,8 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, substitution_matrix=substitution_matrix, length_model=length_model) n_epochs = 5 - batch_size = 2 - losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler) + batch_size = 3 + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) # Reconstruct. recon = model.reconstruct_precursor_seq(dataset, 1, pyro.param) From 1dcd4a83dec52df0e4c99218473519582c9d689b Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 17:01:27 -0500 Subject: [PATCH 45/91] Debug jit compilation for FactorMuE. --- pyro/contrib/mue/models.py | 35 ++++++++++++++++---------------- tests/contrib/mue/test_models.py | 10 ++++----- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 88050e3c8d..71a9be2570 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -147,6 +147,7 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) + self.guide(None, None, None) if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: @@ -266,13 +267,14 @@ def decoder(self, z, W, B, inverse_temp): out = dict() if self.length_model: # Extract expected length. - v, L_v = v.split([self.total_factor_size-1, 1], dim=1) - out['L_mean'] = softplus(L_v).squeeze(1) + L_v = v[:, -1] + out['L_mean'] = softplus(L_v) if self.indel_factor_dependence: # Extract insertion and deletion parameters. - v, insert_v, delete_v = v.split([ - (2*self.latent_seq_length+1)*self.latent_alphabet_length, - self.latent_seq_length*3*2, self.latent_seq_length*3*2], dim=1) + ind0 = (2*self.latent_seq_length+1)*self.latent_alphabet_length + ind1 = ind0 + self.latent_seq_length*3*2 + ind2 = ind1 + self.latent_seq_length*3*2 + insert_v, delete_v = v[:, ind0:ind1], v[:, ind1:ind2] insert_v = (insert_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior) out['insert_logits'] = insert_v - insert_v.logsumexp(-1, True) @@ -280,14 +282,14 @@ def decoder(self, z, W, B, inverse_temp): + self.indel_prior) out['delete_logits'] = delete_v - delete_v.logsumexp(-1, True) # Extract precursor and insertion sequences. - precursor_seq_v, insert_seq_v = (v*softplus(inverse_temp)).split([ - self.latent_seq_length*self.latent_alphabet_length, - (self.latent_seq_length+1)*self.latent_alphabet_length], dim=1) - precursor_seq_v = precursor_seq_v.reshape([ + ind0 = self.latent_seq_length*self.latent_alphabet_length + ind1 = ind0 + (self.latent_seq_length+1)*self.latent_alphabet_length + precursor_seq_v, insert_seq_v = v[:, :ind0], v[:, ind0:ind1] + precursor_seq_v = (precursor_seq_v*softplus(inverse_temp)).reshape([ -1, self.latent_seq_length, self.latent_alphabet_length]) out['precursor_seq_logits'] = ( precursor_seq_v - precursor_seq_v.logsumexp(-1, True)) - insert_seq_v = insert_seq_v.reshape([ + insert_seq_v = (insert_seq_v*softplus(inverse_temp)).reshape([ -1, self.latent_seq_length+1, self.latent_alphabet_length]) out['insert_seq_logits'] = ( insert_seq_v - insert_seq_v.logsumexp(-1, True)) @@ -346,8 +348,8 @@ def model(self, seq_data, L_data, local_scale): # Sample latent variable from prior. if self.z_prior_distribution == 'Normal': z = pyro.sample("latent", dist.Normal( - torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) elif self.z_prior_distribution == 'Laplace': z = pyro.sample("latent", dist.Laplace( torch.zeros(self.z_dim), torch.ones(self.z_dim) @@ -394,11 +396,11 @@ def guide(self, seq_data, L_data, local_scale): # Factors. W_q_mn = pyro.param("W_q_mn", torch.randn([ self.z_dim, self.total_factor_size])) - W_q_sd = pyro.param("W_q_sd", torch.randn([ + W_q_sd = pyro.param("W_q_sd", torch.ones([ self.z_dim, self.total_factor_size])) pyro.sample("W", dist.Normal(W_q_mn, softplus(W_q_sd)).to_event(2)) B_q_mn = pyro.param("B_q_mn", torch.randn(self.total_factor_size)) - B_q_sd = pyro.param("B_q_sd", torch.randn(self.total_factor_size)) + B_q_sd = pyro.param("B_q_sd", torch.ones(self.total_factor_size)) pyro.sample("B", dist.Normal(B_q_mn, softplus(B_q_sd)).to_event(1)) # Indel probabilities. @@ -433,11 +435,11 @@ def guide(self, seq_data, L_data, local_scale): pyro.sample("substitute", dist.Normal( substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) - # Per data latent variables. + # Per datapoint local latent variables. with pyro.plate("batch", L_data.shape[0]): # Encode sequences. z_loc, z_scale = self.encoder(seq_data) - # Scale log likelihood since mini-batching. + # Scale log likelihood to account for mini-batching. with poutine.scale(scale=local_scale): # Sample. if self.z_prior_distribution == 'Normal': @@ -475,7 +477,6 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None, t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: - print(seq_data) loss = svi.step(seq_data, L_data, torch.tensor(dataset.data_size/L_data.shape[0])) losses.append(loss) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 4fe707385d..76452a22c9 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -33,11 +33,11 @@ def test_ProfileHMM_smoke(length_model, jit): assert not np.isnan(losses[-1]) -@pytest.mark.parametrize('indel_factor_dependence', [False])#, True]) -@pytest.mark.parametrize('z_prior_distribution', ['Normal'])#, 'Laplace']) -@pytest.mark.parametrize('ARD_prior', [False])#, True]) -@pytest.mark.parametrize('substitution_matrix', [False])#, True]) -@pytest.mark.parametrize('length_model', [False])#, True]) +@pytest.mark.parametrize('indel_factor_dependence', [False, True]) +@pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) +@pytest.mark.parametrize('ARD_prior', [False, True]) +@pytest.mark.parametrize('substitution_matrix', [False, True]) +@pytest.mark.parametrize('length_model', [False, True]) @pytest.mark.parametrize('jit', [False, True]) def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, ARD_prior, substitution_matrix, length_model, jit): From 07acd22e417675a8bf62f67b40e07728ae6f761a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 17:23:37 -0500 Subject: [PATCH 46/91] Beta annealing in FactorMuE. --- pyro/contrib/mue/models.py | 40 +++++++++++++++++++++----------- tests/contrib/mue/test_models.py | 10 ++++++-- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 71a9be2570..2d9681fd27 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -296,7 +296,7 @@ def decoder(self, z, W, B, inverse_temp): return out - def model(self, seq_data, L_data, local_scale): + def model(self, seq_data, L_data, local_scale, local_prior_scale): # ARD prior. if self.ARD_prior: @@ -345,15 +345,16 @@ def model(self, seq_data, L_data, local_scale): with pyro.plate("batch", L_data.shape[0]): with poutine.scale(scale=local_scale): - # Sample latent variable from prior. - if self.z_prior_distribution == 'Normal': - z = pyro.sample("latent", dist.Normal( + with poutine.scale(scale=local_prior_scale): + # Sample latent variable from prior. + if self.z_prior_distribution == 'Normal': + z = pyro.sample("latent", dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + z = pyro.sample("latent", dist.Laplace( torch.zeros(self.z_dim), torch.ones(self.z_dim) ).to_event(1)) - elif self.z_prior_distribution == 'Laplace': - z = pyro.sample("latent", dist.Laplace( - torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) # Decode latent sequence. decoded = self.decoder(z, W, B, inverse_temp) @@ -383,7 +384,7 @@ def model(self, seq_data, L_data, local_scale): observation_logits), obs=seq_data) - def guide(self, seq_data, L_data, local_scale): + def guide(self, seq_data, L_data, local_scale, local_prior_scale): # Register encoder with pyro. pyro.module("encoder", self.encoder) @@ -440,7 +441,7 @@ def guide(self, seq_data, L_data, local_scale): # Encode sequences. z_loc, z_scale = self.encoder(seq_data) # Scale log likelihood to account for mini-batching. - with poutine.scale(scale=local_scale): + with poutine.scale(scale=local_scale*local_prior_scale): # Sample. if self.z_prior_distribution == 'Normal': pyro.sample("latent", @@ -449,8 +450,8 @@ def guide(self, seq_data, L_data, local_scale): pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) - def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None, - jit=False): + def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, + scheduler=None, jit=False): """Infer model parameters with stochastic variational inference.""" # Setup. @@ -464,8 +465,9 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None, dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize guide. for seq_data, L_data in dataload: - self.guide(seq_data, L_data, torch.tensor(1.)) + self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) break + # Setup stochastic variational inference. if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: @@ -474,16 +476,26 @@ def fit_svi(self, dataset, epochs=1, batch_size=None, scheduler=None, # Run inference. losses = [] + step_i = 1 t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: loss = svi.step(seq_data, L_data, - torch.tensor(dataset.data_size/L_data.shape[0])) + torch.tensor(dataset.data_size/L_data.shape[0]), + self._beta_anneal(step_i, batch_size, + dataset.data_size, + anneal_length)) losses.append(loss) scheduler.step() + step_i += 1 print(epoch, loss, ' ', datetime.datetime.now() - t0) return losses + def _beta_anneal(self, step, batch_size, data_size, anneal_length): + """Annealing schedule for prior KL term (beta annealing).""" + anneal_frac = step*batch_size/(anneal_length*data_size) + return torch.tensor(min([anneal_frac, 1.])) + def reconstruct_precursor_seq(self, data, ind, param): # Encode seq. z_loc = self.encoder(data[ind][0])[0] diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 76452a22c9..391b4538ad 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import torch from torch.optim import Adam import pyro @@ -59,11 +60,16 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, substitution_matrix=substitution_matrix, length_model=length_model) n_epochs = 5 - batch_size = 3 - losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) + anneal_length = 2 + batch_size = 2 + losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size, + scheduler, jit) # Reconstruct. recon = model.reconstruct_precursor_seq(dataset, 1, pyro.param) assert not np.isnan(losses[-1]) assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) + + assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5)) + assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.)) From 150669725d731dd321b76ba847a72fc9ff126de5 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 18:27:50 -0500 Subject: [PATCH 47/91] Evaluate train test elbo and perplexity for ProfileHMM --- pyro/contrib/mue/models.py | 40 ++++++++++++++++++++++++++++---- tests/contrib/mue/test_models.py | 7 ++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 2d9681fd27..c462405344 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -167,6 +167,36 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, print(epoch, loss, ' ', datetime.datetime.now() - t0) return losses + def evaluate(self, dataset_train, dataset_test, jit=False): + """Evaluate performance on train and test datasets.""" + self.guide(None, None, None) + if jit: + Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + Elbo = Trace_ELBO() + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.01}, + 'milestones': [], + 'gamma': 0.5}) + svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) + dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + train_lp, train_perplex = 0., 0. + for seq_data, L_data in dataload_train: + lp = svi.evaluate_loss( + seq_data, L_data, torch.tensor(dataset_train.data_size)) + train_lp += -lp + train_perplex += lp / (L_data[0] + int(self.length_model)) + train_perplex = np.exp(train_perplex) + test_lp, test_perplex = 0., 0. + for seq_data, L_data in dataload_test: + lp = svi.evaluate_loss( + seq_data, L_data, torch.tensor(dataset_test.data_size)) + test_lp += -lp + test_perplex += lp / (L_data[0].numpy() + int(self.length_model)) + test_perplex = np.exp(test_perplex) + return train_lp, test_lp, train_perplex, test_perplex + class Encoder(nn.Module): def __init__(self, data_length, alphabet_length, z_dim): @@ -480,11 +510,11 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: - loss = svi.step(seq_data, L_data, - torch.tensor(dataset.data_size/L_data.shape[0]), - self._beta_anneal(step_i, batch_size, - dataset.data_size, - anneal_length)) + loss = svi.step( + seq_data, L_data, + torch.tensor(dataset.data_size/L_data.shape[0]), + self._beta_anneal(step_i, batch_size, dataset.data_size, + anneal_length)) losses.append(loss) scheduler.step() step_i += 1 diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 391b4538ad..e0960dc37f 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -33,6 +33,13 @@ def test_ProfileHMM_smoke(length_model, jit): assert not np.isnan(losses[-1]) + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset, dataset, jit) + assert train_lp < 0. + assert test_lp < 0. + assert train_perplex > 0. + assert test_perplex > 0. + @pytest.mark.parametrize('indel_factor_dependence', [False, True]) @pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) From a4a620e5e02b54e3b4a1544294ff110244e8b9a1 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Feb 2021 18:54:55 -0500 Subject: [PATCH 48/91] Heldout likelihood evaluation for factormue. --- pyro/contrib/mue/dataloaders.py | 8 +++- pyro/contrib/mue/models.py | 73 +++++++++++++++++++++++--------- tests/contrib/mue/test_models.py | 9 ++++ 3 files changed, 68 insertions(+), 22 deletions(-) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 9f3df9580c..743d18edc5 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -16,7 +16,8 @@ class BiosequenceDataset(Dataset): """Load biological sequence data.""" - def __init__(self, source, source_type='list', alphabet='amino-acid'): + def __init__(self, source, source_type='list', alphabet='amino-acid', + max_length=None): super().__init__() @@ -28,7 +29,10 @@ def __init__(self, source, source_type='list', alphabet='amino-acid'): # Get lengths. self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) - self.max_length = int(torch.max(self.L_data)) + if max_length is None: + self.max_length = int(torch.max(self.L_data)) + else: + self.max_length = max_length self.data_size = len(self.L_data) # Get alphabet. diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index c462405344..d4fcdace2b 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -169,34 +169,33 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, def evaluate(self, dataset_train, dataset_test, jit=False): """Evaluate performance on train and test datasets.""" + dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) + dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) self.guide(None, None, None) if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: Elbo = Trace_ELBO() - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.01}, - 'milestones': [], - 'gamma': 0.5}) + scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) svi = SVI(self.model, self.guide, scheduler, loss=Elbo) - dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) - dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) - train_lp, train_perplex = 0., 0. - for seq_data, L_data in dataload_train: - lp = svi.evaluate_loss( - seq_data, L_data, torch.tensor(dataset_train.data_size)) - train_lp += -lp - train_perplex += lp / (L_data[0] + int(self.length_model)) - train_perplex = np.exp(train_perplex) - test_lp, test_perplex = 0., 0. - for seq_data, L_data in dataload_test: - lp = svi.evaluate_loss( - seq_data, L_data, torch.tensor(dataset_test.data_size)) - test_lp += -lp - test_perplex += lp / (L_data[0].numpy() + int(self.length_model)) - test_perplex = np.exp(test_perplex) + # Compute elbo and perplexity. + train_lp, train_perplex = self._evaluate_elbo( + svi, dataload_train, dataset_train.data_size, self.length_model) + test_lp, test_perplex = self._evaluate_elbo( + svi, dataload_test, dataset_test.data_size, self.length_model) return train_lp, test_lp, train_perplex, test_perplex + def _evaluate_elbo(self, svi, dataload, data_size, length_model): + """Evaluate elbo and average per residue perplexity.""" + lp, perplex = 0., 0. + for seq_data, L_data in dataload: + lp_i = svi.evaluate_loss( + seq_data, L_data, torch.tensor(data_size)) / data_size + lp += -lp_i + perplex += lp_i / (L_data[0].numpy() + int(self.length_model)) + perplex = np.exp(perplex / data_size) + return lp, perplex + class Encoder(nn.Module): def __init__(self, data_length, alphabet_length, z_dim): @@ -526,6 +525,40 @@ def _beta_anneal(self, step, batch_size, data_size, anneal_length): anneal_frac = step*batch_size/(anneal_length*data_size) return torch.tensor(min([anneal_frac, 1.])) + def evaluate(self, dataset_train, dataset_test, jit=False): + """Evaluate performance on train and test datasets.""" + dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) + dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + # Initialize guide. + for seq_data, L_data in dataload_train: + self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + break + if jit: + Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + Elbo = Trace_ELBO() + scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + + # Compute elbo and perplexity. + train_lp, train_perplex = self._evaluate_elbo( + svi, dataload_train, dataset_train.data_size, self.length_model) + test_lp, test_perplex = self._evaluate_elbo( + svi, dataload_test, dataset_test.data_size, self.length_model) + return train_lp, test_lp, train_perplex, test_perplex + + def _evaluate_elbo(self, svi, dataload, data_size, length_model): + """Evaluate elbo and average per residue perplexity.""" + lp, perplex = 0., 0. + for seq_data, L_data in dataload: + lp_i = svi.evaluate_loss( + seq_data, L_data, torch.tensor(data_size), + torch.tensor(1.)) / data_size + lp += -lp_i + perplex += lp_i / (L_data[0].numpy() + int(self.length_model)) + perplex = np.exp(perplex / data_size) + return lp, perplex + def reconstruct_precursor_seq(self, data, ind, param): # Encode seq. z_loc = self.encoder(data[ind][0])[0] diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index e0960dc37f..ebe8781f18 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -33,6 +33,7 @@ def test_ProfileHMM_smoke(length_model, jit): assert not np.isnan(losses[-1]) + # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( dataset, dataset, jit) assert train_lp < 0. @@ -80,3 +81,11 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5)) assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.)) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset, dataset, jit) + assert train_lp < 0. + assert test_lp < 0. + assert train_perplex > 0. + assert test_perplex > 0. From 26061b71be11c32e6b60fd13d478904ca6660482 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 15:37:30 -0500 Subject: [PATCH 49/91] switch to local elbo evaluation and add more options to factor example. --- examples/contrib/mue/FactorMuE.py | 58 ++++++++++++++-- pyro/contrib/mue/models.py | 110 ++++++++++++++++++++++++------ tests/contrib/mue/test_models.py | 6 ++ 3 files changed, 146 insertions(+), 28 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 95c7169b4f..ac3456b79a 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -8,6 +8,7 @@ import argparse import datetime import json +import numpy as np import os import matplotlib.pyplot as plt @@ -19,6 +20,8 @@ from pyro.contrib.mue.models import FactorMuE from pyro.optim import MultiStepLR +import pdb + def generate_data(small_test): """Generate example dataset.""" @@ -35,17 +38,26 @@ def generate_data(small_test): def main(args): - pyro.set_rng_seed(args.rng_seed) - # Load dataset. if args.test: dataset = generate_data(args.small) else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) args.batch_size = min([dataset.data_size, args.batch_size]) + if args.split is not None: + pyro.set_rng_seed(args.rng_data_seed) + heldout_num = int(np.ceil(args.split*len(dataset))) + dataset_train, dataset_test = torch.utils.data.random_split( + dataset, [dataset.data_size - heldout_num, heldout_num]) + else: + dataset_test = dataset + + # Random sampler. + pyro.set_rng_seed(args.rng_seed) # Construct model. - model = FactorMuE(dataset.max_length, dataset.alphabet_length, args.z_dim, + model = FactorMuE(dataset.max_length, dataset.alphabet_length, + args.z_dim, batch_size=args.batch_size, latent_seq_length=args.latent_seq_length, indel_factor_dependence=args.indel_factor, @@ -70,7 +82,17 @@ def main(args): n_epochs = 100 else: n_epochs = args.n_epochs - losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler) + losses = model.fit_svi(dataset_train, n_epochs, args.anneal, + args.batch_size, scheduler, args.jit) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset_train, dataset_test, args.jit) + print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) + print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) + + # Embed. + z_locs, z_scales = model.embed(dataset_train, dataset_test) # Plot and save. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -85,8 +107,7 @@ def main(args): 'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) plt.figure(figsize=(6, 6)) - latent = model.encoder(dataset.seq_data)[0].detach() - plt.scatter(latent[:, 0], latent[:, 1]) + plt.scatter(z_locs[:, 0], z_locs[:, 1]) plt.xlabel('z_1') plt.ylabel('z_2') if args.save: @@ -119,6 +140,23 @@ def main(args): pyro.get_param_store().save(os.path.join( args.out_folder, 'FactorMuE_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('train_lp,test_lp,train_perplex,test_perplex\n') + ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, + test_perplex)) + np.savetxt(os.path.join( + args.out_folder, + 'FactorMuE_results.embed_loc_{}.txt'.format( + time_stamp)), + z_locs.numpy()) + np.savetxt(os.path.join( + args.out_folder, + 'FactorMuE_results.embed_scale_{}.txt'.format( + time_stamp)), + z_scales.numpy()) with open(os.path.join( args.out_folder, 'FactorMuE_results.input_{}.txt'.format(time_stamp)), @@ -135,6 +173,7 @@ def main(args): parser.add_argument("--small", action='store_true', default=False, help='Run with small example dataset.') parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("--rng-data-seed", default=0, type=int) parser.add_argument("-f", "--file", default=None, type=str, help='Input file (fasta format).') parser.add_argument("-a", "--alphabet", default='amino-acid', @@ -178,12 +217,19 @@ def main(args): help='Gamma parameter for multistage learning rate.') parser.add_argument("-e", "--n-epochs", default=10, type=int, help='Number of epochs of training.') + parser.add_argument("--anneal", default=0., type=float, + help='Number of epochs to anneal beta over.') parser.add_argument("-p", "--plots", default=True, type=bool, help='Make plots.') parser.add_argument("-s", "--save", default=True, type=bool, help='Save plots and results.') parser.add_argument("-outf", "--out-folder", default='.', help='Folder to save plots.') + parser.add_argument("--split", default=0.2, type=float, + help=('Fraction of dataset to holdout for testing' + + '(float or None).')) + parser.add_argument("--jit", default=False, type=bool, + help='JIT compile the ELBO.') args = parser.parse_args() torch.set_default_dtype(torch.float64) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index d4fcdace2b..7daa4f9e6c 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -161,16 +161,18 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, for epoch in range(epochs): for seq_data, L_data in dataload: loss = svi.step(seq_data, L_data, - torch.tensor(dataset.data_size/L_data.shape[0])) + torch.tensor(len(dataset)/L_data.shape[0])) losses.append(loss) scheduler.step() print(epoch, loss, ' ', datetime.datetime.now() - t0) return losses - def evaluate(self, dataset_train, dataset_test, jit=False): + def evaluate(self, dataset_train, dataset_test=None, jit=False): """Evaluate performance on train and test datasets.""" dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) - dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=1, + shuffle=False) self.guide(None, None, None) if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) @@ -181,9 +183,12 @@ def evaluate(self, dataset_train, dataset_test, jit=False): # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_elbo( svi, dataload_train, dataset_train.data_size, self.length_model) - test_lp, test_perplex = self._evaluate_elbo( - svi, dataload_test, dataset_test.data_size, self.length_model) - return train_lp, test_lp, train_perplex, test_perplex + if dataset_test is not None: + test_lp, test_perplex = self._evaluate_elbo( + svi, dataload_test, dataset_test.data_size, self.length_model) + return train_lp, test_lp, train_perplex, test_perplex + else: + return train_lp, None, train_perplex, None def _evaluate_elbo(self, svi, dataload, data_size, length_model): """Evaluate elbo and average per residue perplexity.""" @@ -511,24 +516,46 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, for seq_data, L_data in dataload: loss = svi.step( seq_data, L_data, - torch.tensor(dataset.data_size/L_data.shape[0]), - self._beta_anneal(step_i, batch_size, dataset.data_size, + torch.tensor(len(dataset)/L_data.shape[0]), + self._beta_anneal(step_i, batch_size, len(dataset), anneal_length)) losses.append(loss) scheduler.step() step_i += 1 print(epoch, loss, ' ', datetime.datetime.now() - t0) + + """for seq_data, L_data in dataload: + conditioned_model = poutine.condition(self.model, data={ + "obs_L": L_data, "obs_seq": seq_data}) + args = (seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + guide_tr = poutine.trace(self.guide).get_trace(*args) + model_tr = poutine.trace(poutine.replay( + conditioned_model, trace=guide_tr)).get_trace(*args) + model_tr.compute_log_prob() + for ke in list(model_tr.nodes.keys()): + if ke[0] != '_': + try: + print(ke, model_tr.nodes[ke]['log_prob']) + except: + print('no log prob for ', ke) + pdb.set_trace() + print('here')""" + return losses def _beta_anneal(self, step, batch_size, data_size, anneal_length): """Annealing schedule for prior KL term (beta annealing).""" + if np.allclose(anneal_length, 0.): + return torch.tensor(1.) anneal_frac = step*batch_size/(anneal_length*data_size) return torch.tensor(min([anneal_frac, 1.])) - def evaluate(self, dataset_train, dataset_test, jit=False): + def evaluate(self, dataset_train, dataset_test=None, jit=False): """Evaluate performance on train and test datasets.""" dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) - dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=1, + shuffle=False) # Initialize guide. for seq_data, L_data in dataload_train: self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) @@ -541,24 +568,63 @@ def evaluate(self, dataset_train, dataset_test, jit=False): svi = SVI(self.model, self.guide, scheduler, loss=Elbo) # Compute elbo and perplexity. - train_lp, train_perplex = self._evaluate_elbo( - svi, dataload_train, dataset_train.data_size, self.length_model) - test_lp, test_perplex = self._evaluate_elbo( - svi, dataload_test, dataset_test.data_size, self.length_model) - return train_lp, test_lp, train_perplex, test_perplex + train_lp, train_perplex = self._evaluate_local_elbo( + svi, dataload_train, len(dataset_train), self.length_model) + if dataset_test is not None: + test_lp, test_perplex = self._evaluate_local_elbo( + svi, dataload_test, len(dataset_test), + self.length_model) + return train_lp, test_lp, train_perplex, test_perplex + else: + return train_lp, None, train_perplex, None - def _evaluate_elbo(self, svi, dataload, data_size, length_model): + def _local_variables(self, name, site): + """Return per datapoint random variables in model.""" + return name in ['latent', 'obs_L', 'obs_seq'] + + def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): """Evaluate elbo and average per residue perplexity.""" lp, perplex = 0., 0. - for seq_data, L_data in dataload: - lp_i = svi.evaluate_loss( - seq_data, L_data, torch.tensor(data_size), - torch.tensor(1.)) / data_size - lp += -lp_i - perplex += lp_i / (L_data[0].numpy() + int(self.length_model)) + with torch.no_grad(): + for seq_data, L_data in dataload: + conditioned_model = poutine.condition(self.model, data={ + "obs_L": L_data, "obs_seq": seq_data}) + args = (seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + guide_tr = poutine.trace(self.guide).get_trace(*args) + model_tr = poutine.trace(poutine.replay( + conditioned_model, trace=guide_tr)).get_trace(*args) + local_elbo = (model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ).numpy() + lp += local_elbo + perplex += -local_elbo / (L_data[0].numpy() + + int(self.length_model)) perplex = np.exp(perplex / data_size) return lp, perplex + def embed(self, dataset_train, dataset_test=None, batch_size=None): + """Get latent space embedding.""" + if batch_size is None: + batch_size = self.batch_size + dataload_train = DataLoader(dataset_train, batch_size=batch_size, + shuffle=False) + + z_locs, z_scales = [], [] + for seq_data, L_data in dataload_train: + z_loc, z_scale = self.encoder(seq_data) + z_locs.append(z_loc.detach()) + z_scales.append(z_scale.detach()) + + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=batch_size, + shuffle=False) + for seq_data, L_data in dataload_test: + z_loc, z_scale = self.encoder(seq_data) + z_locs.append(z_loc.detach()) + z_scales.append(z_scale.detach()) + + return torch.cat(z_locs), torch.cat(z_scales) + def reconstruct_precursor_seq(self, data, ind, param): # Encode seq. z_loc = self.encoder(data[ind][0])[0] diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index ebe8781f18..e04fcd8804 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -89,3 +89,9 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, assert test_lp < 0. assert train_perplex > 0. assert test_perplex > 0. + + # Embedding. + z_locs, z_scales = model.embed(dataset, dataset) + assert z_locs.shape == (len(dataset)*2, z_dim) + assert z_scales.shape == (len(dataset)*2, z_dim) + assert torch.all(z_scales > 0.) From 33730bd9fe37c50cd6d9f3a80d72450cccb4e5da Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 16:33:02 -0500 Subject: [PATCH 50/91] Cuda option. --- examples/contrib/mue/FactorMuE.py | 14 ++++--- pyro/contrib/mue/dataloaders.py | 8 ++-- pyro/contrib/mue/models.py | 69 +++++++++++++++---------------- 3 files changed, 46 insertions(+), 45 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index ac3456b79a..25365856ad 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -71,17 +71,15 @@ def main(args): substitution_matrix=args.substitution_matrix, substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, - length_model=args.length_model) + length_model=args.length_model, + cuda=args.cuda) # Infer. scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': args.learning_rate}, 'milestones': json.loads(args.milestones), 'gamma': args.learning_gamma}) - if args.test and not args.small: - n_epochs = 100 - else: - n_epochs = args.n_epochs + n_epochs = args.n_epochs losses = model.fit_svi(dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit) @@ -230,8 +228,12 @@ def main(args): '(float or None).')) parser.add_argument("--jit", default=False, type=bool, help='JIT compile the ELBO.') + parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') args = parser.parse_args() - torch.set_default_dtype(torch.float64) + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) main(args) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 743d18edc5..31160a6f1f 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -17,9 +17,10 @@ class BiosequenceDataset(Dataset): """Load biological sequence data.""" def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None): + max_length=None, device=torch.device('cpu')): super().__init__() + self.device = device # Get sequences. if source_type == 'list': @@ -28,7 +29,8 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', seqs = self._load_fasta(source) # Get lengths. - self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) + self.L_data = torch.tensor([float(len(seq)) for seq in seqs], + device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: @@ -68,7 +70,7 @@ def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] - ).astype(np.float64)) + ).astype(np.float64), device=self.device) # Pad. x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 7daa4f9e6c..50955ed667 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -236,8 +236,11 @@ def __init__(self, data_length, alphabet_length, z_dim, substitution_prior_scale=10., latent_alphabet_length=None, length_model=False, + cuda=False, epsilon=1e-32): super().__init__() + assert isinstance(cuda, bool) + self.cuda = cuda # Constants. assert isinstance(data_length, int) and data_length > 0 @@ -514,6 +517,8 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() loss = svi.step( seq_data, L_data, torch.tensor(len(dataset)/L_data.shape[0]), @@ -524,23 +529,6 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, step_i += 1 print(epoch, loss, ' ', datetime.datetime.now() - t0) - """for seq_data, L_data in dataload: - conditioned_model = poutine.condition(self.model, data={ - "obs_L": L_data, "obs_seq": seq_data}) - args = (seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) - guide_tr = poutine.trace(self.guide).get_trace(*args) - model_tr = poutine.trace(poutine.replay( - conditioned_model, trace=guide_tr)).get_trace(*args) - model_tr.compute_log_prob() - for ke in list(model_tr.nodes.keys()): - if ke[0] != '_': - try: - print(ke, model_tr.nodes[ke]['log_prob']) - except: - print('no log prob for ', ke) - pdb.set_trace() - print('here')""" - return losses def _beta_anneal(self, step, batch_size, data_size, anneal_length): @@ -558,6 +546,8 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): shuffle=False) # Initialize guide. for seq_data, L_data in dataload_train: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) break if jit: @@ -587,6 +577,8 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): lp, perplex = 0., 0. with torch.no_grad(): for seq_data, L_data in dataload: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() conditioned_model = poutine.condition(self.model, data={ "obs_L": L_data, "obs_seq": seq_data}) args = (seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) @@ -608,27 +600,32 @@ def embed(self, dataset_train, dataset_test=None, batch_size=None): batch_size = self.batch_size dataload_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False) - - z_locs, z_scales = [], [] - for seq_data, L_data in dataload_train: - z_loc, z_scale = self.encoder(seq_data) - z_locs.append(z_loc.detach()) - z_scales.append(z_scale.detach()) - - if dataset_test is not None: - dataload_test = DataLoader(dataset_test, batch_size=batch_size, - shuffle=False) - for seq_data, L_data in dataload_test: + with torch.no_grad(): + z_locs, z_scales = [], [] + for seq_data, L_data in dataload_train: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() z_loc, z_scale = self.encoder(seq_data) - z_locs.append(z_loc.detach()) - z_scales.append(z_scale.detach()) + z_locs.append(z_loc) + z_scales.append(z_scale) + + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=batch_size, + shuffle=False) + for seq_data, L_data in dataload_test: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() + z_loc, z_scale = self.encoder(seq_data) + z_locs.append(z_loc) + z_scales.append(z_scale) return torch.cat(z_locs), torch.cat(z_scales) def reconstruct_precursor_seq(self, data, ind, param): - # Encode seq. - z_loc = self.encoder(data[ind][0])[0] - # Reconstruct - decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), - param("inverse_temp_q_mn")) - return torch.exp(decoded['precursor_seq_logits']).detach() + with torch.no_grad(): + # Encode seq. + z_loc = self.encoder(data[ind][0])[0] + # Reconstruct + decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), + param("inverse_temp_q_mn")) + return torch.exp(decoded['precursor_seq_logits']) From 2304b0e726f0c839aafc613df31cf782aea1b543 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 16:36:01 -0500 Subject: [PATCH 51/91] pin memory option --- examples/contrib/mue/FactorMuE.py | 5 ++++- pyro/contrib/mue/models.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 25365856ad..4bb8d4a94e 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -72,7 +72,8 @@ def main(args): substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, length_model=args.length_model, - cuda=args.cuda) + cuda=args.cuda, + pin_memory=args.pin_mem) # Infer. scheduler = MultiStepLR({'optimizer': Adam, @@ -229,6 +230,8 @@ def main(args): parser.add_argument("--jit", default=False, type=bool, help='JIT compile the ELBO.') parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') + parser.add_argument("--pin-mem", default=False, type=bool, + help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() if args.cuda: diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 50955ed667..616b338c73 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -237,10 +237,13 @@ def __init__(self, data_length, alphabet_length, z_dim, latent_alphabet_length=None, length_model=False, cuda=False, + pin_memory=False, epsilon=1e-32): super().__init__() assert isinstance(cuda, bool) self.cuda = cuda + assert isinstance(pin_memory, bool) + self.pin_memory = pin_memory # Constants. assert isinstance(data_length, int) and data_length > 0 @@ -499,7 +502,8 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) - dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, + pin_memory=self.pin_memory) # Initialize guide. for seq_data, L_data in dataload: self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) From 04e9a22c4799d6da48a31d95e87d64e33d5a5826 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 18:58:41 -0500 Subject: [PATCH 52/91] Fix data tensor initialization. --- pyro/contrib/mue/dataloaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 31160a6f1f..6b970ae98b 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -72,7 +72,8 @@ def _one_hot(self, seq, alphabet, length): oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] ).astype(np.float64), device=self.device) # Pad. - x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)], + device=self.device)]) return x From e07cca79eb872d981a45dea8f79090fec81c77d7 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 19:04:06 -0500 Subject: [PATCH 53/91] Move data to cuda. --- pyro/contrib/mue/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 616b338c73..e5a00e239e 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -506,6 +506,8 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, pin_memory=self.pin_memory) # Initialize guide. for seq_data, L_data in dataload: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) break # Setup stochastic variational inference. From fd52f5b91aab4579e9e81ba6cb503f8791e438a2 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 19:30:47 -0500 Subject: [PATCH 54/91] Transfer results back to cpu. --- pyro/contrib/mue/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index e5a00e239e..e7b12760fd 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -593,7 +593,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): conditioned_model, trace=guide_tr)).get_trace(*args) local_elbo = (model_tr.log_prob_sum(self._local_variables) - guide_tr.log_prob_sum(self._local_variables) - ).numpy() + ).cpu().numpy() lp += local_elbo perplex += -local_elbo / (L_data[0].numpy() + int(self.length_model)) @@ -612,8 +612,8 @@ def embed(self, dataset_train, dataset_test=None, batch_size=None): if self.cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() z_loc, z_scale = self.encoder(seq_data) - z_locs.append(z_loc) - z_scales.append(z_scale) + z_locs.append(z_loc.cpu()) + z_scales.append(z_scale.cpu()) if dataset_test is not None: dataload_test = DataLoader(dataset_test, batch_size=batch_size, @@ -622,8 +622,8 @@ def embed(self, dataset_train, dataset_test=None, batch_size=None): if self.cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() z_loc, z_scale = self.encoder(seq_data) - z_locs.append(z_loc) - z_scales.append(z_scale) + z_locs.append(z_loc.cpu()) + z_scales.append(z_scale.cpu()) return torch.cat(z_locs), torch.cat(z_scales) From 15e4b17c5890f15d5bead4a96ae65792550116f4 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 19:32:23 -0500 Subject: [PATCH 55/91] Move results back to cpu. --- pyro/contrib/mue/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index e7b12760fd..54050a2074 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -595,7 +595,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): - guide_tr.log_prob_sum(self._local_variables) ).cpu().numpy() lp += local_elbo - perplex += -local_elbo / (L_data[0].numpy() + + perplex += -local_elbo / (L_data[0].cpu().numpy() + int(self.length_model)) perplex = np.exp(perplex / data_size) return lp, perplex From 028e305fc1b57fa69020bd7f0cb954cf62e41cc9 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 25 Feb 2021 19:33:46 -0500 Subject: [PATCH 56/91] Move more results to cpu. --- examples/contrib/mue/FactorMuE.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 4bb8d4a94e..ebc87235e8 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -118,7 +118,7 @@ def main(args): plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) - plt.plot(insert_expect[:, :, 1].numpy()) + plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of insert') if args.save: @@ -128,7 +128,7 @@ def main(args): plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) - plt.plot(delete_expect[:, :, 1].numpy()) + plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of delete') if args.save: @@ -150,12 +150,12 @@ def main(args): args.out_folder, 'FactorMuE_results.embed_loc_{}.txt'.format( time_stamp)), - z_locs.numpy()) + z_locs.cpu().numpy()) np.savetxt(os.path.join( args.out_folder, 'FactorMuE_results.embed_scale_{}.txt'.format( time_stamp)), - z_scales.numpy()) + z_scales.cpu().numpy()) with open(os.path.join( args.out_folder, 'FactorMuE_results.input_{}.txt'.format(time_stamp)), From 2311edcd4ad682ef8ab500a4a256b49ac0c0a163 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 09:18:59 -0500 Subject: [PATCH 57/91] Speed up initialization. --- pyro/contrib/mue/statearrangers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index a54d7bf361..a6194d6811 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -106,9 +106,8 @@ def _make_transfer(self): elif m + 1 - g < mp and gp == 0: self.r_transf[m+1-g, g, 0, k, kp] = 1 self.u_transf[m+1-g, g, 1, k, kp] = 1 - for mpp in range(m+2-g, mp): - self.r_transf[mpp, 2, 0, k, kp] = 1 - self.u_transf[mpp, 2, 1, k, kp] = 1 + self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 + self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 self.r_transf[mp, 2, 0, k, kp] = 1 self.u_transf[mp, 2, 0, k, kp] = 1 @@ -119,9 +118,8 @@ def _make_transfer(self): elif m + 1 - g < mp and gp == 1: self.r_transf[m+1-g, g, 0, k, kp] = 1 self.u_transf[m+1-g, g, 1, k, kp] = 1 - for mpp in range(m+2-g, mp): - self.r_transf[mpp, 2, 0, k, kp] = 1 - self.u_transf[mpp, 2, 1, k, kp] = 1 + self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 + self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 if mp < M: self.r_transf[mp, 2, 1, k, kp] = 1 From 8b02abec2f1aa2cbfaff5924da1cd96aa54bc0de Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 09:56:40 -0500 Subject: [PATCH 58/91] Move to device in generator. --- examples/contrib/mue/FactorMuE.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index ebc87235e8..1a8ac17b08 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -45,10 +45,15 @@ def main(args): dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split is not None: - pyro.set_rng_seed(args.rng_data_seed) heldout_num = int(np.ceil(args.split*len(dataset))) + if args.cuda: + device = 'cuda' + else: + device = 'cpu' dataset_train, dataset_test = torch.utils.data.random_split( - dataset, [dataset.data_size - heldout_num, heldout_num]) + dataset, [dataset.data_size - heldout_num, heldout_num], + generator=torch.Generator(device=device).manual_seed( + args.rng_data_seed)) else: dataset_test = dataset From a7b452b38d1b57ba1dc19b5d9355bb5c94649348 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 09:59:33 -0500 Subject: [PATCH 59/91] Adjust cuda device transfer. --- examples/contrib/mue/FactorMuE.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 1a8ac17b08..a84aa6d5d5 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -47,7 +47,7 @@ def main(args): if args.split is not None: heldout_num = int(np.ceil(args.split*len(dataset))) if args.cuda: - device = 'cuda' + device = 'cuda:0' else: device = 'cpu' dataset_train, dataset_test = torch.utils.data.random_split( From 490be6c1d1837209275209433e4683daf8c0b1b5 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:04:20 -0500 Subject: [PATCH 60/91] Move data to device for now for ease. --- examples/contrib/mue/FactorMuE.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index a84aa6d5d5..1d2d78a23f 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -23,7 +23,7 @@ import pdb -def generate_data(small_test): +def generate_data(small_test, device=torch.device('cpu')): """Generate example dataset.""" if small_test: mult_dat = 1 @@ -31,7 +31,7 @@ def generate_data(small_test): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) + dataset = BiosequenceDataset(seqs, 'list', ['A', 'B'], device=device) return dataset @@ -39,17 +39,18 @@ def generate_data(small_test): def main(args): # Load dataset. + if args.cuda: + device = torch.device('cuda') + else: + device = torch.device('cpu') if args.test: - dataset = generate_data(args.small) + dataset = generate_data(args.small, device=device) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, + device=device) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split is not None: heldout_num = int(np.ceil(args.split*len(dataset))) - if args.cuda: - device = 'cuda:0' - else: - device = 'cpu' dataset_train, dataset_test = torch.utils.data.random_split( dataset, [dataset.data_size - heldout_num, heldout_num], generator=torch.Generator(device=device).manual_seed( From c6e7f2dc4d5a870990708ce6a58da733b7c67619 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:08:27 -0500 Subject: [PATCH 61/91] Try another way of fixing cuda error. --- examples/contrib/mue/FactorMuE.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 1d2d78a23f..1dba3c24d5 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -40,7 +40,7 @@ def main(args): # Load dataset. if args.cuda: - device = torch.device('cuda') + device = torch.device('cuda:0') else: device = torch.device('cpu') if args.test: @@ -49,7 +49,7 @@ def main(args): dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, device=device) args.batch_size = min([dataset.data_size, args.batch_size]) - if args.split is not None: + if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) dataset_train, dataset_test = torch.utils.data.random_split( dataset, [dataset.data_size - heldout_num, heldout_num], @@ -231,8 +231,7 @@ def main(args): parser.add_argument("-outf", "--out-folder", default='.', help='Folder to save plots.') parser.add_argument("--split", default=0.2, type=float, - help=('Fraction of dataset to holdout for testing' + - '(float or None).')) + help=('Fraction of dataset to holdout for testing')) parser.add_argument("--jit", default=False, type=bool, help='JIT compile the ELBO.') parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') From f550447d4a0e98759bddfb620ca7042ed3751b3a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:13:25 -0500 Subject: [PATCH 62/91] Try disabling device transfer in dataloader entirely. --- pyro/contrib/mue/dataloaders.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 6b970ae98b..0ee3c7b7d3 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -17,10 +17,10 @@ class BiosequenceDataset(Dataset): """Load biological sequence data.""" def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None, device=torch.device('cpu')): + max_length=None): # , device=torch.device('cpu')): super().__init__() - self.device = device + # self.device = device # Get sequences. if source_type == 'list': @@ -29,8 +29,8 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', seqs = self._load_fasta(source) # Get lengths. - self.L_data = torch.tensor([float(len(seq)) for seq in seqs], - device=device) + self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) + # , device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: @@ -70,10 +70,10 @@ def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] - ).astype(np.float64), device=self.device) + ).astype(np.float64)) # , device=self.device) # Pad. - x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)], - device=self.device)]) + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) #, + # device=self.device)]) return x From ef943ccb5df926010e359b4ea7b6a6308847c553 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:14:48 -0500 Subject: [PATCH 63/91] Disable device handling entirely. --- examples/contrib/mue/FactorMuE.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 1dba3c24d5..18ad75e1f6 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -23,7 +23,7 @@ import pdb -def generate_data(small_test, device=torch.device('cpu')): +def generate_data(small_test): # , device=torch.device('cpu')): """Generate example dataset.""" if small_test: mult_dat = 1 @@ -31,7 +31,7 @@ def generate_data(small_test, device=torch.device('cpu')): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', ['A', 'B'], device=device) + dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) # , device=device) return dataset @@ -44,16 +44,16 @@ def main(args): else: device = torch.device('cpu') if args.test: - dataset = generate_data(args.small, device=device) + dataset = generate_data(args.small) # , device=device) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, - device=device) + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + # , device=device) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) dataset_train, dataset_test = torch.utils.data.random_split( dataset, [dataset.data_size - heldout_num, heldout_num], - generator=torch.Generator(device=device).manual_seed( + generator=torch.Generator().manual_seed( args.rng_data_seed)) else: dataset_test = dataset From ec9b21639526bf6b8ace2341bdf2f136f7df797b Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:15:52 -0500 Subject: [PATCH 64/91] Add back in device handling in random split --- examples/contrib/mue/FactorMuE.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 18ad75e1f6..3bb764e4e7 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -53,7 +53,7 @@ def main(args): heldout_num = int(np.ceil(args.split*len(dataset))) dataset_train, dataset_test = torch.utils.data.random_split( dataset, [dataset.data_size - heldout_num, heldout_num], - generator=torch.Generator().manual_seed( + generator=torch.Generator(device=device).manual_seed( args.rng_data_seed)) else: dataset_test = dataset From c247c99e0649edbc5f839715f567c7e6b6f46dc5 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:18:22 -0500 Subject: [PATCH 65/91] Try moving data lengths to cuda? --- examples/contrib/mue/FactorMuE.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 3bb764e4e7..5ca531162b 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -52,7 +52,8 @@ def main(args): if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) dataset_train, dataset_test = torch.utils.data.random_split( - dataset, [dataset.data_size - heldout_num, heldout_num], + dataset, torch.tensor([dataset.data_size - heldout_num, + heldout_num]), generator=torch.Generator(device=device).manual_seed( args.rng_data_seed)) else: From 50e214903737c144fed3e03f8458ab111bac86ad Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:20:41 -0500 Subject: [PATCH 66/91] Try another combination of device calls. --- examples/contrib/mue/FactorMuE.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 5ca531162b..d5a6af0d16 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -54,10 +54,10 @@ def main(args): dataset_train, dataset_test = torch.utils.data.random_split( dataset, torch.tensor([dataset.data_size - heldout_num, heldout_num]), - generator=torch.Generator(device=device).manual_seed( + generator=torch.Generator().manual_seed( args.rng_data_seed)) else: - dataset_test = dataset + dataset_train = dataset # Random sampler. pyro.set_rng_seed(args.rng_seed) From f06a1b4caf89492da59285b4e305176b64221b93 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:31:25 -0500 Subject: [PATCH 67/91] Try adjusting generator again. --- examples/contrib/mue/FactorMuE.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index d5a6af0d16..3a7b72c2e9 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -54,10 +54,11 @@ def main(args): dataset_train, dataset_test = torch.utils.data.random_split( dataset, torch.tensor([dataset.data_size - heldout_num, heldout_num]), - generator=torch.Generator().manual_seed( + generator=torch.Generator(device=device).manual_seed( args.rng_data_seed)) else: dataset_train = dataset + dataset_test = None # Random sampler. pyro.set_rng_seed(args.rng_seed) From a455ba83056c8438bf9388c03a2555638b3dadf9 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:34:46 -0500 Subject: [PATCH 68/91] Try removing generator statement entirely. --- examples/contrib/mue/FactorMuE.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 3a7b72c2e9..df6d70b993 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -51,11 +51,10 @@ def main(args): args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) + pyro.set_rng_seed(args.rng_data_seed) dataset_train, dataset_test = torch.utils.data.random_split( dataset, torch.tensor([dataset.data_size - heldout_num, - heldout_num]), - generator=torch.Generator(device=device).manual_seed( - args.rng_data_seed)) + heldout_num])) else: dataset_train = dataset dataset_test = None From 511b51e4351045e2c6676f0f387ee4326d9f82a3 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:39:01 -0500 Subject: [PATCH 69/91] Try removing data util random split call. --- examples/contrib/mue/FactorMuE.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index df6d70b993..7a37e5517a 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -51,10 +51,15 @@ def main(args): args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) + data_lengths = [len(dataset) - heldout_num, heldout_num] pyro.set_rng_seed(args.rng_data_seed) - dataset_train, dataset_test = torch.utils.data.random_split( - dataset, torch.tensor([dataset.data_size - heldout_num, - heldout_num])) + indices = torch.randperm(sum(data_lengths)).tolist() + dataset_train, dataset_test = [ + torch.utils.data.Subset(dataset, indices[(offset - length):offset]) + for offset, length in zip(torch._utils._accumulate(data_lengths), + data_lengths)] + """dataset_train, dataset_test = torch.utils.data.random_split( + dataset, torch.tensor())""" else: dataset_train = dataset dataset_test = None From 1301e97190cf00f7cc0653b9c23978b480c67e76 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Fri, 26 Feb 2021 10:41:01 -0500 Subject: [PATCH 70/91] Clean up comments. --- examples/contrib/mue/FactorMuE.py | 13 +++---------- pyro/contrib/mue/dataloaders.py | 9 +++------ 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 7a37e5517a..d48c30c4e7 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -23,7 +23,7 @@ import pdb -def generate_data(small_test): # , device=torch.device('cpu')): +def generate_data(small_test): """Generate example dataset.""" if small_test: mult_dat = 1 @@ -31,7 +31,7 @@ def generate_data(small_test): # , device=torch.device('cpu')): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) # , device=device) + dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) return dataset @@ -39,15 +39,10 @@ def generate_data(small_test): # , device=torch.device('cpu')): def main(args): # Load dataset. - if args.cuda: - device = torch.device('cuda:0') - else: - device = torch.device('cpu') if args.test: - dataset = generate_data(args.small) # , device=device) + dataset = generate_data(args.small) else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) - # , device=device) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: heldout_num = int(np.ceil(args.split*len(dataset))) @@ -58,8 +53,6 @@ def main(args): torch.utils.data.Subset(dataset, indices[(offset - length):offset]) for offset, length in zip(torch._utils._accumulate(data_lengths), data_lengths)] - """dataset_train, dataset_test = torch.utils.data.random_split( - dataset, torch.tensor())""" else: dataset_train = dataset dataset_test = None diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 0ee3c7b7d3..743d18edc5 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -17,10 +17,9 @@ class BiosequenceDataset(Dataset): """Load biological sequence data.""" def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None): # , device=torch.device('cpu')): + max_length=None): super().__init__() - # self.device = device # Get sequences. if source_type == 'list': @@ -30,7 +29,6 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', # Get lengths. self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) - # , device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: @@ -70,10 +68,9 @@ def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] - ).astype(np.float64)) # , device=self.device) + ).astype(np.float64)) # Pad. - x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) #, - # device=self.device)]) + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) return x From 690dae866f873b0698388dc38f4f013c462ebc26 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 27 Feb 2021 10:53:54 -0500 Subject: [PATCH 71/91] Fix embedding ordering. --- examples/contrib/mue/FactorMuE.py | 2 +- pyro/contrib/mue/models.py | 17 +++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index d48c30c4e7..0c1284f6ab 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -96,7 +96,7 @@ def main(args): print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) # Embed. - z_locs, z_scales = model.embed(dataset_train, dataset_test) + z_locs, z_scales = model.embed(dataset) # Plot and save. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 54050a2074..ad6bde108c 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -600,31 +600,20 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): perplex = np.exp(perplex / data_size) return lp, perplex - def embed(self, dataset_train, dataset_test=None, batch_size=None): + def embed(self, dataset, batch_size=None): """Get latent space embedding.""" if batch_size is None: batch_size = self.batch_size - dataload_train = DataLoader(dataset_train, batch_size=batch_size, - shuffle=False) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=False) with torch.no_grad(): z_locs, z_scales = [], [] - for seq_data, L_data in dataload_train: + for seq_data, L_data in dataload: if self.cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() z_loc, z_scale = self.encoder(seq_data) z_locs.append(z_loc.cpu()) z_scales.append(z_scale.cpu()) - if dataset_test is not None: - dataload_test = DataLoader(dataset_test, batch_size=batch_size, - shuffle=False) - for seq_data, L_data in dataload_test: - if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() - z_loc, z_scale = self.encoder(seq_data) - z_locs.append(z_loc.cpu()) - z_scales.append(z_scale.cpu()) - return torch.cat(z_locs), torch.cat(z_scales) def reconstruct_precursor_seq(self, data, ind, param): From 999bbe2dcc67af34be526fa4cab16be3a0529f0c Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 27 Feb 2021 13:02:43 -0500 Subject: [PATCH 72/91] Update profile HMM example. --- examples/contrib/mue/FactorMuE.py | 4 ++- examples/contrib/mue/ProfileHMM.py | 50 ++++++++++++++++++++++----- pyro/contrib/mue/models.py | 55 +++++++++++++++++++++++------- 3 files changed, 87 insertions(+), 22 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 0c1284f6ab..aed000775c 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -45,8 +45,10 @@ def main(args): dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: + # Train test split. heldout_num = int(np.ceil(args.split*len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] + # Specific data split seed. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths)).tolist() dataset_train, dataset_test = [ @@ -57,7 +59,7 @@ def main(args): dataset_train = dataset dataset_test = None - # Random sampler. + # Training seed. pyro.set_rng_seed(args.rng_seed) # Construct model. diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index d9e5c47dd1..ba4840e2e5 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -8,6 +8,7 @@ import argparse import datetime import json +import numpy as np import os import matplotlib.pyplot as plt @@ -43,26 +44,44 @@ def main(args): else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) args.batch_size = min([dataset.data_size, args.batch_size]) + if args.split > 0.: + heldout_num = int(np.ceil(args.split*len(dataset))) + data_lengths = [len(dataset) - heldout_num, heldout_num] + pyro.set_rng_seed(args.rng_data_seed) + indices = torch.randperm(sum(data_lengths)).tolist() + dataset_train, dataset_test = [ + torch.utils.data.Subset(dataset, indices[(offset - length):offset]) + for offset, length in zip(torch._utils._accumulate(data_lengths), + data_lengths)] + else: + dataset_train = dataset + dataset_test = None # Construct model. latent_seq_length = args.latent_seq_length - if args.latent_seq_length is None: - latent_seq_length = dataset.max_length + if latent_seq_length is None: + latent_seq_length = int(dataset.max_length * 1.1) model = ProfileHMM(latent_seq_length, dataset.alphabet_length, length_model=args.length_model, prior_scale=args.prior_scale, - indel_prior_bias=args.indel_prior_bias) + indel_prior_bias=args.indel_prior_bias, + cuda=args.cuda, + pin_memory=args.pin_mem) # Infer. scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': args.learning_rate}, 'milestones': json.loads(args.milestones), 'gamma': args.learning_gamma}) - if args.test and not args.small: - n_epochs = 100 - else: - n_epochs = args.n_epochs - losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler) + n_epochs = args.n_epochs + losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler, + args.jit) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset_train, dataset_test, args.jit) + print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) + print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -114,6 +133,13 @@ def main(args): pyro.get_param_store().save(os.path.join( args.out_folder, 'ProfileHMM_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('train_lp,test_lp,train_perplex,test_perplex\n') + ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, + test_perplex)) with open(os.path.join( args.out_folder, 'ProfileHMM_results.input_{}.txt'.format(time_stamp)), @@ -130,6 +156,7 @@ def main(args): parser.add_argument("--small", action='store_true', default=False, help='Run with small example dataset.') parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("--rng-data-seed", default=0, type=int) parser.add_argument("-f", "--file", default=None, type=str, help='Input file (fasta format).') parser.add_argument("-a", "--alphabet", default='amino-acid', @@ -158,6 +185,13 @@ def main(args): help='Save plots and results.') parser.add_argument("-outf", "--out-folder", default='.', help='Folder to save plots.') + parser.add_argument("--split", default=0.2, type=float, + help=('Fraction of dataset to holdout for testing')) + parser.add_argument("--jit", default=False, type=bool, + help='JIT compile the ELBO.') + parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') + parser.add_argument("--pin-mem", default=False, type=bool, + help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() torch.set_default_dtype(torch.float64) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index ad6bde108c..2e8542f268 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -28,8 +28,13 @@ class ProfileHMM(nn.Module): """Model: Constant + MuE. """ def __init__(self, latent_seq_length, alphabet_length, - length_model=False, prior_scale=1., indel_prior_bias=10.): + length_model=False, prior_scale=1., indel_prior_bias=10., + cuda=False, pin_memory=False): super().__init__() + assert isinstance(cuda, bool) + self.cuda = cuda + assert isinstance(pin_memory, bool) + self.pin_memory = pin_memory assert isinstance(latent_seq_length, int) and latent_seq_length > 0 self.latent_seq_length = latent_seq_length @@ -147,19 +152,23 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, 'optim_args': {'lr': 0.01}, 'milestones': [], 'gamma': 0.5}) + # Initialize guide. self.guide(None, None, None) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) + # Setup stochastic variational inference. if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: Elbo = Trace_ELBO() svi = SVI(self.model, self.guide, scheduler, loss=Elbo) - dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Run inference. losses = [] t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() loss = svi.step(seq_data, L_data, torch.tensor(len(dataset)/L_data.shape[0])) losses.append(loss) @@ -173,31 +182,50 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): if dataset_test is not None: dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + # Initialize guide. self.guide(None, None, None) if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: Elbo = Trace_ELBO() scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + # Setup stochastic variational inference. svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + # Compute elbo and perplexity. - train_lp, train_perplex = self._evaluate_elbo( - svi, dataload_train, dataset_train.data_size, self.length_model) + train_lp, train_perplex = self._evaluate_local_elbo( + svi, dataload_train, len(dataset_train), self.length_model) if dataset_test is not None: - test_lp, test_perplex = self._evaluate_elbo( - svi, dataload_test, dataset_test.data_size, self.length_model) + test_lp, test_perplex = self._evaluate_local_elbo( + svi, dataload_test, len(dataset_test), + self.length_model) return train_lp, test_lp, train_perplex, test_perplex else: return train_lp, None, train_perplex, None - def _evaluate_elbo(self, svi, dataload, data_size, length_model): + def _local_variables(self, name, site): + """Return per datapoint random variables in model.""" + return name in ['obs_L', 'obs_seq'] + + def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): """Evaluate elbo and average per residue perplexity.""" lp, perplex = 0., 0. - for seq_data, L_data in dataload: - lp_i = svi.evaluate_loss( - seq_data, L_data, torch.tensor(data_size)) / data_size - lp += -lp_i - perplex += lp_i / (L_data[0].numpy() + int(self.length_model)) + with torch.no_grad(): + for seq_data, L_data in dataload: + if self.cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() + conditioned_model = poutine.condition(self.model, data={ + "obs_L": L_data, "obs_seq": seq_data}) + args = (seq_data, L_data, torch.tensor(1.)) + guide_tr = poutine.trace(self.guide).get_trace(*args) + model_tr = poutine.trace(poutine.replay( + conditioned_model, trace=guide_tr)).get_trace(*args) + local_elbo = (model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ).cpu().numpy() + lp += local_elbo + perplex += -local_elbo / (L_data[0].cpu().numpy() + + int(self.length_model)) perplex = np.exp(perplex / data_size) return lp, perplex @@ -249,7 +277,7 @@ def __init__(self, data_length, alphabet_length, z_dim, assert isinstance(data_length, int) and data_length > 0 self.data_length = data_length if latent_seq_length is None: - latent_seq_length = data_length + latent_seq_length = int(data_length * 1.1) else: assert isinstance(latent_seq_length, int) and latent_seq_length > 0 self.latent_seq_length = latent_seq_length @@ -561,6 +589,7 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): else: Elbo = Trace_ELBO() scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + # Setup stochastic variational inference. svi = SVI(self.model, self.guide, scheduler, loss=Elbo) # Compute elbo and perplexity. From 1fcbd670d82dfb8b0a4a6003beaeccfdd78f269f Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 27 Feb 2021 14:31:58 -0500 Subject: [PATCH 73/91] Update tests, improve alphabet handling. --- examples/contrib/mue/FactorMuE.py | 4 ++-- examples/contrib/mue/ProfileHMM.py | 4 ++-- pyro/contrib/mue/dataloaders.py | 6 ++---- tests/contrib/mue/test_dataloaders.py | 8 ++++---- tests/contrib/mue/test_models.py | 12 ++++++------ 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index aed000775c..39d6084991 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -31,7 +31,7 @@ def generate_data(small_test): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) + dataset = BiosequenceDataset(seqs, 'list', 'AB') return dataset @@ -183,7 +183,7 @@ def main(args): parser.add_argument("-f", "--file", default=None, type=str, help='Input file (fasta format).') parser.add_argument("-a", "--alphabet", default='amino-acid', - help='Alphabet (amino-acid OR dna).') + help='Alphabet (amino-acid OR dna OR ATGC ...).') parser.add_argument("-zdim", "--z-dim", default=2, type=int, help='z space dimension.') parser.add_argument("-b", "--batch-size", default=10, type=int, diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index ba4840e2e5..5cff9a1d0c 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -29,7 +29,7 @@ def generate_data(small_test): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', ['A', 'B']) + dataset = BiosequenceDataset(seqs, 'list', 'AB') return dataset @@ -160,7 +160,7 @@ def main(args): parser.add_argument("-f", "--file", default=None, type=str, help='Input file (fasta format).') parser.add_argument("-a", "--alphabet", default='amino-acid', - help='Alphabet (amino-acid OR dna).') + help='Alphabet (amino-acid OR dna OR ATGC ...).') parser.add_argument("-b", "--batch-size", default=10, type=int, help='Batch size.') parser.add_argument("-M", "--latent-seq-length", default=None, type=int, diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 743d18edc5..c69bd97331 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -36,12 +36,10 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', self.data_size = len(self.L_data) # Get alphabet. - if type(alphabet) is list: - alphabet = np.array(alphabet) - elif alphabet in alphabets: + if alphabet in alphabets: alphabet = alphabets[alphabet] else: - assert 'Alphabet unavailable, please provide a list of letters.' + alphabet = np.array(list(alphabet)) self.alphabet_length = len(alphabet) # Build dataset. diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 887b63c490..ab1c697f03 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -8,17 +8,17 @@ @pytest.mark.parametrize('source_type', ['list', 'fasta']) -@pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', ['A', 'T', 'C']]) +@pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', 'ATC']) def test_biosequencedataset(source_type, alphabet): # Define dataset. seqs = ['AATC', 'CA', 'T'] # Encode dataset, alternate approach. - if type(alphabet) is list: - alphabet_list = alphabet - elif alphabet in alphabets: + if alphabet in alphabets: alphabet_list = list(alphabets[alphabet]) + else: + alphabet_list = list(alphabet) L_data_check = [len(seq) for seq in seqs] max_length_check = max(L_data_check) data_size_check = len(seqs) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index e04fcd8804..0d18e17260 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -17,7 +17,7 @@ def test_ProfileHMM_smoke(length_model, jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] - alph = ['A', 'B'] + alph = 'AB' dataset = BiosequenceDataset(seqs, 'list', alph) # Infer. @@ -25,7 +25,7 @@ def test_ProfileHMM_smoke(length_model, jit): 'optim_args': {'lr': 0.1}, 'milestones': [20, 100, 1000, 2000], 'gamma': 0.5}) - model = ProfileHMM(dataset.max_length, dataset.alphabet_length, + model = ProfileHMM(int(dataset.max_length*1.1), dataset.alphabet_length, length_model) n_epochs = 5 batch_size = 2 @@ -52,7 +52,7 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, ARD_prior, substitution_matrix, length_model, jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] - alph = ['A', 'B'] + alph = 'AB' dataset = BiosequenceDataset(seqs, 'list', alph) # Infer. @@ -91,7 +91,7 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, assert test_perplex > 0. # Embedding. - z_locs, z_scales = model.embed(dataset, dataset) - assert z_locs.shape == (len(dataset)*2, z_dim) - assert z_scales.shape == (len(dataset)*2, z_dim) + z_locs, z_scales = model.embed(dataset) + assert z_locs.shape == (len(dataset), z_dim) + assert z_scales.shape == (len(dataset), z_dim) assert torch.all(z_scales > 0.) From aa2850d3b44dad09b494b49d950b3d6336d3ce0c Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 27 Feb 2021 15:14:08 -0500 Subject: [PATCH 74/91] Adjust cuda defaults. --- examples/contrib/mue/ProfileHMM.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 5cff9a1d0c..ee720b77eb 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -194,6 +194,9 @@ def main(args): help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() - torch.set_default_dtype(torch.float64) + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) main(args) From 23ddc45adeb48e4e4dc069b6c2e66189a1afbf8b Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sat, 27 Feb 2021 15:29:58 -0500 Subject: [PATCH 75/91] Move back to cpu for plotting. --- examples/contrib/mue/ProfileHMM.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index ee720b77eb..408f176bdd 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -99,7 +99,7 @@ def main(args): precursor_seq = pyro.param("precursor_seq_q_mn").detach() precursor_seq_expect = torch.exp(precursor_seq - precursor_seq.logsumexp(-1, True)) - plt.plot(precursor_seq_expect[:, 1].numpy()) + plt.plot(precursor_seq_expect[:, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of character 1') if args.save: @@ -111,7 +111,7 @@ def main(args): plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) - plt.plot(insert_expect[:, :, 1].numpy()) + plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of insert') if args.save: @@ -121,7 +121,7 @@ def main(args): plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) - plt.plot(delete_expect[:, :, 1].numpy()) + plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of delete') if args.save: From ca345f06db31a5d045247a5eab9a7ad5c8b83f59 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Mon, 15 Mar 2021 17:51:35 -0400 Subject: [PATCH 76/91] Documentation for profile hmm model. --- examples/contrib/mue/FactorMuE.py | 2 -- pyro/contrib/mue/models.py | 54 ++++++++++++++++++++++++++---- pyro/contrib/mue/statearrangers.py | 2 +- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 39d6084991..41b48d13d4 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -20,8 +20,6 @@ from pyro.contrib.mue.models import FactorMuE from pyro.optim import MultiStepLR -import pdb - def generate_data(small_test): """Generate example dataset.""" diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 2e8542f268..7cd4f2bd20 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -26,15 +26,34 @@ class ProfileHMM(nn.Module): - """Model: Constant + MuE. """ + """Profile HMM. + + This model consists of a constant distribution (a delta function) over the + regressor sequence, plus a MuE observation distribution. The priors + are all Normal distributions, and are pushed through a softmax function + onto the simplex. + + :param int latent_seq_length: Length of the latent regressor sequence M. + Must be greater than or equal to 1. + :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for + amino acids). + :param bool length_model: Model the length of the sequence with a Poisson + distribution. (Default: False.) + :param float prior_scale: Standard deviation of the prior distribution. + (Default: 1.0.) + :param float indel_prior_bias: Offset of the mean of the prior distribution + over the indel probability. Higher values lead to lower probability + of indels. (Default: 10.0.) + :param bool cuda: Transfer data onto the GPU for training. (Default: False.) + :param bool pin_memory: Pin memory for faster GPU transfer. + (Default: False.) + """ def __init__(self, latent_seq_length, alphabet_length, length_model=False, prior_scale=1., indel_prior_bias=10., - cuda=False, pin_memory=False): + cuda=False): super().__init__() assert isinstance(cuda, bool) self.cuda = cuda - assert isinstance(pin_memory, bool) - self.pin_memory = pin_memory assert isinstance(latent_seq_length, int) and latent_seq_length > 0 self.latent_seq_length = latent_seq_length @@ -142,7 +161,20 @@ def guide(self, seq_data, L_data, local_scale): def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, jit=False): - """Infer model parameters with stochastic variational inference.""" + """ + Infer approximate posterior with stochastic variational inference. + + This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference + method useful for quickly iterating on probabilistic models. + + :param torch.utils.data.Dataset dataset: The training dataset. + :param int epochs: Number of epochs of training. (Default: 1.) + :param int batch_size: Minibatch size (number of sequences). + (Default: 1.) + :param pyro.optim.MultiStepLR scheduler: Learning rate scheduler. + (Default: Adam optimizer, 0.01 constant learning rate.) + :param bool jit: Whether to use a jit compiled ELBO. + """ # Setup. if batch_size is not None: @@ -154,7 +186,8 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, 'gamma': 0.5}) # Initialize guide. self.guide(None, None, None) - dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, + pin_memory=self.pin_memory) # Setup stochastic variational inference. if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) @@ -177,7 +210,14 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, return losses def evaluate(self, dataset_train, dataset_test=None, jit=False): - """Evaluate performance on train and test datasets.""" + """ + Evaluate performance on train and test datasets. + + :param torch.utils.data.Dataset dataset: The training dataset. + :param torch.utils.data.Dataset dataset: The testing dataset. + (Default: None.) + :param bool jit: Whether to use a jit compiled ELBO. + """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) if dataset_test is not None: dataload_test = DataLoader(dataset_test, batch_size=1, diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index a6194d6811..51f2aa3989 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -16,7 +16,7 @@ class Profile(nn.Module): [1] E. N. Weinstein, D. S. Marks (2020) "Generative probabilistic biological sequence models that account for mutational variability" - https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf + https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf [2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) "Biological sequence analysis: probabilistic models of proteins and nucleic acids" From cdea31d88948f199ffbbc97e2c1638159c2f6ac6 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 11:51:13 -0400 Subject: [PATCH 77/91] Docs for FactorMuE model. --- docs/source/contrib.mue.rst | 2 +- pyro/contrib/mue/models.py | 123 +++++++++++++++++++++++++------ tests/contrib/mue/test_models.py | 2 +- 3 files changed, 102 insertions(+), 25 deletions(-) diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 9dbc3d05b2..e2e020ecd6 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -12,7 +12,7 @@ preprocessing. Reference: MuE models were described in Weinstein and Marks (2020), -https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1. +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2. Example MuE Models ------------------ diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 7cd4f2bd20..5a54480f98 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -38,15 +38,13 @@ class ProfileHMM(nn.Module): :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for amino acids). :param bool length_model: Model the length of the sequence with a Poisson - distribution. (Default: False.) + distribution. :param float prior_scale: Standard deviation of the prior distribution. - (Default: 1.0.) - :param float indel_prior_bias: Offset of the mean of the prior distribution - over the indel probability. Higher values lead to lower probability - of indels. (Default: 10.0.) - :param bool cuda: Transfer data onto the GPU for training. (Default: False.) + :param float indel_prior_bias: Mean of the prior distribution over the + log probability of an indel not occurring. Higher values lead to lower + probability of indels. + :param bool cuda: Transfer data onto the GPU for training. :param bool pin_memory: Pin memory for faster GPU transfer. - (Default: False.) """ def __init__(self, latent_seq_length, alphabet_length, length_model=False, prior_scale=1., indel_prior_bias=10., @@ -159,7 +157,7 @@ def guide(self, seq_data, L_data, local_scale): pyro.sample("length", dist.Normal( length_q_mn, softplus(length_q_sd))) - def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, + def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, jit=False): """ Infer approximate posterior with stochastic variational inference. @@ -167,12 +165,13 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference method useful for quickly iterating on probabilistic models. - :param torch.utils.data.Dataset dataset: The training dataset. - :param int epochs: Number of epochs of training. (Default: 1.) + :param dataset: The training dataset, with type + :class:`~torch.utils.data.Dataset`. + :param int epochs: Number of epochs of training. :param int batch_size: Minibatch size (number of sequences). - (Default: 1.) - :param pyro.optim.MultiStepLR scheduler: Learning rate scheduler. - (Default: Adam optimizer, 0.01 constant learning rate.) + :param scheduler: Learning rate scheduler, with type + :class:`~pyro.optim.MultiStepLR`. (Default: Adam optimizer, + 0.01 constant learning rate.) :param bool jit: Whether to use a jit compiled ELBO. """ @@ -211,11 +210,13 @@ def fit_svi(self, dataset, epochs=1, batch_size=1, scheduler=None, def evaluate(self, dataset_train, dataset_test=None, jit=False): """ - Evaluate performance on train and test datasets. + Evaluate performance (log probability and per residue perplexity) on + train and test datasets. - :param torch.utils.data.Dataset dataset: The training dataset. - :param torch.utils.data.Dataset dataset: The testing dataset. - (Default: None.) + :param dataset: The training dataset, with type + :class:`~torch.utils.data.Dataset`. + :param torch.utils.data.Dataset dataset: The testing dataset, with type + :class:`~torch.utils.data.Dataset` or None. (Default: None.) :param bool jit: Whether to use a jit compiled ELBO. """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) @@ -288,7 +289,49 @@ def forward(self, data): class FactorMuE(nn.Module): - """Model: pPCA + MuE.""" + """FactorMuE + + This model consists of probabilistic PCA plus a MuE output distribution. + + The priors are all Normal distributions, and where relevant pushed through + a softmax to produce a prior over the simplex. + + :param int data_length: Length of the input sequence matrix, including + zero padding at the end. + :param int z_dim: Number of dimensions of the z space. + :param int batch_size: Minibatch size. + :param int latent_seq_length: Length of the latent regressor sequence (M). + Must be greater than or equal to 1. (Default: 1.1 x data_length.) + :param bool indel_factor_dependence: Indel probabilities depend on the + latent variable z. + :param float indel_prior_scale: Standard deviation of the prior + distribution on indel parameters. + :param float indel_prior_bias: Mean of the prior distribution over the + log probability of an indel not occurring. Higher values lead to lower + probability of indels. + :param float inverse_temp_prior: Mean of the prior distribution over the + inverse temperature parameter. + :param float weights_prior_scale: Standard deviation of the prior + distribution over the factors. + :param float offset_prior_scale: Standard deviation of the prior + distribution over the offset (constant) pPCA model. + :param str z_prior_distribution: Prior distribution over the latent + variable z. Either 'Normal' (pPCA model) or 'Laplace' (an ICA model). + :param bool ARD_prior: Use automatic relevance determination prior on + factors. + :param bool substitution_matrix: Use a learnable substitution matrix (l) + rather than the identity matrix. + :param float substitution_prior_scale: Standard deviation of the prior + distribution over substitution matrix parameters (when + substitution_matrix is True). + :param int latent_alphabet_length: Length of the alphabet in the latent + regressor sequence. + :param bool length_model: Model the length of the sequence with a Poisson + distribution. + :param bool cuda: Transfer data onto the GPU for training. + :param bool pin_memory: Pin memory for faster GPU transfer. + :epsilon float epsilon: A small value for numerical stability. + """ def __init__(self, data_length, alphabet_length, z_dim, batch_size=10, latent_seq_length=None, @@ -558,9 +601,26 @@ def guide(self, seq_data, L_data, local_scale, local_prior_scale): pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) - def fit_svi(self, dataset, epochs=2, anneal_length=1, batch_size=None, + def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, scheduler=None, jit=False): - """Infer model parameters with stochastic variational inference.""" + """ + Infer approximate posterior with stochastic variational inference. + + This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference + method useful for quickly iterating on probabilistic models. + + :param dataset: The training dataset, with type + :class:`~torch.utils.data.Dataset`. + :param int epochs: Number of epochs of training. + :param float anneal_length: Number of epochs over which to linearly + anneal the prior KL divergence weight from 0 to 1, for improved + convergence. + :param int batch_size: Minibatch size (number of sequences). + :param scheduler: Learning rate scheduler, with type + :class:`~pyro.optim.MultiStepLR`. (Default: Adam optimizer, + 0.01 constant learning rate.) + :param bool jit: Whether to use a jit compiled ELBO. + """ # Setup. if batch_size is not None: @@ -613,7 +673,16 @@ def _beta_anneal(self, step, batch_size, data_size, anneal_length): return torch.tensor(min([anneal_frac, 1.])) def evaluate(self, dataset_train, dataset_test=None, jit=False): - """Evaluate performance on train and test datasets.""" + """ + Evaluate performance (log probability and per residue perplexity) on + train and test datasets. + + :param dataset: The training dataset, with type + :class:`~torch.utils.data.Dataset`. + :param torch.utils.data.Dataset dataset: The testing dataset, with type + :class:`~torch.utils.data.Dataset` or None. (Default: None.) + :param bool jit: Whether to use a jit compiled ELBO. + """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) if dataset_test is not None: dataload_test = DataLoader(dataset_test, batch_size=1, @@ -670,7 +739,14 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): return lp, perplex def embed(self, dataset, batch_size=None): - """Get latent space embedding.""" + """ + Get the latent space embedding (mean posterior value of z). + + :param dataset: The dataset to embed, with type + :class:`~torch.utils.data.Dataset`. + :param int batch_size: Minibatch size (number of sequences). (Defaults + to batch_size of the model object.) + """ if batch_size is None: batch_size = self.batch_size dataload = DataLoader(dataset, batch_size=batch_size, shuffle=False) @@ -685,7 +761,8 @@ def embed(self, dataset, batch_size=None): return torch.cat(z_locs), torch.cat(z_scales) - def reconstruct_precursor_seq(self, data, ind, param): + def _reconstruct_regressor_seq(self, data, ind, param): + "Reconstruct the latent regressor sequence given data." with torch.no_grad(): # Encode seq. z_loc = self.encoder(data[ind][0])[0] diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 0d18e17260..127af5043a 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -74,7 +74,7 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, scheduler, jit) # Reconstruct. - recon = model.reconstruct_precursor_seq(dataset, 1, pyro.param) + recon = model._reconstruct_regressor_seq(dataset, 1, pyro.param) assert not np.isnan(losses[-1]) assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) From d8c954627857467acea58fddbc7499c8a8e8b44c Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 14:29:58 -0400 Subject: [PATCH 78/91] Cleaned up docs. --- docs/source/contrib.mue.rst | 14 ++++++-- pyro/contrib/mue/dataloaders.py | 14 +++++++- pyro/contrib/mue/missingdatahmm.py | 4 +-- pyro/contrib/mue/models.py | 56 +++++++++++++----------------- pyro/contrib/mue/statearrangers.py | 9 ++--- 5 files changed, 56 insertions(+), 41 deletions(-) diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index e2e020ecd6..9c3f8edc94 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -11,7 +11,7 @@ a fully probabilistic alternative to multiple sequence alignment-based preprocessing. Reference: -MuE models were described in Weinstein and Marks (2020), +MuE models were described in Weinstein and Marks (2021), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2. Example MuE Models @@ -28,9 +28,17 @@ State Arrangers for Parameterizing MuEs :show-inheritance: :member-order: bysource -Missing Variable Length Data HMM --------------------------------- +Missing or Variable Length Data HMM +----------------------------------- .. automodule:: pyro.contrib.mue.missingdatahmm :members: :show-inheritance: :member-order: bysource + + +Biosequence Dataset Loading +--------------------------- +.. automodule:: pyro.contrib.mue.dataloaders + :members: + :show-inheritance: + :member-order: bysource diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index c69bd97331..c36cff4bd2 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -14,7 +14,19 @@ class BiosequenceDataset(Dataset): - """Load biological sequence data.""" + """ + Load biological sequence data, either from a fasta file or a python list. + + :param source: Either the input fasta file path (str) or the input list + of sequences (list of str). + :param str source_type: Type of input, either 'list' or 'fasta'. + :param str alphabet: Alphabet to use. Alphabets 'amino-acid' and 'dna' are + preset; any other input will be interpreted as the alphabet itself, + i.e. you can use 'ACGU' for RNA. + :param int max_length: Total length of the one-hot representation of the + sequences, including zero padding. Defaults to the maximum sequence + length in the dataset. + """ def __init__(self, source, source_type='list', alphabet='amino-acid', max_length=None): diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index f858dd68cb..e4f49aa879 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -15,8 +15,8 @@ class MissingDataDiscreteHMM(TorchDistribution): missing data or variable length sequences. Observations are assumed to be one hot encoded; rows with all zeros indicate missing data. - .. warning:: Unlike in pyro's DiscreteHMM, which computes the - probability of the first state as + .. warning:: Unlike in pyro's pyro.distributions.DiscreteHMM, which + computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 5a54480f98..fadfd15f05 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -22,8 +22,6 @@ from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import MultiStepLR -import pdb - class ProfileHMM(nn.Module): """Profile HMM. @@ -43,15 +41,17 @@ class ProfileHMM(nn.Module): :param float indel_prior_bias: Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower probability of indels. - :param bool cuda: Transfer data onto the GPU for training. + :param bool cuda: Transfer data onto the GPU during training. :param bool pin_memory: Pin memory for faster GPU transfer. """ def __init__(self, latent_seq_length, alphabet_length, length_model=False, prior_scale=1., indel_prior_bias=10., - cuda=False): + cuda=False, pin_memory=False): super().__init__() assert isinstance(cuda, bool) self.cuda = cuda + assert isinstance(pin_memory, bool) + self.pin_memory = pin_memory assert isinstance(latent_seq_length, int) and latent_seq_length > 0 self.latent_seq_length = latent_seq_length @@ -165,13 +165,11 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference method useful for quickly iterating on probabilistic models. - :param dataset: The training dataset, with type - :class:`~torch.utils.data.Dataset`. + :param ~torch.utils.data.Dataset dataset: The training dataset. :param int epochs: Number of epochs of training. :param int batch_size: Minibatch size (number of sequences). - :param scheduler: Learning rate scheduler, with type - :class:`~pyro.optim.MultiStepLR`. (Default: Adam optimizer, - 0.01 constant learning rate.) + :param pyro.optim.MultiStepLR scheduler: Optimization scheduler. + (Default: Adam optimizer, 0.01 constant learning rate.) :param bool jit: Whether to use a jit compiled ELBO. """ @@ -213,10 +211,8 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): Evaluate performance (log probability and per residue perplexity) on train and test datasets. - :param dataset: The training dataset, with type - :class:`~torch.utils.data.Dataset`. - :param torch.utils.data.Dataset dataset: The testing dataset, with type - :class:`~torch.utils.data.Dataset` or None. (Default: None.) + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param ~torch.utils.data.Dataset dataset: The testing dataset. :param bool jit: Whether to use a jit compiled ELBO. """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) @@ -294,10 +290,12 @@ class FactorMuE(nn.Module): This model consists of probabilistic PCA plus a MuE output distribution. The priors are all Normal distributions, and where relevant pushed through - a softmax to produce a prior over the simplex. + a softmax onto the simplex. :param int data_length: Length of the input sequence matrix, including zero padding at the end. + :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for + amino acids). :param int z_dim: Number of dimensions of the z space. :param int batch_size: Minibatch size. :param int latent_seq_length: Length of the latent regressor sequence (M). @@ -314,12 +312,12 @@ class FactorMuE(nn.Module): :param float weights_prior_scale: Standard deviation of the prior distribution over the factors. :param float offset_prior_scale: Standard deviation of the prior - distribution over the offset (constant) pPCA model. + distribution over the offset (constant) in the pPCA model. :param str z_prior_distribution: Prior distribution over the latent variable z. Either 'Normal' (pPCA model) or 'Laplace' (an ICA model). :param bool ARD_prior: Use automatic relevance determination prior on factors. - :param bool substitution_matrix: Use a learnable substitution matrix (l) + :param bool substitution_matrix: Use a learnable substitution matrix rather than the identity matrix. :param float substitution_prior_scale: Standard deviation of the prior distribution over substitution matrix parameters (when @@ -327,10 +325,10 @@ class FactorMuE(nn.Module): :param int latent_alphabet_length: Length of the alphabet in the latent regressor sequence. :param bool length_model: Model the length of the sequence with a Poisson - distribution. - :param bool cuda: Transfer data onto the GPU for training. + distribution, with mean dependent on the latent pPCA model. + :param bool cuda: Transfer data onto the GPU during training. :param bool pin_memory: Pin memory for faster GPU transfer. - :epsilon float epsilon: A small value for numerical stability. + :param float epsilon: A small value for numerical stability. """ def __init__(self, data_length, alphabet_length, z_dim, batch_size=10, @@ -609,16 +607,14 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference method useful for quickly iterating on probabilistic models. - :param dataset: The training dataset, with type - :class:`~torch.utils.data.Dataset`. + :param ~torch.utils.data.Dataset dataset: The training dataset. :param int epochs: Number of epochs of training. :param float anneal_length: Number of epochs over which to linearly anneal the prior KL divergence weight from 0 to 1, for improved - convergence. + training. :param int batch_size: Minibatch size (number of sequences). - :param scheduler: Learning rate scheduler, with type - :class:`~pyro.optim.MultiStepLR`. (Default: Adam optimizer, - 0.01 constant learning rate.) + :param pyro.optim.MultiStepLR scheduler: Optimization scheduler. + (Default: Adam optimizer, 0.01 constant learning rate.) :param bool jit: Whether to use a jit compiled ELBO. """ @@ -677,10 +673,9 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): Evaluate performance (log probability and per residue perplexity) on train and test datasets. - :param dataset: The training dataset, with type - :class:`~torch.utils.data.Dataset`. - :param torch.utils.data.Dataset dataset: The testing dataset, with type - :class:`~torch.utils.data.Dataset` or None. (Default: None.) + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param ~torch.utils.data.Dataset dataset: The testing dataset + (optional). :param bool jit: Whether to use a jit compiled ELBO. """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) @@ -742,8 +737,7 @@ def embed(self, dataset, batch_size=None): """ Get the latent space embedding (mean posterior value of z). - :param dataset: The dataset to embed, with type - :class:`~torch.utils.data.Dataset`. + :param ~torch.utils.data.Dataset dataset: The dataset to embed. :param int batch_size: Minibatch size (number of sequences). (Defaults to batch_size of the model object.) """ diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 51f2aa3989..2c384ec720 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -13,18 +13,19 @@ class Profile(nn.Module): **References** - [1] E. N. Weinstein, D. S. Marks (2020) + [1] E. N. Weinstein, D. S. Marks (2021) "Generative probabilistic biological sequence models that account for mutational variability" https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf + [2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) "Biological sequence analysis: probabilistic models of proteins and nucleic acids" Cambridge university press - :param M: Length of precursor (ancestral) sequence. + :param M: Length of regressor sequence. :type M: int - :param epsilon: Small value for approximate zeros in log space. + :param epsilon: A small value for numerical stability. :type epsilon: float """ def __init__(self, M, epsilon=1e-32): @@ -143,7 +144,7 @@ def forward(self, precursor_seq_logits, insert_seq_logits, """ Assemble HMM parameters given profile parameters. - :param ~torch.Tensor precursor_seq_logits: Initial (relaxed) sequence + :param ~torch.Tensor precursor_seq_logits: Regressor sequence *log(x)*. Should have rightmost dimension ``(M, D)`` and be broadcastable to ``(batch_size, M, D)``, where D is the latent alphabet size. Should be normalized to one along the From 8542820f273cd407be8bb883fa287d433b3eb8f9 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 17:30:02 -0400 Subject: [PATCH 79/91] Tutorials. --- examples/contrib/mue/FactorMuE.py | 36 ++++++++++++++++++++++++------ examples/contrib/mue/ProfileHMM.py | 35 +++++++++++++++++++++++++---- tutorial/source/index.rst | 8 +++++++ tutorial/source/mue_factor.rst | 11 +++++++++ tutorial/source/mue_profile.rst | 11 +++++++++ 5 files changed, 90 insertions(+), 11 deletions(-) create mode 100644 tutorial/source/mue_factor.rst create mode 100644 tutorial/source/mue_profile.rst diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 41b48d13d4..6c95a5df77 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -2,7 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 """ -A PCA model with a MuE emission (FactorMuE). +A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model +[1]. This is a generative model of variable-length biological sequences (e.g. +proteins) which does not require preprocessing the data by building a +multiple sequence alignment. It can be used to infer a latent representation +of sequences and the principal components of sequence variation, while +accounting for alignment uncertainty. + +An example dataset consisting of proteins similar to the human papillomavirus E6 +protein, collected from a non-redundant sequence dataset using jackhmmer, can +be found at +https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta + +Reference: +[1] E. N. Weinstein, D. S. Marks (2021) +"Generative probabilistic biological sequence models that account for +mutational variability" +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf """ import argparse @@ -22,7 +38,7 @@ def generate_data(small_test): - """Generate example dataset.""" + """Generate mini example dataset.""" if small_test: mult_dat = 1 else: @@ -46,7 +62,8 @@ def main(args): # Train test split. heldout_num = int(np.ceil(args.split*len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] - # Specific data split seed. + # Specific data split seed, for comparability across models and + # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths)).tolist() dataset_train, dataset_test = [ @@ -80,7 +97,7 @@ def main(args): cuda=args.cuda, pin_memory=args.pin_mem) - # Infer. + # Infer with SVI. scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': args.learning_rate}, 'milestones': json.loads(args.milestones), @@ -95,7 +112,7 @@ def main(args): print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) - # Embed. + # Get latent space embedding. z_locs, z_scales = model.embed(dataset) # Plot and save. @@ -112,20 +129,23 @@ def main(args): plt.figure(figsize=(6, 6)) plt.scatter(z_locs[:, 0], z_locs[:, 1]) - plt.xlabel('z_1') - plt.ylabel('z_2') + plt.xlabel(r'$z_1$') + plt.ylabel(r'$z_2$') if args.save: plt.savefig(os.path.join( args.out_folder, 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) if not args.indel_factor: + # Plot indel parameters. See statearrangers.py for details on the + # r and u parameters. plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of insert') + plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) if args.save: plt.savefig(os.path.join( args.out_folder, @@ -136,6 +156,7 @@ def main(args): plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of delete') + plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) if args.save: plt.savefig(os.path.join( args.out_folder, @@ -171,6 +192,7 @@ def main(args): if __name__ == '__main__': + # Parse command line arguments. parser = argparse.ArgumentParser(description="Factor MuE model.") parser.add_argument("--test", action='store_true', default=False, help='Run with generated example dataset.') diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 408f176bdd..ff4124724a 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -2,7 +2,28 @@ # SPDX-License-Identifier: Apache-2.0 """ -A standard profile HMM model. +A standard profile HMM model [1], which corresponds to a constant (delta +function) distribution with a MuE observation [2]. This is a standard +generative model of variable-length biological sequences (e.g. proteins) which +does not require preprocessing the data by building a multiple sequence +alignment. It can be compared to a more complex MuE model in this package, +the FactorMuE. + +An example dataset consisting of proteins similar to the human papillomavirus E6 +protein, collected from a non-redundant sequence dataset using jackhmmer, can +be found at +https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta + +References: +[1] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) +"Biological sequence analysis: probabilistic models of proteins and nucleic +acids" +Cambridge university press + +[2] E. N. Weinstein, D. S. Marks (2021) +"Generative probabilistic biological sequence models that account for +mutational variability" +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf """ import argparse @@ -22,7 +43,7 @@ def generate_data(small_test): - """Generate example dataset.""" + """Generate mini example dataset.""" if small_test: mult_dat = 1 else: @@ -45,8 +66,11 @@ def main(args): dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: + # Train test split. heldout_num = int(np.ceil(args.split*len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] + # Specific data split seed, for comparability across models and + # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths)).tolist() dataset_train, dataset_test = [ @@ -68,7 +92,7 @@ def main(args): cuda=args.cuda, pin_memory=args.pin_mem) - # Infer. + # Infer with SVI. scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': args.learning_rate}, 'milestones': json.loads(args.milestones), @@ -114,6 +138,7 @@ def main(args): plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of insert') + plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) if args.save: plt.savefig(os.path.join( args.out_folder, @@ -124,6 +149,7 @@ def main(args): plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel('position') plt.ylabel('probability of delete') + plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) if args.save: plt.savefig(os.path.join( args.out_folder, @@ -150,7 +176,8 @@ def main(args): if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Factor MuE model.") + # Parse command line arguments. + parser = argparse.ArgumentParser(description="Profile HMM model.") parser.add_argument("--test", action='store_true', default=False, help='Run with generated example dataset.') parser.add_argument("--small", action='store_true', default=False, diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst index f24dd6b7e7..47d27424bd 100644 --- a/tutorial/source/index.rst +++ b/tutorial/source/index.rst @@ -171,6 +171,14 @@ List of Tutorials epi_regional sir_hmc +.. toctree:: + :maxdepth: 1 + :caption: Application: Biological sequences + :name: biological-sequences + + mue_profile + mue_factor + .. toctree:: :maxdepth: 1 :caption: Application: Experimental Design diff --git a/tutorial/source/mue_factor.rst b/tutorial/source/mue_factor.rst new file mode 100644 index 0000000000..d4ec19ae4b --- /dev/null +++ b/tutorial/source/mue_factor.rst @@ -0,0 +1,11 @@ +Example: Probabilistic PCA + MuE (FactorMuE) +============================================ + +`View FactorHMM.py on github`__ + +.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/mue/FactorMuE.py + +__ github_ + +.. literalinclude:: ../../examples/contrib/mue/FactorMuE.py + :language: python diff --git a/tutorial/source/mue_profile.rst b/tutorial/source/mue_profile.rst new file mode 100644 index 0000000000..41d1b704f0 --- /dev/null +++ b/tutorial/source/mue_profile.rst @@ -0,0 +1,11 @@ +Example: Constant + MuE (Profile HMM) +===================================== + +`View ProfileHMM.py on github`__ + +.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/mue/ProfileHMM.py + +__ github_ + +.. literalinclude:: ../../examples/contrib/mue/ProfileHMM.py + :language: python From 5199844693dc6260faec1a9c3512ac400c3c3cef Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 18:05:57 -0400 Subject: [PATCH 80/91] Add example run scripts. --- examples/contrib/mue/FactorMuE.py | 7 +++++++ examples/contrib/mue/ProfileHMM.py | 21 +++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 6c95a5df77..0f0c5cf279 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -14,6 +14,13 @@ be found at https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta +Example run: +python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25 + --indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace + --jit True --cuda True +This should take about 8 minutes to run on a GPU. The latent space should show +multiple small clusters, and the perplexity should be around 4.0. + Reference: [1] E. N. Weinstein, D. S. Marks (2021) "Generative probabilistic biological sequence models that account for diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index ff4124724a..ef9c18fb93 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -14,6 +14,12 @@ be found at https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta +Example run: +python ProfileHMM.py -f PATH/ve6_full.fasta -b 10 -M 174 --indel-prior-bias 10. + -e 15 -lr 0.01 --jit True --cuda True +This should take about 9 minutes to run on a GPU. The perplexity should be +around 5.8. + References: [1] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) "Biological sequence analysis: probabilistic models of proteins and nucleic @@ -119,19 +125,6 @@ def main(args): args.out_folder, 'ProfileHMM_plot.loss_{}.pdf'.format(time_stamp))) - plt.figure(figsize=(6, 6)) - precursor_seq = pyro.param("precursor_seq_q_mn").detach() - precursor_seq_expect = torch.exp(precursor_seq - - precursor_seq.logsumexp(-1, True)) - plt.plot(precursor_seq_expect[:, 1].cpu().numpy()) - plt.xlabel('position') - plt.ylabel('probability of character 1') - if args.save: - plt.savefig(os.path.join( - args.out_folder, - 'ProfileHMM_plot.precursor_seq_prob_{}.pdf'.format( - time_stamp))) - plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) @@ -161,7 +154,7 @@ def main(args): 'ProfileHMM_results.params_{}.out'.format(time_stamp))) with open(os.path.join( args.out_folder, - 'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), + 'ProfileHMM_results.evaluation_{}.txt'.format(time_stamp)), 'w') as ow: ow.write('train_lp,test_lp,train_perplex,test_perplex\n') ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, From 008528340727be759cd21235bbc794b4ff8180d6 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 18:17:37 -0400 Subject: [PATCH 81/91] Make format changes. --- examples/contrib/mue/FactorMuE.py | 2 +- examples/contrib/mue/ProfileHMM.py | 2 +- pyro/contrib/mue/models.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 0f0c5cf279..a913239f35 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -31,10 +31,10 @@ import argparse import datetime import json -import numpy as np import os import matplotlib.pyplot as plt +import numpy as np import torch from torch.optim import Adam diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index ef9c18fb93..53ecc3ec1d 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -35,10 +35,10 @@ import argparse import datetime import json -import numpy as np import os import matplotlib.pyplot as plt +import numpy as np import torch from torch.optim import Adam diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index fadfd15f05..ef270d4a99 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -15,8 +15,8 @@ from torch.utils.data import DataLoader import pyro -from pyro import poutine import pyro.distributions as dist +from pyro import poutine from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM from pyro.contrib.mue.statearrangers import Profile from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO From b708557ebba70ae5e45b542524c24cf53beb9bca Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 19:02:51 -0400 Subject: [PATCH 82/91] Fix example boolean inputs --- examples/contrib/mue/FactorMuE.py | 44 +++++++++++++++++------------- examples/contrib/mue/ProfileHMM.py | 28 ++++++++++--------- tests/test_examples.py | 8 +++--- 3 files changed, 44 insertions(+), 36 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index a913239f35..cf262dbe08 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -17,7 +17,7 @@ Example run: python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25 --indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace - --jit True --cuda True + --jit --cuda This should take about 8 minutes to run on a GPU. The latent space should show multiple small clusters, and the perplexity should be around 4.0. @@ -97,7 +97,7 @@ def main(args): offset_prior_scale=args.offset_prior_scale, z_prior_distribution=args.z_prior, ARD_prior=args.ARD_prior, - substitution_matrix=args.substitution_matrix, + substitution_matrix=(not args.no_substitution_matrix), substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, length_model=args.length_model, @@ -124,12 +124,12 @@ def main(args): # Plot and save. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - if args.plots: + if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) @@ -138,7 +138,7 @@ def main(args): plt.scatter(z_locs[:, 0], z_locs[:, 1]) plt.xlabel(r'$z_1$') plt.ylabel(r'$z_2$') - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) @@ -153,7 +153,7 @@ def main(args): plt.xlabel('position') plt.ylabel('probability of insert') plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) @@ -164,11 +164,12 @@ def main(args): plt.xlabel('position') plt.ylabel('probability of delete') plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) - if args.save: + + if not args.no_save: pyro.get_param_store().save(os.path.join( args.out_folder, 'FactorMuE_results.params_{}.out'.format(time_stamp))) @@ -217,17 +218,21 @@ def main(args): help='Batch size.') parser.add_argument("-M", "--latent-seq-length", default=None, type=int, help='Latent sequence length.') - parser.add_argument("-idfac", "--indel-factor", default=False, type=bool, + parser.add_argument("-idfac", "--indel-factor", default=False, + action='store_true', help='Indel parameters depend on latent variable.') parser.add_argument("-zdist", "--z-prior", default='Normal', help='Latent prior distribution (normal or Laplace).') - parser.add_argument("-ard", "--ARD-prior", default=False, type=bool, + parser.add_argument("-ard", "--ARD-prior", default=False, + action='store_true', help='Use automatic relevance detection prior.') - parser.add_argument("-sub", "--substitution-matrix", default=True, type=bool, - help='Use substitution matrix.') + parser.add_argument("--no-substitution-matrix", default=False, + action='store_true', + help='Do not use substitution matrix.') parser.add_argument("-D", "--latent-alphabet", default=None, type=int, help='Latent alphabet length.') - parser.add_argument("-L", "--length-model", default=False, type=bool, + parser.add_argument("-L", "--length-model", default=False, + action='store_true', help='Model sequence length.') parser.add_argument("--indel-prior-scale", default=1., type=float, help=('Indel prior scale parameter ' + @@ -252,18 +257,19 @@ def main(args): help='Number of epochs of training.') parser.add_argument("--anneal", default=0., type=float, help='Number of epochs to anneal beta over.') - parser.add_argument("-p", "--plots", default=True, type=bool, + parser.add_argument("--no-plots", default=False, action='store_true', help='Make plots.') - parser.add_argument("-s", "--save", default=True, type=bool, - help='Save plots and results.') + parser.add_argument("--no-save", default=False, action='store_true', + help='Do not save plots and results.') parser.add_argument("-outf", "--out-folder", default='.', help='Folder to save plots.') parser.add_argument("--split", default=0.2, type=float, help=('Fraction of dataset to holdout for testing')) - parser.add_argument("--jit", default=False, type=bool, + parser.add_argument("--jit", default=False, action='store_true', help='JIT compile the ELBO.') - parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') - parser.add_argument("--pin-mem", default=False, type=bool, + parser.add_argument("--cuda", default=False, action='store_true', + help='Use GPU.') + parser.add_argument("--pin-mem", default=False, action='store_true', help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 53ecc3ec1d..9313672485 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -16,7 +16,7 @@ Example run: python ProfileHMM.py -f PATH/ve6_full.fasta -b 10 -M 174 --indel-prior-bias 10. - -e 15 -lr 0.01 --jit True --cuda True + -e 15 -lr 0.01 --jit --cuda This should take about 9 minutes to run on a GPU. The perplexity should be around 5.8. @@ -115,12 +115,12 @@ def main(args): # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - if args.plots: + if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) plt.xlabel('step') plt.ylabel('loss') - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'ProfileHMM_plot.loss_{}.pdf'.format(time_stamp))) @@ -132,7 +132,7 @@ def main(args): plt.xlabel('position') plt.ylabel('probability of insert') plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'ProfileHMM_plot.insert_prob_{}.pdf'.format(time_stamp))) @@ -143,12 +143,12 @@ def main(args): plt.xlabel('position') plt.ylabel('probability of delete') plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) - if args.save: + if not args.no_save: plt.savefig(os.path.join( args.out_folder, 'ProfileHMM_plot.delete_prob_{}.pdf'.format(time_stamp))) - if args.save: + if not args.no_save: pyro.get_param_store().save(os.path.join( args.out_folder, 'ProfileHMM_results.params_{}.out'.format(time_stamp))) @@ -185,7 +185,8 @@ def main(args): help='Batch size.') parser.add_argument("-M", "--latent-seq-length", default=None, type=int, help='Latent sequence length.') - parser.add_argument("-L", "--length-model", default=False, type=bool, + parser.add_argument("-L", "--length-model", default=False, + action='store_true', help='Model sequence length.') parser.add_argument("--prior-scale", default=1., type=float, help='Prior scale parameter (all parameters).') @@ -199,18 +200,19 @@ def main(args): help='Gamma parameter for multistage learning rate.') parser.add_argument("-e", "--n-epochs", default=10, type=int, help='Number of epochs of training.') - parser.add_argument("-p", "--plots", default=True, type=bool, + parser.add_argument("--no-plots", default=False, action='store_true', help='Make plots.') - parser.add_argument("-s", "--save", default=True, type=bool, - help='Save plots and results.') + parser.add_argument("--no-save", default=False, action='store_true', + help='Do not save plots and results.') parser.add_argument("-outf", "--out-folder", default='.', help='Folder to save plots.') parser.add_argument("--split", default=0.2, type=float, help=('Fraction of dataset to holdout for testing')) - parser.add_argument("--jit", default=False, type=bool, + parser.add_argument("--jit", default=False, action='store_true', help='JIT compile the ELBO.') - parser.add_argument("--cuda", default=False, type=bool, help='Use GPU.') - parser.add_argument("--pin-mem", default=False, type=bool, + parser.add_argument("--cuda", default=False, action='store_true', + help='Use GPU.') + parser.add_argument("--pin-mem", default=False, action='store_true', help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() diff --git a/tests/test_examples.py b/tests/test_examples.py index 320aaa4b28..c31bfe74a2 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -59,10 +59,10 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/mue/FactorMuE.py --test --small -p False -s False', - 'contrib/mue/FactorMuE.py --test --small -ard True -idfac True -sub False -p False -s False', - 'contrib/mue/ProfileHMM.py --test --small -p False -s False', - 'contrib/mue/ProfileHMM.py --test --small -L True -p False -s False', + 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save', + 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', From a4fc2e71af7090ba6a554d29574c777416703613 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 16 Mar 2021 19:11:44 -0400 Subject: [PATCH 83/91] Wording edit. --- docs/source/contrib.mue.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 9c3f8edc94..681b72efb7 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -7,7 +7,7 @@ MuE ``pyro.contrib.mue`` provides modeling tools for working with biological sequence data. In particular it implements MuE distributions, which are used as -a fully probabilistic alternative to multiple sequence alignment-based +a fully generative alternative to multiple sequence alignment-based preprocessing. Reference: From 30898610b3c479888084a03d5a1e44ecd3e260ed Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 09:43:55 -0400 Subject: [PATCH 84/91] Add jit and cuda calls to test_examples. --- tests/test_examples.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_examples.py b/tests/test_examples.py index c31bfe74a2..9cd3e81e9c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -146,6 +146,10 @@ 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', + 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save --cuda', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda', + 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm.py --num-epochs=1 --cuda', 'dmm.py --num-epochs=1 --num-iafs=1 --cuda', @@ -214,6 +218,10 @@ def xfail_jit(*args, **kwargs): 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), + 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save --jit', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --jit', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --jit', + 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save --jit', xfail_jit('dmm.py --num-epochs=1 --jit'), xfail_jit('dmm.py --num-epochs=1 --num-iafs=1 --jit'), 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit', From aa388a3cbbe1b8bf7d53f86d4ee8aaa1fc8dba54 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 13:22:09 -0400 Subject: [PATCH 85/91] Add stop codon handling --- examples/contrib/mue/FactorMuE.py | 11 +++++++---- pyro/contrib/mue/dataloaders.py | 12 ++++++++++-- tests/contrib/mue/test_dataloaders.py | 17 ++++++++++------- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index cf262dbe08..647cb6a122 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -44,7 +44,7 @@ from pyro.optim import MultiStepLR -def generate_data(small_test): +def generate_data(small_test, include_stop): """Generate mini example dataset.""" if small_test: mult_dat = 1 @@ -52,7 +52,7 @@ def generate_data(small_test): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB') + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop) return dataset @@ -61,9 +61,10 @@ def main(args): # Load dataset. if args.test: - dataset = generate_data(args.small) + dataset = generate_data(args.small, args.include_stop) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, + include_stop=args.include_stop) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: # Train test split. @@ -234,6 +235,8 @@ def main(args): parser.add_argument("-L", "--length-model", default=False, action='store_true', help='Model sequence length.') + parser.add_argument("--include-stop", default=False, action='store_true', + help='Model sequence length.') parser.add_argument("--indel-prior-scale", default=1., type=float, help=('Indel prior scale parameter ' + '(when indel-factor=False).')) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index c36cff4bd2..a17035ef0d 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -29,13 +29,14 @@ class BiosequenceDataset(Dataset): """ def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None): + max_length=None, include_stop=False): super().__init__() # Get sequences. + self.include_stop = include_stop if source_type == 'list': - seqs = source + seqs = [seq + include_stop*'*' for seq in source] elif source_type == 'fasta': seqs = self._load_fasta(source) @@ -52,6 +53,9 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', alphabet = alphabets[alphabet] else: alphabet = np.array(list(alphabet)) + if self.include_stop: + alphabet = np.array(list(alphabet) + ['*']) + self.alphabet = alphabet self.alphabet_length = len(alphabet) # Build dataset. @@ -66,11 +70,15 @@ def _load_fasta(self, source): for line in fr: if line[0] == '>': if seq != '': + if self.include_stop: + seq += '*' seqs.append(seq) seq = '' else: seq += line.strip('\n') if seq != '': + if self.include_stop: + seq += '*' seqs.append(seq) return seqs diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index ab1c697f03..889bd20d9b 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -9,23 +9,24 @@ @pytest.mark.parametrize('source_type', ['list', 'fasta']) @pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', 'ATC']) -def test_biosequencedataset(source_type, alphabet): +@pytest.mark.parametrize('include_stop', [False, True]) +def test_biosequencedataset(source_type, alphabet, include_stop): # Define dataset. seqs = ['AATC', 'CA', 'T'] # Encode dataset, alternate approach. if alphabet in alphabets: - alphabet_list = list(alphabets[alphabet]) + alphabet_list = list(alphabets[alphabet]) + include_stop*['*'] else: - alphabet_list = list(alphabet) - L_data_check = [len(seq) for seq in seqs] + alphabet_list = list(alphabet) + include_stop*['*'] + L_data_check = [len(seq) + include_stop for seq in seqs] max_length_check = max(L_data_check) data_size_check = len(seqs) seq_data_check = torch.zeros([len(seqs), max_length_check, len(alphabet_list)]) for i in range(len(seqs)): - for j, s in enumerate(seqs[i]): + for j, s in enumerate(seqs[i] + include_stop*'*'): seq_data_check[i, j, list(alphabet_list).index(s)] = 1 # Setup data source. @@ -46,7 +47,8 @@ def test_biosequencedataset(source_type, alphabet): source = seqs # Load dataset. - dataset = BiosequenceDataset(source, source_type, alphabet) + dataset = BiosequenceDataset(source, source_type, alphabet, + include_stop=include_stop) # Check. assert torch.allclose(dataset.L_data, @@ -60,4 +62,5 @@ def test_biosequencedataset(source_type, alphabet): assert torch.allclose(dataset[ind][0], torch.cat([seq_data_check[0, None, :, :], seq_data_check[2, None, :, :]])) - assert torch.allclose(dataset[ind][1], torch.tensor([4., 1.])) + assert torch.allclose(dataset[ind][1], torch.tensor([4. + include_stop, + 1. + include_stop])) From df10c7b7bf9d150f58e2b7affedf02fc7dd6c058 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 15:01:41 -0400 Subject: [PATCH 86/91] Remove old length modeling mechanism in favor of using stop symbols --- examples/contrib/mue/FactorMuE.py | 8 +- examples/contrib/mue/ProfileHMM.py | 17 ++-- pyro/contrib/mue/models.py | 111 +++++++++----------------- tests/contrib/mue/test_dataloaders.py | 3 + tests/contrib/mue/test_models.py | 12 +-- tests/test_examples.py | 12 +-- 6 files changed, 61 insertions(+), 102 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 647cb6a122..a01b7594e3 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -19,7 +19,7 @@ --indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace --jit --cuda This should take about 8 minutes to run on a GPU. The latent space should show -multiple small clusters, and the perplexity should be around 4.0. +multiple small clusters, and the perplexity should be around 4. Reference: [1] E. N. Weinstein, D. S. Marks (2021) @@ -101,7 +101,6 @@ def main(args): substitution_matrix=(not args.no_substitution_matrix), substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, - length_model=args.length_model, cuda=args.cuda, pin_memory=args.pin_mem) @@ -232,11 +231,8 @@ def main(args): help='Do not use substitution matrix.') parser.add_argument("-D", "--latent-alphabet", default=None, type=int, help='Latent alphabet length.') - parser.add_argument("-L", "--length-model", default=False, - action='store_true', - help='Model sequence length.') parser.add_argument("--include-stop", default=False, action='store_true', - help='Model sequence length.') + help='Include stop codon symbol.') parser.add_argument("--indel-prior-scale", default=1., type=float, help=('Indel prior scale parameter ' + '(when indel-factor=False).')) diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 9313672485..7f33b2b060 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -18,7 +18,7 @@ python ProfileHMM.py -f PATH/ve6_full.fasta -b 10 -M 174 --indel-prior-bias 10. -e 15 -lr 0.01 --jit --cuda This should take about 9 minutes to run on a GPU. The perplexity should be -around 5.8. +around 6. References: [1] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) @@ -48,7 +48,7 @@ from pyro.optim import MultiStepLR -def generate_data(small_test): +def generate_data(small_test, include_stop): """Generate mini example dataset.""" if small_test: mult_dat = 1 @@ -56,7 +56,7 @@ def generate_data(small_test): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB') + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop) return dataset @@ -67,9 +67,10 @@ def main(args): # Load dataset. if args.test: - dataset = generate_data(args.small) + dataset = generate_data(args.small, args.include_stop) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet) + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, + include_stop=args.include_stop) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: # Train test split. @@ -92,7 +93,6 @@ def main(args): if latent_seq_length is None: latent_seq_length = int(dataset.max_length * 1.1) model = ProfileHMM(latent_seq_length, dataset.alphabet_length, - length_model=args.length_model, prior_scale=args.prior_scale, indel_prior_bias=args.indel_prior_bias, cuda=args.cuda, @@ -185,9 +185,8 @@ def main(args): help='Batch size.') parser.add_argument("-M", "--latent-seq-length", default=None, type=int, help='Latent sequence length.') - parser.add_argument("-L", "--length-model", default=False, - action='store_true', - help='Model sequence length.') + parser.add_argument("--include-stop", default=False, action='store_true', + help='Include stop codon symbol.') parser.add_argument("--prior-scale", default=1., type=float, help='Prior scale parameter (all parameters).') parser.add_argument("--indel-prior-bias", default=10., type=float, diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index ef270d4a99..53c0dd0b54 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -24,7 +24,8 @@ class ProfileHMM(nn.Module): - """Profile HMM. + """ + Profile HMM. This model consists of a constant distribution (a delta function) over the regressor sequence, plus a MuE observation distribution. The priors @@ -35,8 +36,6 @@ class ProfileHMM(nn.Module): Must be greater than or equal to 1. :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for amino acids). - :param bool length_model: Model the length of the sequence with a Poisson - distribution. :param float prior_scale: Standard deviation of the prior distribution. :param float indel_prior_bias: Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower @@ -45,7 +44,7 @@ class ProfileHMM(nn.Module): :param bool pin_memory: Pin memory for faster GPU transfer. """ def __init__(self, latent_seq_length, alphabet_length, - length_model=False, prior_scale=1., indel_prior_bias=10., + prior_scale=1., indel_prior_bias=10., cuda=False, pin_memory=False): super().__init__() assert isinstance(cuda, bool) @@ -62,8 +61,6 @@ def __init__(self, latent_seq_length, alphabet_length, self.insert_seq_shape = (latent_seq_length+1, alphabet_length) self.indel_shape = (latent_seq_length, 3, 2) - assert isinstance(length_model, bool) - self.length_model = length_model assert isinstance(prior_scale, float) self.prior_scale = prior_scale assert isinstance(indel_prior_bias, float) @@ -72,7 +69,7 @@ def __init__(self, latent_seq_length, alphabet_length, # Initialize state arranger. self.statearrange = Profile(latent_seq_length) - def model(self, seq_data, L_data, local_scale): + def model(self, seq_data, local_scale): # Latent sequence. precursor_seq = pyro.sample("precursor_seq", dist.Normal( @@ -101,25 +98,16 @@ def model(self, seq_data, L_data, local_scale): self.statearrange(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits)) - # Length model. - if self.length_model: - length = pyro.sample("length", dist.Normal( - torch.tensor(200.), torch.tensor(1000.))) - L_mean = softplus(length) - - with pyro.plate("batch", L_data.shape[0]): + with pyro.plate("batch", seq_data.shape[0]): with poutine.scale(scale=local_scale): - - if self.length_model: - pyro.sample("obs_L", dist.Poisson(L_mean), - obs=L_data) + # Observations. pyro.sample("obs_seq", MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits), obs=seq_data) - def guide(self, seq_data, L_data, local_scale): + def guide(self, seq_data, local_scale): # Sequence. precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)) @@ -150,13 +138,6 @@ def guide(self, seq_data, L_data, local_scale): pyro.sample("delete", dist.Normal( delete_q_mn, softplus(delete_q_sd)).to_event(3)) - # Length. - if self.length_model: - length_q_mn = pyro.param("length_q_mn", torch.zeros(1)) - length_q_sd = pyro.param("length_q_sd", torch.zeros(1)) - pyro.sample("length", dist.Normal( - length_q_mn, softplus(length_q_sd))) - def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, jit=False): """ @@ -182,7 +163,7 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, 'milestones': [], 'gamma': 0.5}) # Initialize guide. - self.guide(None, None, None) + self.guide(None, None) dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory) # Setup stochastic variational inference. @@ -198,9 +179,9 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, for epoch in range(epochs): for seq_data, L_data in dataload: if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() - loss = svi.step(seq_data, L_data, - torch.tensor(len(dataset)/L_data.shape[0])) + seq_data = seq_data.cuda() + loss = svi.step(seq_data, + torch.tensor(len(dataset)/seq_data.shape[0])) losses.append(loss) scheduler.step() print(epoch, loss, ' ', datetime.datetime.now() - t0) @@ -220,7 +201,7 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) # Initialize guide. - self.guide(None, None, None) + self.guide(None, None) if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: @@ -231,11 +212,10 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( - svi, dataload_train, len(dataset_train), self.length_model) + svi, dataload_train, len(dataset_train)) if dataset_test is not None: test_lp, test_perplex = self._evaluate_local_elbo( - svi, dataload_test, len(dataset_test), - self.length_model) + svi, dataload_test, len(dataset_test)) return train_lp, test_lp, train_perplex, test_perplex else: return train_lp, None, train_perplex, None @@ -244,7 +224,7 @@ def _local_variables(self, name, site): """Return per datapoint random variables in model.""" return name in ['obs_L', 'obs_seq'] - def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): + def _evaluate_local_elbo(self, svi, dataload, data_size): """Evaluate elbo and average per residue perplexity.""" lp, perplex = 0., 0. with torch.no_grad(): @@ -252,8 +232,8 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): if self.cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() conditioned_model = poutine.condition(self.model, data={ - "obs_L": L_data, "obs_seq": seq_data}) - args = (seq_data, L_data, torch.tensor(1.)) + "obs_seq": seq_data}) + args = (seq_data, torch.tensor(1.)) guide_tr = poutine.trace(self.guide).get_trace(*args) model_tr = poutine.trace(poutine.replay( conditioned_model, trace=guide_tr)).get_trace(*args) @@ -261,8 +241,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): - guide_tr.log_prob_sum(self._local_variables) ).cpu().numpy() lp += local_elbo - perplex += -local_elbo / (L_data[0].cpu().numpy() + - int(self.length_model)) + perplex += -local_elbo / L_data[0].cpu().numpy() perplex = np.exp(perplex / data_size) return lp, perplex @@ -285,7 +264,8 @@ def forward(self, data): class FactorMuE(nn.Module): - """FactorMuE + """ + FactorMuE This model consists of probabilistic PCA plus a MuE output distribution. @@ -324,8 +304,6 @@ class FactorMuE(nn.Module): substitution_matrix is True). :param int latent_alphabet_length: Length of the alphabet in the latent regressor sequence. - :param bool length_model: Model the length of the sequence with a Poisson - distribution, with mean dependent on the latent pPCA model. :param bool cuda: Transfer data onto the GPU during training. :param bool pin_memory: Pin memory for faster GPU transfer. :param float epsilon: A small value for numerical stability. @@ -344,7 +322,6 @@ def __init__(self, data_length, alphabet_length, z_dim, substitution_matrix=True, substitution_prior_scale=10., latent_alphabet_length=None, - length_model=False, cuda=False, pin_memory=False, epsilon=1e-32): @@ -374,14 +351,12 @@ def __init__(self, data_length, alphabet_length, z_dim, self.indel_shape = (latent_seq_length, 3, 2) self.total_factor_size = ( (2*latent_seq_length+1)*latent_alphabet_length + - 2*indel_factor_dependence*latent_seq_length*3*2 + - length_model) + 2*indel_factor_dependence*latent_seq_length*3*2) # Architecture. self.indel_factor_dependence = indel_factor_dependence self.ARD_prior = ARD_prior self.substitution_matrix = substitution_matrix - self.length_model = length_model # Priors. assert isinstance(indel_prior_scale, float) @@ -414,10 +389,6 @@ def decoder(self, z, W, B, inverse_temp): v = torch.mm(z, W) + B out = dict() - if self.length_model: - # Extract expected length. - L_v = v[:, -1] - out['L_mean'] = softplus(L_v) if self.indel_factor_dependence: # Extract insertion and deletion parameters. ind0 = (2*self.latent_seq_length+1)*self.latent_alphabet_length @@ -445,7 +416,7 @@ def decoder(self, z, W, B, inverse_temp): return out - def model(self, seq_data, L_data, local_scale, local_prior_scale): + def model(self, seq_data, local_scale, local_prior_scale): # ARD prior. if self.ARD_prior: @@ -492,7 +463,7 @@ def model(self, seq_data, L_data, local_scale, local_prior_scale): self.latent_alphabet_length, self.alphabet_length]) ).to_event(2)) - with pyro.plate("batch", L_data.shape[0]): + with pyro.plate("batch", seq_data.shape[0]): with poutine.scale(scale=local_scale): with poutine.scale(scale=local_prior_scale): # Sample latent variable from prior. @@ -524,16 +495,13 @@ def model(self, seq_data, L_data, local_scale, local_prior_scale): decoded['insert_seq_logits'], insert_logits, delete_logits)) # Draw samples. - if self.length_model: - pyro.sample("obs_L", dist.Poisson(decoded['L_mean']), - obs=L_data) pyro.sample("obs_seq", MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits), obs=seq_data) - def guide(self, seq_data, L_data, local_scale, local_prior_scale): + def guide(self, seq_data, local_scale, local_prior_scale): # Register encoder with pyro. pyro.module("encoder", self.encoder) @@ -586,7 +554,7 @@ def guide(self, seq_data, L_data, local_scale, local_prior_scale): substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) # Per datapoint local latent variables. - with pyro.plate("batch", L_data.shape[0]): + with pyro.plate("batch", seq_data.shape[0]): # Encode sequences. z_loc, z_scale = self.encoder(seq_data) # Scale log likelihood to account for mini-batching. @@ -631,8 +599,8 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, # Initialize guide. for seq_data, L_data in dataload: if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() - self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + seq_data = seq_data.cuda() + self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break # Setup stochastic variational inference. if jit: @@ -648,10 +616,9 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, for epoch in range(epochs): for seq_data, L_data in dataload: if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() + seq_data = seq_data.cuda() loss = svi.step( - seq_data, L_data, - torch.tensor(len(dataset)/L_data.shape[0]), + seq_data, torch.tensor(len(dataset)/seq_data.shape[0]), self._beta_anneal(step_i, batch_size, len(dataset), anneal_length)) losses.append(loss) @@ -685,8 +652,8 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): # Initialize guide. for seq_data, L_data in dataload_train: if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() - self.guide(seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + seq_data = seq_data.cuda() + self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break if jit: Elbo = JitTrace_ELBO(ignore_jit_warnings=True) @@ -698,11 +665,10 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( - svi, dataload_train, len(dataset_train), self.length_model) + svi, dataload_train, len(dataset_train)) if dataset_test is not None: test_lp, test_perplex = self._evaluate_local_elbo( - svi, dataload_test, len(dataset_test), - self.length_model) + svi, dataload_test, len(dataset_test)) return train_lp, test_lp, train_perplex, test_perplex else: return train_lp, None, train_perplex, None @@ -711,7 +677,7 @@ def _local_variables(self, name, site): """Return per datapoint random variables in model.""" return name in ['latent', 'obs_L', 'obs_seq'] - def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): + def _evaluate_local_elbo(self, svi, dataload, data_size): """Evaluate elbo and average per residue perplexity.""" lp, perplex = 0., 0. with torch.no_grad(): @@ -719,8 +685,8 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): if self.cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() conditioned_model = poutine.condition(self.model, data={ - "obs_L": L_data, "obs_seq": seq_data}) - args = (seq_data, L_data, torch.tensor(1.), torch.tensor(1.)) + "obs_seq": seq_data}) + args = (seq_data, torch.tensor(1.), torch.tensor(1.)) guide_tr = poutine.trace(self.guide).get_trace(*args) model_tr = poutine.trace(poutine.replay( conditioned_model, trace=guide_tr)).get_trace(*args) @@ -728,8 +694,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size, length_model): - guide_tr.log_prob_sum(self._local_variables) ).cpu().numpy() lp += local_elbo - perplex += -local_elbo / (L_data[0].cpu().numpy() + - int(self.length_model)) + perplex += -local_elbo / L_data[0].cpu().numpy() perplex = np.exp(perplex / data_size) return lp, perplex @@ -748,7 +713,7 @@ def embed(self, dataset, batch_size=None): z_locs, z_scales = [], [] for seq_data, L_data in dataload: if self.cuda: - seq_data, L_data = seq_data.cuda(), L_data.cuda() + seq_data = seq_data.cuda() z_loc, z_scale = self.encoder(seq_data) z_locs.append(z_loc.cpu()) z_scales.append(z_scale.cpu()) diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 889bd20d9b..94ba2fa02a 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -64,3 +64,6 @@ def test_biosequencedataset(source_type, alphabet, include_stop): seq_data_check[2, None, :, :]])) assert torch.allclose(dataset[ind][1], torch.tensor([4. + include_stop, 1. + include_stop])) + dataload = torch.utils.data.DataLoader(dataset, batch_size=2) + for seq_data, L_data in dataload: + assert seq_data.shape[0] == L_data.shape[0] diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 127af5043a..5f2ba9634b 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -12,9 +12,8 @@ from pyro.optim import MultiStepLR -@pytest.mark.parametrize('length_model', [False, True]) @pytest.mark.parametrize('jit', [False, True]) -def test_ProfileHMM_smoke(length_model, jit): +def test_ProfileHMM_smoke(jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] alph = 'AB' @@ -25,8 +24,7 @@ def test_ProfileHMM_smoke(length_model, jit): 'optim_args': {'lr': 0.1}, 'milestones': [20, 100, 1000, 2000], 'gamma': 0.5}) - model = ProfileHMM(int(dataset.max_length*1.1), dataset.alphabet_length, - length_model) + model = ProfileHMM(int(dataset.max_length*1.1), dataset.alphabet_length) n_epochs = 5 batch_size = 2 losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) @@ -46,10 +44,9 @@ def test_ProfileHMM_smoke(length_model, jit): @pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) @pytest.mark.parametrize('ARD_prior', [False, True]) @pytest.mark.parametrize('substitution_matrix', [False, True]) -@pytest.mark.parametrize('length_model', [False, True]) @pytest.mark.parametrize('jit', [False, True]) def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, - ARD_prior, substitution_matrix, length_model, jit): + ARD_prior, substitution_matrix, jit): # Setup dataset. seqs = ['BABBA', 'BAAB', 'BABBB'] alph = 'AB' @@ -65,8 +62,7 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, indel_factor_dependence=indel_factor_dependence, z_prior_distribution=z_prior_distribution, ARD_prior=ARD_prior, - substitution_matrix=substitution_matrix, - length_model=length_model) + substitution_matrix=substitution_matrix) n_epochs = 5 anneal_length = 2 batch_size = 2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 9cd3e81e9c..a86bd9a3b3 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -59,10 +59,10 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save', 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save', 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save', - 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', @@ -146,10 +146,10 @@ 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', - 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save --cuda', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda', 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda', 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda', - 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save --cuda', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm.py --num-epochs=1 --cuda', 'dmm.py --num-epochs=1 --num-iafs=1 --cuda', @@ -218,10 +218,10 @@ def xfail_jit(*args, **kwargs): 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), - 'contrib/mue/FactorMuE.py --test --small --no-plots --no-save --jit', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --jit', 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --jit', 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --jit', - 'contrib/mue/ProfileHMM.py --test --small -L --no-plots --no-save --jit', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --jit', xfail_jit('dmm.py --num-epochs=1 --jit'), xfail_jit('dmm.py --num-epochs=1 --num-iafs=1 --jit'), 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit', From 767d9e0bbb80decf5a13cd1547c418209ff2959a Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 15:59:58 -0400 Subject: [PATCH 87/91] Option to keep data on cpu --- examples/contrib/mue/FactorMuE.py | 20 ++++++++++++++------ examples/contrib/mue/ProfileHMM.py | 2 +- pyro/contrib/mue/dataloaders.py | 19 +++++++++++++++---- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index a01b7594e3..a122be7266 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -44,7 +44,7 @@ from pyro.optim import MultiStepLR -def generate_data(small_test, include_stop): +def generate_data(small_test, include_stop, device): """Generate mini example dataset.""" if small_test: mult_dat = 1 @@ -52,7 +52,8 @@ def generate_data(small_test, include_stop): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop) + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, + device=device) return dataset @@ -60,11 +61,16 @@ def generate_data(small_test, include_stop): def main(args): # Load dataset. + if args.cpu_data and args.cuda: + device = torch.device('cpu') + else: + device = None if args.test: - dataset = generate_data(args.small, args.include_stop) + dataset = generate_data(args.small, args.include_stop, device) else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, - include_stop=args.include_stop) + include_stop=args.include_stop, + device=device) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: # Train test split. @@ -232,7 +238,7 @@ def main(args): parser.add_argument("-D", "--latent-alphabet", default=None, type=int, help='Latent alphabet length.') parser.add_argument("--include-stop", default=False, action='store_true', - help='Include stop codon symbol.') + help='Include stop symbol at the end of each sequence.') parser.add_argument("--indel-prior-scale", default=1., type=float, help=('Indel prior scale parameter ' + '(when indel-factor=False).')) @@ -268,8 +274,10 @@ def main(args): help='JIT compile the ELBO.') parser.add_argument("--cuda", default=False, action='store_true', help='Use GPU.') + parser.add_argument("--cpu-data", default=False, action='store_true', + help='Keep data on CPU (for large datasets).') parser.add_argument("--pin-mem", default=False, action='store_true', - help='Use pin_memory for faster GPU transfer.') + help='Use pin_memory for faster CPU to GPU transfer.') args = parser.parse_args() if args.cuda: diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 7f33b2b060..0a202b230a 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -186,7 +186,7 @@ def main(args): parser.add_argument("-M", "--latent-seq-length", default=None, type=int, help='Latent sequence length.') parser.add_argument("--include-stop", default=False, action='store_true', - help='Include stop codon symbol.') + help='Include stop symbol at the end of each sequence.') parser.add_argument("--prior-scale", default=1., type=float, help='Prior scale parameter (all parameters).') parser.add_argument("--indel-prior-bias", default=10., type=float, diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index a17035ef0d..3aa70859ab 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -26,13 +26,22 @@ class BiosequenceDataset(Dataset): :param int max_length: Total length of the one-hot representation of the sequences, including zero padding. Defaults to the maximum sequence length in the dataset. + :param bool include_stop: Append stop symbol to the end of each sequence + and add the stop symbol to the alphabet. """ def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None, include_stop=False): + max_length=None, include_stop=False, device=None): super().__init__() + # Determine device + if device is None: + device = torch.tensor(0.).device + elif type(device) == str: + device = torch.device(device) + self.device = device + # Get sequences. self.include_stop = include_stop if source_type == 'list': @@ -41,7 +50,8 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', seqs = self._load_fasta(source) # Get lengths. - self.L_data = torch.tensor([float(len(seq)) for seq in seqs]) + self.L_data = torch.tensor([float(len(seq)) for seq in seqs], + device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: @@ -86,9 +96,10 @@ def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] - ).astype(np.float64)) + ).astype(np.float64), device=self.device) # Pad. - x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)])]) + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)], + device=self.device)]) return x From 7617acbff675692df49c131da6cd2c03417024e3 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 16:09:45 -0400 Subject: [PATCH 88/91] check device of seq data --- pyro/contrib/mue/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 53c0dd0b54..5058bd3f81 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -599,6 +599,7 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, # Initialize guide. for seq_data, L_data in dataload: if self.cuda: + print(seq_data.device) seq_data = seq_data.cuda() self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break From 839adf58f6177b8abd24eba28a690efe35bf0af0 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 16:19:34 -0400 Subject: [PATCH 89/91] CPU storage option for profile HMM --- examples/contrib/mue/ProfileHMM.py | 16 ++++++++++++---- pyro/contrib/mue/dataloaders.py | 4 ++-- pyro/contrib/mue/models.py | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 0a202b230a..61df67f039 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -48,7 +48,7 @@ from pyro.optim import MultiStepLR -def generate_data(small_test, include_stop): +def generate_data(small_test, include_stop, device): """Generate mini example dataset.""" if small_test: mult_dat = 1 @@ -56,7 +56,8 @@ def generate_data(small_test, include_stop): mult_dat = 10 seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop) + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, + device=device) return dataset @@ -66,11 +67,16 @@ def main(args): pyro.set_rng_seed(args.rng_seed) # Load dataset. + if args.cpu_data and args.cuda: + device = torch.device('cpu') + else: + device = None if args.test: - dataset = generate_data(args.small, args.include_stop) + dataset = generate_data(args.small, args.include_stop, device) else: dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, - include_stop=args.include_stop) + include_stop=args.include_stop, + device=device) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.: # Train test split. @@ -211,6 +217,8 @@ def main(args): help='JIT compile the ELBO.') parser.add_argument("--cuda", default=False, action='store_true', help='Use GPU.') + parser.add_argument("--cpu-data", default=False, action='store_true', + help='Keep data on CPU (for large datasets).') parser.add_argument("--pin-mem", default=False, action='store_true', help='Use pin_memory for faster GPU transfer.') args = parser.parse_args() diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 3aa70859ab..198044265b 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -28,6 +28,8 @@ class BiosequenceDataset(Dataset): length in the dataset. :param bool include_stop: Append stop symbol to the end of each sequence and add the stop symbol to the alphabet. + :param ~torch.device device: Device on which data should be stored in + memory. """ def __init__(self, source, source_type='list', alphabet='amino-acid', @@ -38,8 +40,6 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', # Determine device if device is None: device = torch.tensor(0.).device - elif type(device) == str: - device = torch.device(device) self.device = device # Get sequences. diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 5058bd3f81..b7879bd8d2 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -179,6 +179,7 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, for epoch in range(epochs): for seq_data, L_data in dataload: if self.cuda: + print(seq_data.device) seq_data = seq_data.cuda() loss = svi.step(seq_data, torch.tensor(len(dataset)/seq_data.shape[0])) @@ -599,7 +600,6 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, # Initialize guide. for seq_data, L_data in dataload: if self.cuda: - print(seq_data.device) seq_data = seq_data.cuda() self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break From 4dcf3854017a09b410071d70232b0af6b9019235 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Wed, 17 Mar 2021 16:23:57 -0400 Subject: [PATCH 90/91] Update example tests --- pyro/contrib/mue/dataloaders.py | 2 +- pyro/contrib/mue/models.py | 1 - tests/test_examples.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 198044265b..b6fbbb4489 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -28,7 +28,7 @@ class BiosequenceDataset(Dataset): length in the dataset. :param bool include_stop: Append stop symbol to the end of each sequence and add the stop symbol to the alphabet. - :param ~torch.device device: Device on which data should be stored in + :param torch.device device: Device on which data should be stored in memory. """ diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index b7879bd8d2..53c0dd0b54 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -179,7 +179,6 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, for epoch in range(epochs): for seq_data, L_data in dataload: if self.cuda: - print(seq_data.device) seq_data = seq_data.cuda() loss = svi.step(seq_data, torch.tensor(len(dataset)/seq_data.shape[0])) diff --git a/tests/test_examples.py b/tests/test_examples.py index a86bd9a3b3..9298a78f11 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -146,9 +146,9 @@ 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', - 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda --cpu-data --pin-mem', 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda', - 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda --cpu-data --pin-mem', 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm.py --num-epochs=1 --cuda', From 2e192a8804f1c849074ec579559efc110300d312 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Sun, 21 Mar 2021 20:11:39 -0400 Subject: [PATCH 91/91] Addressed fritzo comments --- docs/source/contrib.mue.rst | 4 +-- pyro/contrib/mue/missingdatahmm.py | 18 ++-------- pyro/contrib/mue/models.py | 42 ++++++++++++------------ tests/contrib/mue/test_statearrangers.py | 1 - 4 files changed, 25 insertions(+), 40 deletions(-) diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst index 681b72efb7..116279c3c6 100644 --- a/docs/source/contrib.mue.rst +++ b/docs/source/contrib.mue.rst @@ -1,5 +1,5 @@ -MuE -=== +Biological Sequence Models with MuE +=================================== .. automodule:: pyro.contrib.mue .. warning:: Code in ``pyro.contrib.mue`` is under development. diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index e4f49aa879..eb414bf82c 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -83,22 +83,8 @@ def log_prob(self, value): Variable length observation sequences can be handled by padding the sequence with zeros at the end. """ - # observation_logits: - # batch_shape (option) x state_dim x observation_dim - # value: - # batch_shape (option) x num_steps x observation_dim - # value_logits - # batch_shape (option) x num_steps x state_dim (new) - # transition_logits: - # batch_shape (option) x state_dim (old) x state_dim (new) - # result 1 - # batch_shape (option) x num_steps-1 x state_dim (old) x state_dim (new) - # result 2 - # batch_shape (option) x state_dim (old) x state_dim (new) - # initial_logits - # batch_shape (option) x state_dim - # result 3 - # batch_shape (option) + + assert value.shape[-1] == self.event_shape[1] # Combine observation and transition factors. value_logits = torch.matmul( diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index 53c0dd0b54..fb55a2fa9f 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -48,7 +48,7 @@ def __init__(self, latent_seq_length, alphabet_length, cuda=False, pin_memory=False): super().__init__() assert isinstance(cuda, bool) - self.cuda = cuda + self.is_cuda = cuda assert isinstance(pin_memory, bool) self.pin_memory = pin_memory @@ -168,17 +168,17 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, pin_memory=self.pin_memory) # Setup stochastic variational inference. if jit: - Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: - Elbo = Trace_ELBO() - svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Run inference. losses = [] t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data = seq_data.cuda() loss = svi.step(seq_data, torch.tensor(len(dataset)/seq_data.shape[0])) @@ -203,12 +203,12 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): # Initialize guide. self.guide(None, None) if jit: - Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: - Elbo = Trace_ELBO() + elbo = Trace_ELBO() scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) # Setup stochastic variational inference. - svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( @@ -229,7 +229,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size): lp, perplex = 0., 0. with torch.no_grad(): for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() conditioned_model = poutine.condition(self.model, data={ "obs_seq": seq_data}) @@ -327,7 +327,7 @@ def __init__(self, data_length, alphabet_length, z_dim, epsilon=1e-32): super().__init__() assert isinstance(cuda, bool) - self.cuda = cuda + self.is_cuda = cuda assert isinstance(pin_memory, bool) self.pin_memory = pin_memory @@ -598,16 +598,16 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, pin_memory=self.pin_memory) # Initialize guide. for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data = seq_data.cuda() self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break # Setup stochastic variational inference. if jit: - Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: - Elbo = Trace_ELBO() - svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Run inference. losses = [] @@ -615,7 +615,7 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, t0 = datetime.datetime.now() for epoch in range(epochs): for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data = seq_data.cuda() loss = svi.step( seq_data, torch.tensor(len(dataset)/seq_data.shape[0]), @@ -651,17 +651,17 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): shuffle=False) # Initialize guide. for seq_data, L_data in dataload_train: - if self.cuda: + if self.is_cuda: seq_data = seq_data.cuda() self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) break if jit: - Elbo = JitTrace_ELBO(ignore_jit_warnings=True) + elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: - Elbo = Trace_ELBO() + elbo = Trace_ELBO() scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) # Setup stochastic variational inference. - svi = SVI(self.model, self.guide, scheduler, loss=Elbo) + svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( @@ -682,7 +682,7 @@ def _evaluate_local_elbo(self, svi, dataload, data_size): lp, perplex = 0., 0. with torch.no_grad(): for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() conditioned_model = poutine.condition(self.model, data={ "obs_seq": seq_data}) @@ -712,7 +712,7 @@ def embed(self, dataset, batch_size=None): with torch.no_grad(): z_locs, z_scales = [], [] for seq_data, L_data in dataload: - if self.cuda: + if self.is_cuda: seq_data = seq_data.cuda() z_loc, z_scale = self.encoder(seq_data) z_locs.append(z_loc.cpu()) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 7215edc9c8..1fd4f672a9 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -223,7 +223,6 @@ def test_profile_trivial_cases(M): a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) # --- Compute expected value per step. --- - # TODO: replace with VariableLengthDiscreteHMM function once implemented. Eyln = torch.zeros([batch_size, M, B]) ai = a0ln for j in range(M):