Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for PyTorch 1.7 release #2683

Merged
merged 11 commits into from
Nov 17, 2020
2 changes: 1 addition & 1 deletion pyro/distributions/spanning_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ at::Tensor sample_tree_approx(at::Tensor edge_logits) {
// the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2.
auto edge_ids = torch::empty({E}, at::kLong);
// This maps each vertex to whether it is a member of the cumulative tree.
auto components = torch::zeros({V}, at::kByte);
auto components = torch::zeros({V}, at::kBool);

// Sample the first edge at random.
auto probs = (edge_logits - edge_logits.max()).exp();
Expand Down
27 changes: 27 additions & 0 deletions pyro/ops/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

try:
# This works in PyTorch 1.7+
from torch.fft import irfft, rfft
except ModuleNotFoundError:
# This works in PyTorch 1.6
def rfft(input, n=None):
if n is not None:
m = input.size(-1)
if n > m:
input = torch.nn.functional.pad(input, (0, n - m))
elif n < m:
input = input[..., :n]
return torch.view_as_complex(torch.rfft(input, 1))

def irfft(input, n=None):
if torch.is_complex(input):
input = torch.view_as_real(input)
else:
input = torch.nn.functional.pad(input[..., None], (0, 1))
if n is None:
n = 2 * (input.size(-1) - 1)
return torch.irfft(input, 1, signal_sizes=(n,))
2 changes: 1 addition & 1 deletion pyro/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def eig_3d(H):
p = torch.sqrt(p2 / 6)
B = (1 / p).unsqueeze(-1).unsqueeze(-1) * (H - q.unsqueeze(-1).unsqueeze(-1) * torch.eye(3))
r = determinant_3d(B) / 2
phi = (r.acos() / 3).unsqueeze(-1).unsqueeze(-1).expand(r.shape + (3, 3))
phi = (r.acos() / 3).unsqueeze(-1).unsqueeze(-1).expand(r.shape + (3, 3)).clone()
phi[r < -1 + 1e-6] = math.pi / 3
phi[r > 1 - 1e-6] = 0.

Expand Down
11 changes: 4 additions & 7 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from .tensor_utils import next_fast_len
from .fft import rfft, irfft


def _compute_chain_variance_stats(input):
Expand Down Expand Up @@ -105,17 +106,13 @@ def autocorrelation(input, dim=0):

# centering and padding x
centered_signal = input - input.mean(dim=-1, keepdim=True)
pad = torch.zeros(input.shape[:-1] + (M2 - N,), dtype=input.dtype, device=input.device)
centered_signal = torch.cat([centered_signal, pad], dim=-1)

# Fourier transform
freqvec = torch.rfft(centered_signal, signal_ndim=1, onesided=False)
freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
# take square of magnitude of freqvec (or freqvec x freqvec*)
freqvec_gram = freqvec.pow(2).sum(-1, keepdim=True)
freqvec_gram = torch.cat([freqvec_gram, torch.zeros(freqvec_gram.shape, dtype=input.dtype,
device=input.device)], dim=-1)
freqvec_gram = freqvec.pow(2).sum(-1)
# inverse Fourier transform
autocorr = torch.irfft(freqvec_gram, signal_ndim=1, onesided=False)
autocorr = irfft(freqvec_gram, n=M2)

# truncate and normalize the result, then transpose back to original shape
autocorr = autocorr[..., :N]
Expand Down
46 changes: 20 additions & 26 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from .fft import irfft, rfft

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)

Expand Down Expand Up @@ -183,12 +184,6 @@ def next_fast_len(size):
next_size += 1


def _complex_mul(a, b):
ar, ai = a.unbind(-1)
br, bi = b.unbind(-1)
return torch.stack([ar * br - ai * bi, ar * bi + ai * br], dim=-1)


