From 84a84f054baa76df9d8445a82ec0fd0eacbf29af Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 28 Oct 2020 09:53:38 -0400 Subject: [PATCH 1/7] Fixes for PyTorch 1.7 release --- pyro/ops/fft.py | 11 +++++++++++ pyro/ops/linalg.py | 2 +- pyro/ops/stats.py | 5 +++-- pyro/ops/tensor_utils.py | 9 +++++---- setup.cfg | 1 + 5 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 pyro/ops/fft.py diff --git a/pyro/ops/fft.py b/pyro/ops/fft.py new file mode 100644 index 0000000000..93f0cc98c8 --- /dev/null +++ b/pyro/ops/fft.py @@ -0,0 +1,11 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +try: + # This works in PyTorch 1.7+ + import torch.fft as torch_fft +except ModuleNotFoundError: + # This works in PyTorch 1.6 + torch_fft = torch diff --git a/pyro/ops/linalg.py b/pyro/ops/linalg.py index 014fb320d0..0b2c29a43c 100644 --- a/pyro/ops/linalg.py +++ b/pyro/ops/linalg.py @@ -50,7 +50,7 @@ def eig_3d(H): p = torch.sqrt(p2 / 6) B = (1 / p).unsqueeze(-1).unsqueeze(-1) * (H - q.unsqueeze(-1).unsqueeze(-1) * torch.eye(3)) r = determinant_3d(B) / 2 - phi = (r.acos() / 3).unsqueeze(-1).unsqueeze(-1).expand(r.shape + (3, 3)) + phi = (r.acos() / 3).unsqueeze(-1).unsqueeze(-1).expand(r.shape + (3, 3)).clone() phi[r < -1 + 1e-6] = math.pi / 3 phi[r > 1 - 1e-6] = 0. diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 78f8b9d2ff..7a007117f6 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -7,6 +7,7 @@ import torch from .tensor_utils import next_fast_len +from .fft import irfft, rfft def _compute_chain_variance_stats(input): @@ -109,13 +110,13 @@ def autocorrelation(input, dim=0): centered_signal = torch.cat([centered_signal, pad], dim=-1) # Fourier transform - freqvec = torch.rfft(centered_signal, signal_ndim=1, onesided=False) + freqvec = rfft(centered_signal, signal_ndim=1, onesided=False) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec.pow(2).sum(-1, keepdim=True) freqvec_gram = torch.cat([freqvec_gram, torch.zeros(freqvec_gram.shape, dtype=input.dtype, device=input.device)], dim=-1) # inverse Fourier transform - autocorr = torch.irfft(freqvec_gram, signal_ndim=1, onesided=False) + autocorr = irfft(freqvec_gram, signal_ndim=1, onesided=False) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 2167da7287..cb559e9010 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -5,6 +5,7 @@ import torch +from .fft import irfft, rfft _ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0) @@ -220,10 +221,10 @@ def convolve(signal, kernel, mode='full'): padded_size = m + n - 1 # Round up for cheaper fft. fast_ftt_size = next_fast_len(padded_size) - f_signal = torch.rfft(torch.nn.functional.pad(signal, (0, fast_ftt_size - m)), 1, onesided=False) - f_kernel = torch.rfft(torch.nn.functional.pad(kernel, (0, fast_ftt_size - n)), 1, onesided=False) + f_signal = rfft(torch.nn.functional.pad(signal, (0, fast_ftt_size - m)), 1, onesided=False) + f_kernel = rfft(torch.nn.functional.pad(kernel, (0, fast_ftt_size - n)), 1, onesided=False) f_result = _complex_mul(f_signal, f_kernel) - result = torch.irfft(f_result, 1, onesided=False) + result = irfft(f_result, 1, onesided=False) start_idx = (padded_size - truncate) // 2 return result[..., start_idx: start_idx + truncate] @@ -328,7 +329,7 @@ def idct(x, dim=-1): half_size = N // 2 + 1 Y = _complex_mul(coef[..., :half_size, :], X[..., :half_size, :]) # Step 2 - y = torch.irfft(Y, 1, onesided=True, signal_sizes=(N,)) + y = irfft(Y, 1, onesided=True, signal_sizes=(N,)) # Step 3 return torch.stack([y, y.flip(-1)], axis=-1).reshape(x.shape[:-1] + (-1,))[..., :N] diff --git a/setup.cfg b/setup.cfg index 94d043a69c..65cc46305e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ filterwarnings = error ignore:numpy.ufunc size changed:RuntimeWarning ignore:numpy.dtype size changed:RuntimeWarning ignore:Mixed memory format inputs detected:UserWarning + ignore:Setting attributes on ParameterDict:UserWarning ignore::DeprecationWarning once::DeprecationWarning From b4b86a4981730ba949c6017cb477668d396e8674 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 5 Nov 2020 17:40:01 -0500 Subject: [PATCH 2/7] Update spanning_tree.cpp --- pyro/distributions/spanning_tree.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/spanning_tree.cpp b/pyro/distributions/spanning_tree.cpp index 2fd03e14d1..f2a9d22bea 100644 --- a/pyro/distributions/spanning_tree.cpp +++ b/pyro/distributions/spanning_tree.cpp @@ -142,7 +142,7 @@ at::Tensor sample_tree_approx(at::Tensor edge_logits) { // the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2. auto edge_ids = torch::empty({E}, at::kLong); // This maps each vertex to whether it is a member of the cumulative tree. - auto components = torch::zeros({V}, at::kByte); + auto components = torch::zeros({V}, at::kBool); // Sample the first edge at random. auto probs = (edge_logits - edge_logits.max()).exp(); From 5a8fd9479db99bd993af2f86ce5fd7bad9efb185 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 14 Nov 2020 20:56:15 -0600 Subject: [PATCH 3/7] make fft compatible with pytorch 1.6 --- pyro/ops/fft.py | 20 ++++++++++++++++-- pyro/ops/stats.py | 12 ++++------- pyro/ops/tensor_utils.py | 45 +++++++++++++++++----------------------- 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/pyro/ops/fft.py b/pyro/ops/fft.py index 93f0cc98c8..abf7b2601a 100644 --- a/pyro/ops/fft.py +++ b/pyro/ops/fft.py @@ -5,7 +5,23 @@ try: # This works in PyTorch 1.7+ - import torch.fft as torch_fft + from torch.fft import irfft, rfft except ModuleNotFoundError: # This works in PyTorch 1.6 - torch_fft = torch + def rfft(input, n=None): + if n is not None: + m = input.size(-1) + if n > m: + input = torch.nn.functional.pad(input, (0, n - m)) + elif n < m: + input = input[..., :n] + return torch.view_as_complex(torch.rfft(input, 1)) + + def irfft(input, n=None): + if torch.is_complex(input): + input = torch.view_as_real(input) + else: + input = torch.nn.functional.pad(input[..., None], (0, 1)) + if n is None: + n = 2 * (input.size(-1) - 1) + return torch.irfft(input, 1, signal_sizes=(n,)) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 7a007117f6..f05886e22e 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -7,7 +7,7 @@ import torch from .tensor_utils import next_fast_len -from .fft import irfft, rfft +from .fft import rfft, irfft def _compute_chain_variance_stats(input): @@ -106,17 +106,13 @@ def autocorrelation(input, dim=0): # centering and padding x centered_signal = input - input.mean(dim=-1, keepdim=True) - pad = torch.zeros(input.shape[:-1] + (M2 - N,), dtype=input.dtype, device=input.device) - centered_signal = torch.cat([centered_signal, pad], dim=-1) # Fourier transform - freqvec = rfft(centered_signal, signal_ndim=1, onesided=False) + freqvec = torch.view_as_real(rfft(centered_signal, n=M2)) # take square of magnitude of freqvec (or freqvec x freqvec*) - freqvec_gram = freqvec.pow(2).sum(-1, keepdim=True) - freqvec_gram = torch.cat([freqvec_gram, torch.zeros(freqvec_gram.shape, dtype=input.dtype, - device=input.device)], dim=-1) + freqvec_gram = freqvec.pow(2).sum(-1) # inverse Fourier transform - autocorr = irfft(freqvec_gram, signal_ndim=1, onesided=False) + autocorr = irfft(freqvec_gram, n=M2) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index cb559e9010..0ac27aac4e 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -184,12 +184,6 @@ def next_fast_len(size): next_size += 1 -def _complex_mul(a, b): - ar, ai = a.unbind(-1) - br, bi = b.unbind(-1) - return torch.stack([ar * br - ai * bi, ar * bi + ai * br], dim=-1) - - def convolve(signal, kernel, mode='full'): """ Computes the 1-d convolution of signal by kernel using FFTs. @@ -221,10 +215,10 @@ def convolve(signal, kernel, mode='full'): padded_size = m + n - 1 # Round up for cheaper fft. fast_ftt_size = next_fast_len(padded_size) - f_signal = rfft(torch.nn.functional.pad(signal, (0, fast_ftt_size - m)), 1, onesided=False) - f_kernel = rfft(torch.nn.functional.pad(kernel, (0, fast_ftt_size - n)), 1, onesided=False) - f_result = _complex_mul(f_signal, f_kernel) - result = irfft(f_result, 1, onesided=False) + f_signal = rfft(signal, n=fast_ftt_size) + f_kernel = rfft(kernel, n=fast_ftt_size) + f_result = f_signal * f_kernel + result = irfft(f_result, n=fast_ftt_size) start_idx = (padded_size - truncate) // 2 return result[..., start_idx: start_idx + truncate] @@ -257,12 +251,6 @@ def repeated_matmul(M, n): return result[0:n] -def _real_of_complex_mul(a, b): - ar, ai = a.unbind(-1) - br, bi = b.unbind(-1) - return ar * br - ai * bi - - def dct(x, dim=-1): """ Discrete cosine transform of type II, scaled to be orthonormal. @@ -285,11 +273,16 @@ def dct(x, dim=-1): # Step 1 y = torch.cat([x[..., ::2], x[..., 1::2].flip(-1)], dim=-1) # Step 2 - Y = torch.rfft(y, 1, onesided=False) + Y = rfft(y, n=N) # Step 3 coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) - coef = torch.stack([coef_real[:-1], -coef_real[1:].flip(-1)], dim=-1) - X = _real_of_complex_mul(coef, Y) + M = Y.size(-1) + coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1) + X = torch.view_as_complex(coef) * Y + # NB: if we use the full-length version Y_full = fft(y, n=N), then + # the real part of the later half of X will be the flip + # of the negative of the imaginary part of the first half + X = torch.cat([X.real, -X.imag[..., 1:(N - M + 1)].flip(-1)], dim=-1) # orthogonalize scale = torch.cat([x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))]) return X / scale @@ -322,14 +315,14 @@ def idct(x, dim=-1): # and Yi[1:] = sin(k) * X[1:] - cos(k) * X[:0:-1] # In addition, Yi[0] = 0, Yr[0] = X[0] # In other words, Y = complex_mul(e^ik, X - i[0, X[:0:-1]]) - xi = torch.nn.functional.pad(-x[..., 1:], (0, 1)).flip(-1) - X = torch.stack([x, xi], dim=-1) - coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1)) - coef = torch.stack([coef_real[:-1], coef_real[1:].flip(-1)], dim=-1) - half_size = N // 2 + 1 - Y = _complex_mul(coef[..., :half_size, :], X[..., :half_size, :]) + M = N // 2 + 1 # half size + xi = torch.nn.functional.pad(-x[..., N - M + 1:], (0, 1)).flip(-1) + X = torch.stack([x[..., :M], xi], dim=-1) + coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) + coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1) + Y = torch.view_as_complex(coef) * torch.view_as_complex(X) # Step 2 - y = irfft(Y, 1, onesided=True, signal_sizes=(N,)) + y = irfft(Y, n=N) # Step 3 return torch.stack([y, y.flip(-1)], axis=-1).reshape(x.shape[:-1] + (-1,))[..., :N] From 30083cbecabfe1c06dc15e4618044bd2ccd2ba5d Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Mon, 16 Nov 2020 17:42:12 +0000 Subject: [PATCH 4/7] change sparse reg test args --- tests/test_examples.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index a060d67d03..50046f3f1c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -101,8 +101,7 @@ 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy', 'svi_horovod.py --num-epochs=2 --size=400 --no-horovod', 'toy_mixture_model_discrete_enumeration.py --num-steps=1', - xfail_param('sparse_regression.py --num-steps=2 --num-data=50 --num-dimensions 20', - reason='https://github.com/pyro-ppl/pyro/issues/2082'), + 'sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11', 'vae/ss_vae_M2.py --num-epochs=1', 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel', From a9226312e42908c24ca03030c7730ecf3bd68d23 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 16 Nov 2020 16:23:20 -0500 Subject: [PATCH 5/7] Work around PyTorch 1.7 bugs in mcmc --- pyro/infer/mcmc/api.py | 1 + pyro/poutine/enum_messenger.py | 2 +- tests/test_examples.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index dfa55cf15c..253659ff81 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -227,6 +227,7 @@ def run(self, *args, **kwargs): # Ignore sigint in worker processes; they will be shut down # when the main process terminates. sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) + args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args] self.init_workers(*args, **kwargs) # restore original handler signal.signal(signal.SIGINT, sigint_handler) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 03c8b9c8a3..cb1c14162b 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -200,7 +200,7 @@ def _pyro_post_sample(self, msg): value = msg["value"] if value is None: return - shape = value.shape[:value.dim() - msg["fn"].event_dim] + shape = value.data.shape[:value.dim() - msg["fn"].event_dim] dim_to_id = msg["infer"].setdefault("_dim_to_id", {}) dim_to_id.update(self._param_dims.get(msg["name"], {})) with ignore_jit_warnings(): diff --git a/tests/test_examples.py b/tests/test_examples.py index 50046f3f1c..b0d0cb96c8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -201,7 +201,6 @@ def xfail_jit(*args): 'contrib/epidemiology/sir.py --jit -np=128 -ss=2 -n=4 -d=20 -p=1000 -f 2 --svi', 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', - xfail_jit('lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --jit'), xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), xfail_jit('dmm.py --num-epochs=1 --jit'), xfail_jit('dmm.py --num-epochs=1 --num-iafs=1 --jit'), From b4b1edfd20d544f6adaa1c976247b6f3df3bfb45 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 16 Nov 2020 17:16:15 -0600 Subject: [PATCH 6/7] fix mcmc multi-chain bugs --- pyro/distributions/torch_patch.py | 11 +++++++++++ pyro/infer/mcmc/api.py | 9 ++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index c3af775361..0f6ffb58be 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import math import weakref import torch @@ -69,6 +70,16 @@ def _Multinomial_support(self): return torch.distributions.constraints.integer_interval(0, total_count) +# TODO fix https://github.com/pytorch/pytorch/issues/48054 upstream +@patch_dependency('torch.distributions.HalfCauchy.log_prob') +def _HalfCauchy_logprob(self, value): + value = torch.as_tensor(value, dtype=self.base_dist.scale.dtype, + device=self.base_dist.scale.device) + log_prob = self.base_dist.log_prob(value) + math.log(2) + log_prob.masked_fill_(value.expand(log_prob.shape) < 0, -float("inf")) + return log_prob + + # This adds a __call__ method to satisfy sphinx. @patch_dependency('torch.distributions.utils.lazy_property.__call__') def _lazy_property__call__(self): diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index dfa55cf15c..7e62720a61 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -85,9 +85,6 @@ def __init__(self, chain_id, result_queue, log_queue, event, kernel, num_samples def run(self, *args, **kwargs): pyro.set_rng_seed(self.rng_seed) torch.set_default_tensor_type(self.default_tensor_type) - # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" - # at https://github.com/pytorch/pytorch/issues/10375 - args = [arg.clone().detach() if (torch.is_tensor(arg) and arg.is_cuda) else arg for arg in args] kwargs = kwargs logger = logging.getLogger("pyro.infer.mcmc") logger_id = "CHAIN:{}".format(self.chain_id) @@ -377,6 +374,12 @@ def model(data): z_flat_acc = [[] for _ in range(self.num_chains)] with optional(pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None): + if self.num_chains > 1: + # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" + # at https://github.com/pytorch/pytorch/issues/10375 + # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which + # requires_grad", which happens with `jit_compile` under PyTorch 1.7 + args = [arg.clone().detach() if torch.is_tensor(arg) else arg for arg in args] for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1 From cf1d49f6a488404b945f717c1fc2e2dbc4acff7d Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 16 Nov 2020 22:23:15 -0600 Subject: [PATCH 7/7] only detach --- pyro/infer/mcmc/api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index 256be020f9..410bafbc61 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -224,7 +224,6 @@ def run(self, *args, **kwargs): # Ignore sigint in worker processes; they will be shut down # when the main process terminates. sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) - args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args] self.init_workers(*args, **kwargs) # restore original handler signal.signal(signal.SIGINT, sigint_handler) @@ -375,12 +374,11 @@ def model(data): z_flat_acc = [[] for _ in range(self.num_chains)] with optional(pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None): - if self.num_chains > 1: - # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" - # at https://github.com/pytorch/pytorch/issues/10375 - # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which - # requires_grad", which happens with `jit_compile` under PyTorch 1.7 - args = [arg.clone().detach() if torch.is_tensor(arg) else arg for arg in args] + # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" + # at https://github.com/pytorch/pytorch/issues/10375 + # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which + # requires_grad", which happens with `jit_compile` under PyTorch 1.7 + args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args] for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1