def convolve(signal, kernel, mode='full'):
"""
Computes the 1-d convolution of signal by kernel using FFTs.
Expand Down Expand Up @@ -220,10 +215,10 @@ def convolve(signal, kernel, mode='full'):
padded_size = m + n - 1
# Round up for cheaper fft.
fast_ftt_size = next_fast_len(padded_size)
f_signal = torch.rfft(torch.nn.functional.pad(signal, (0, fast_ftt_size - m)), 1, onesided=False)
f_kernel = torch.rfft(torch.nn.functional.pad(kernel, (0, fast_ftt_size - n)), 1, onesided=False)
f_result = _complex_mul(f_signal, f_kernel)
result = torch.irfft(f_result, 1, onesided=False)
f_signal = rfft(signal, n=fast_ftt_size)
f_kernel = rfft(kernel, n=fast_ftt_size)
f_result = f_signal * f_kernel
result = irfft(f_result, n=fast_ftt_size)

start_idx = (padded_size - truncate) // 2
return result[..., start_idx: start_idx + truncate]
Expand Down Expand Up @@ -256,12 +251,6 @@ def repeated_matmul(M, n):
return result[0:n]


def _real_of_complex_mul(a, b):
ar, ai = a.unbind(-1)
br, bi = b.unbind(-1)
return ar * br - ai * bi


def dct(x, dim=-1):
"""
Discrete cosine transform of type II, scaled to be orthonormal.
Expand All @@ -284,11 +273,16 @@ def dct(x, dim=-1):
# Step 1
y = torch.cat([x[..., ::2], x[..., 1::2].flip(-1)], dim=-1)
# Step 2
Y = torch.rfft(y, 1, onesided=False)
Y = rfft(y, n=N)
# Step 3
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
coef = torch.stack([coef_real[:-1], -coef_real[1:].flip(-1)], dim=-1)
X = _real_of_complex_mul(coef, Y)
M = Y.size(-1)
coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1)
X = torch.view_as_complex(coef) * Y
# NB: if we use the full-length version Y_full = fft(y, n=N), then
# the real part of the later half of X will be the flip
# of the negative of the imaginary part of the first half
X = torch.cat([X.real, -X.imag[..., 1:(N - M + 1)].flip(-1)], dim=-1)
# orthogonalize
scale = torch.cat([x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))])
return X / scale
Expand Down Expand Up @@ -321,14 +315,14 @@ def idct(x, dim=-1):
# and Yi[1:] = sin(k) * X[1:] - cos(k) * X[:0:-1]
# In addition, Yi[0] = 0, Yr[0] = X[0]
# In other words, Y = complex_mul(e^ik, X - i[0, X[:0:-1]])
xi = torch.nn.functional.pad(-x[..., 1:], (0, 1)).flip(-1)
X = torch.stack([x, xi], dim=-1)
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1))
coef = torch.stack([coef_real[:-1], coef_real[1:].flip(-1)], dim=-1)
half_size = N // 2 + 1
Y = _complex_mul(coef[..., :half_size, :], X[..., :half_size, :])
M = N // 2 + 1 # half size
xi = torch.nn.functional.pad(-x[..., N - M + 1:], (0, 1)).flip(-1)
X = torch.stack([x[..., :M], xi], dim=-1)
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1)
Y = torch.view_as_complex(coef) * torch.view_as_complex(X)
# Step 2
y = torch.irfft(Y, 1, onesided=True, signal_sizes=(N,))
y = irfft(Y, n=N)
# Step 3
return torch.stack([y, y.flip(-1)], axis=-1).reshape(x.shape[:-1] + (-1,))[..., :N]

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ filterwarnings = error
ignore:numpy.ufunc size changed:RuntimeWarning
ignore:numpy.dtype size changed:RuntimeWarning
ignore:Mixed memory format inputs detected:UserWarning
ignore:Setting attributes on ParameterDict:UserWarning
ignore::DeprecationWarning
once::DeprecationWarning

Expand Down
3 changes: 1 addition & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy',
'svi_horovod.py --num-epochs=2 --size=400 --no-horovod',
'toy_mixture_model_discrete_enumeration.py --num-steps=1',
xfail_param('sparse_regression.py --num-steps=2 --num-data=50 --num-dimensions 20',
reason='https://github.com/pyro-ppl/pyro/issues/2082'),
'sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11',
fritzo marked this conversation as resolved.
Show resolved Hide resolved
'vae/ss_vae_M2.py --num-epochs=1',
'vae/ss_vae_M2.py --num-epochs=1 --aux-loss',
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel',
Expand Down