Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyTorch 1.0.0 #1376

Closed
wants to merge 143 commits into from
Closed
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
4c6d9b7
Fix PyTorch 0.4.1 errors
fritzo Jul 27, 2018
cd945f3
fix tests for pytorch 0.4.1
neerajprad Jul 27, 2018
4499b9c
fix test_mask
neerajprad Jul 27, 2018
f707101
skip JIT tests
neerajprad Jul 27, 2018
4737e10
fix NUTS tests
neerajprad Jul 27, 2018
7dc2bc1
fix gaussian_scale_mixture
neerajprad Jul 27, 2018
b58ef93
Mark test_em.py failures as xfail
fritzo Jul 27, 2018
448a760
Allow inf values in assert_tensors_equal
fritzo Jul 27, 2018
62f21d5
update examples
jpchen Jul 27, 2018
385ef94
Merge branch 'dev' into fix-dist-0.4.1
neerajprad Jul 28, 2018
814e025
remove redundant xfail
neerajprad Jul 28, 2018
e1cd9b4
Update JIT usage to PyTorch 0.4.1 (#1276)
fritzo Jul 30, 2018
a302944
Merge branch 'dev' into fix-dist-0.4.1
neerajprad Jul 31, 2018
8cb89a3
use float in arange
neerajprad Jul 31, 2018
eac0b73
Merge branch 'dev' into fix-dist-0.4.1
fritzo Aug 7, 2018
660b8d1
Fix Categorical.enumerate_support to make JitTraceEnum_ELBO work
fritzo Aug 7, 2018
d583dd1
Refactor test_examples.py to allow xfailing examples
fritzo Aug 7, 2018
fc9a4a7
Add xfailing examples that use --jit
fritzo Aug 7, 2018
15fba47
Fix missing import in test_jit.py
fritzo Aug 7, 2018
91bf748
Enable jit in most SVI examples
fritzo Aug 7, 2018
afc0b70
Merge branch 'dev' into pytorch-0.4.1
fritzo Aug 8, 2018
71779ec
Revert changes to torch_patch.py
fritzo Aug 8, 2018
fcdddcc
Work around jit issues; bayesian_regressian example now jits
fritzo Aug 8, 2018
d260fe0
Fix doctests to pass on Python 2.7
fritzo Aug 8, 2018
7ee642b
Merge branch 'fix-doctest-2.7' into pytorch-0.4.1
fritzo Aug 8, 2018
9f7fd54
Fix arange usage
fritzo Aug 8, 2018
d3cafb1
Only patch Categorical if broadcast_tensors is defined
fritzo Aug 8, 2018
17cb636
Add patches to work around bugs in 0.4.1
fritzo Aug 8, 2018
ad2e391
Merge branch 'dev' into pytorch-0.4.1
fritzo Aug 8, 2018
a251d3d
Fix test failures
fritzo Aug 8, 2018
1e6fb98
flake8
fritzo Aug 8, 2018
80cf76f
Fix typo in skipif markers
fritzo Aug 9, 2018
0008725
Work around bugs in torch unwind backward
fritzo Aug 9, 2018
55003ec
Mark xfailing jit test
fritzo Aug 9, 2018
8ea0461
Update all uses of torch.arange
fritzo Aug 9, 2018
09cadfb
Remove obsolete logsumexp implementation
fritzo Aug 9, 2018
2a93a9a
Patch torch.distributions.Categorical.log_prob
fritzo Aug 9, 2018
e91aa86
Work around lack of jit support for torch.eye(_, out=_)
fritzo Aug 9, 2018
5d71162
Add test-jit target to Makefile
fritzo Aug 9, 2018
0fb6870
Fix bug in eye_like when m!=n
fritzo Aug 9, 2018
6f88bbc
Fix jit errors: torch_scale and variable len(args)
fritzo Aug 9, 2018
2f190e5
Patch multivariate normal __init__ methods to be jittable
fritzo Aug 9, 2018
f7ef56e
Patch torch.log
fritzo Aug 9, 2018
b5ba5f1
Patch torch.Tensor.log
fritzo Aug 9, 2018
3f64101
Patch torch.exp and torch.Tensor.exp
fritzo Aug 9, 2018
59c48e1
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 10, 2018
894fa58
Use JIT traced potential energy computation in HMC (#1299)
neerajprad Aug 14, 2018
f303855
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 28, 2018
bf98735
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 31, 2018
9313b0f
add xfailing test
neerajprad Sep 4, 2018
a5bc9bb
Merge branch 'dev' into pytorch-0.4.1
fritzo Sep 11, 2018
3cfba13
Remove obsolete PyTorch patches
fritzo Sep 11, 2018
a677512
Remove patch for Tensor._standard_gamma
fritzo Sep 11, 2018
8f665ba
Fix some jit errors
fritzo Sep 11, 2018
06b0e63
Convert to valid einsum chars in torch_log backend
fritzo Sep 11, 2018
1785b6c
Updating distributions module with PyTorch master (#1377)
neerajprad Sep 11, 2018
29bb3ed
Use native torch.tensordot
fritzo Sep 11, 2018
8c914b8
Remove duplicate implementation of logsumexp
fritzo Sep 11, 2018
1c708b4
Merge branch 'dev' into pytorch-0.5.0
fritzo Sep 11, 2018
35a6965
Ignore jit warnings
fritzo Sep 11, 2018
692b883
Ignore a couple TracerWarnings in pyro.ops.jit.trace
fritzo Sep 11, 2018
0c50243
Fix a tiny test_jit error
fritzo Sep 11, 2018
850432f
Add jit test for OneHotCategorical
fritzo Sep 11, 2018
a03164a
fix JIT errors for HMC
neerajprad Sep 11, 2018
effada7
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 11, 2018
c36a9aa
change assert in torch_log
neerajprad Sep 11, 2018
ecfc995
Work around more jit missing coverage
fritzo Sep 11, 2018
c47e294
Strengthen masked_fill test
fritzo Sep 12, 2018
9ea7fd2
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 12, 2018
c19830d
fix hmc enum test
neerajprad Sep 12, 2018
0c233d3
Fix failing jit tests
fritzo Sep 13, 2018
dceaf9a
Add test for .scatter_() workaround
fritzo Sep 13, 2018
ac122b0
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 13, 2018
e5cd034
add expand for MaskedDistribution
neerajprad Sep 13, 2018
6ce9925
remove binomial and half cauchy
neerajprad Sep 13, 2018
e80ef20
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 14, 2018
68c1168
reinstate Independent constraint
neerajprad Sep 14, 2018
657fc56
add expand methods to more distributions
neerajprad Sep 14, 2018
d44f90b
Fix CUDA tests in test_eig.py
neerajprad Sep 14, 2018
c2c4b72
remove standard gamma patch
neerajprad Sep 14, 2018
62bf019
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 17, 2018
5532c3f
Work-around to allow JIT compiler to infer batch size in iarange (#1392)
neerajprad Sep 19, 2018
021707a
Remove deprecated new_tensor invocation
neerajprad Sep 20, 2018
981d64b
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 20, 2018
a36f25a
Remove deprecated new_tensor invocation
neerajprad Sep 20, 2018
693fcd9
remove .new
neerajprad Sep 20, 2018
3129059
address comments
neerajprad Sep 20, 2018
b48f108
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
a648968
fix test_hessian
neerajprad Sep 20, 2018
c1c9e82
fix more tests
neerajprad Sep 20, 2018
edf000e
remove redundant parens
neerajprad Sep 20, 2018
53c3e07
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
c2d3de8
fix test_elbo_mapdata
neerajprad Sep 20, 2018
bf85894
fix test_conj_gaussian
neerajprad Sep 20, 2018
c6dd8d7
fix test_valid_models
neerajprad Sep 20, 2018
b54081c
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
d4ff53c
fix dist tests
neerajprad Sep 20, 2018
0064b6a
fix test_gaussian_mixtures
neerajprad Sep 20, 2018
94cb96c
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
8ceaaa0
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 25, 2018
710406d
Test fixes for compatibility with PyTorch master
neerajprad Sep 26, 2018
eef40eb
address comments; more fixes
neerajprad Sep 26, 2018
2325b14
more test fixes
neerajprad Sep 26, 2018
a5a457f
uncomment torch_patch
neerajprad Sep 26, 2018
c9078dd
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 26, 2018
6078bc7
Merge branch 'test-fixes' into pytorch-0.5.0
neerajprad Sep 26, 2018
41164bb
ignore jit warnings in hmc
neerajprad Sep 26, 2018
21b32c0
remove default jit compilation in air
neerajprad Sep 26, 2018
865ddab
set args.jit default to false
neerajprad Sep 26, 2018
6ddeb38
ignore jit warnings in hmc tests
neerajprad Sep 26, 2018
43c5e4e
mark failing hmc tests
neerajprad Sep 27, 2018
b5ae590
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 27, 2018
68171b9
test against nightly build
neerajprad Sep 27, 2018
f4db712
fix channel name
neerajprad Sep 27, 2018
fb290e8
downgrade ipython
neerajprad Sep 27, 2018
ed4b360
fix lapack issue
neerajprad Sep 27, 2018
2142589
include mkl
neerajprad Sep 27, 2018
1efd552
addons to .travis
neerajprad Sep 27, 2018
b84122f
add pytorch channel
neerajprad Sep 27, 2018
85342ef
remove pythonpath
neerajprad Sep 27, 2018
464fe68
editable install
neerajprad Sep 27, 2018
475f483
add ld_library_path
neerajprad Sep 27, 2018
07e20d4
conda install pip
neerajprad Sep 27, 2018
a624b57
debug build
neerajprad Sep 27, 2018
053163f
debug - revert to pytorch release
neerajprad Sep 27, 2018
fb62e69
add before install
neerajprad Sep 27, 2018
13be21f
use nightly wheel
neerajprad Sep 27, 2018
31b5d63
Fix incompatible dependency between jupyter-console and ipython
neerajprad Sep 27, 2018
83b3eb4
Merge branch 'fix-ipython-dep' into pytorch-0.5.0
neerajprad Sep 27, 2018
f904d0c
remove torch==0.4.1 from setup
neerajprad Sep 27, 2018
13d7b51
remove torchvision temporarily
neerajprad Sep 27, 2018
0e17d3c
install torchvision without deps
neerajprad Sep 28, 2018
92d6d4d
remove torchvision from setup
neerajprad Sep 28, 2018
72c76fd
update to contextlib2
neerajprad Sep 28, 2018
0e60ce2
fix benchmark tests
neerajprad Sep 28, 2018
b328d12
add xfail markers for failing tests
neerajprad Sep 28, 2018
3161deb
temporarily xfail ubersum_sizes test
neerajprad Sep 28, 2018
e298cb4
fix xfail marker
neerajprad Sep 28, 2018
d13039b
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 28, 2018
feca15d
remove xfail marker from test_enum
neerajprad Sep 28, 2018
a7300e8
add xfail for mixture of diag normals
neerajprad Oct 1, 2018
17f2033
Merge branch 'dev' into pytorch-0.5.0
neerajprad Oct 1, 2018
379ffef
fix mask fill on non contiguous tensor
neerajprad Oct 2, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ env:
install:
- pip install -U pip
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl;
pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl;
else
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl;
pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl;
fi
- pip install .[test]
- pip freeze
Expand Down
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ test-cuda: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx -n 4 --stage unit
CUDA_TEST=1 pytest -vx -n 4 tests/test_examples.py::test_cuda

test-jit: FORCE
@echo See jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/mcmc/test_hmc.py tests/infer/mcmc/test_nuts.py \
-k JIT=True | tee -a jit.log

clean: FORCE
git clean -dfx -e pyro-egg.info

Expand Down
2 changes: 1 addition & 1 deletion docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ Primitives
.. autofunction:: pyro.validation_enabled
.. autofunction:: pyro.enable_validation

.. autofunction:: pyro.ops.jit.compile
.. autofunction:: pyro.ops.jit.trace
4 changes: 3 additions & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def main(args):
baseball_dataset = pd.read_csv(DATA_URL, "\t")
train, _, player_names = train_test_split(baseball_dataset)
at_bats, hits = train[:, 0], train[:, 1]
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit)
logging.info("Original Dataset:")
logging.info(baseball_dataset)

Expand Down Expand Up @@ -270,5 +270,7 @@ def main(args):
parser.add_argument("-n", "--num-samples", nargs="?", default=1200, type=int)
parser.add_argument("--warmup-steps", nargs='?', default=300, type=int)
parser.add_argument("--rng_seed", nargs='?', default=0, type=int)
parser.add_argument('--jit', action='store_true', default=False,
help='use PyTorch jit')
args = parser.parse_args()
main(args)
3 changes: 2 additions & 1 deletion examples/eight_schools/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def conditioned_model(model, sigma, y):


def main(args):
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit)
posterior = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps)\
.run(model, data.sigma, data.y)
marginal_mu_tau = EmpiricalMarginal(posterior, sites=["mu", "tau"])\
Expand All @@ -54,6 +54,7 @@ def main(args):
help='number of MCMC samples (default: 1000)')
parser.add_argument('--warmup-steps', type=int, default=1000,
help='number of MCMC samples for warmup (default: 1000)')
parser.add_argument('--jit', action='store_true', default=False)
args = parser.parse_args()

main(args)
8 changes: 4 additions & 4 deletions pyro/contrib/gp/models/gplvm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
from pyro.contrib.gp.util import Parameterized
import pyro.distributions as dist
import pyro.infer as infer
import pyro.optim as optim
from pyro.contrib.gp.util import Parameterized
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, base_model, name="GPLVM"):

C = self.X_loc.shape[1]
X_scale_tril_shape = self.X_loc.shape + (C,)
Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
Id = eye_like(self.X_loc, C)
X_scale_tril = Id.expand(X_scale_tril_shape)
self.X_scale_tril = Parameter(X_scale_tril)
self.set_constraint("X_scale_tril", constraints.lower_cholesky)
Expand All @@ -87,7 +87,7 @@ def model(self):
# sample X from unit multivariate normal distribution
zero_loc = self.X_loc.new_zeros(self.X_loc.shape)
C = self.X_loc.shape[1]
Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
Id = eye_like(self.X_loc, C)
X_name = param_with_module_name(self.name, "X")
X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim()-1))
Expand Down
5 changes: 3 additions & 2 deletions pyro/contrib/gp/models/vgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyro.distributions as dist
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self, X, y, kernel, likelihood, mean_function=None,
self.f_loc = Parameter(f_loc)

f_scale_tril_shape = self.latent_shape + (N, N)
Id = torch.eye(N, out=self.X.new_empty(N, N))
Id = eye_like(self.X, N)
f_scale_tril = Id.expand(f_scale_tril_shape)
self.f_scale_tril = Parameter(f_scale_tril)
self.set_constraint("f_scale_tril", constraints.lower_cholesky)
Expand All @@ -96,7 +97,7 @@ def model(self):
f_name = param_with_module_name(self.name, "f")

if self.whiten:
Id = torch.eye(N, out=self.X.new_empty(N, N))
Id = eye_like(self.X, N)
pyro.sample(f_name,
dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim() - 1))
Expand Down
5 changes: 3 additions & 2 deletions pyro/contrib/gp/models/vsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyro.poutine as poutine
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None,
self.u_loc = Parameter(u_loc)

u_scale_tril_shape = self.latent_shape + (M, M)
Id = torch.eye(M, out=self.Xu.new_empty(M, M))
Id = eye_like(self.Xu, M)
u_scale_tril = Id.expand(u_scale_tril_shape)
self.u_scale_tril = Parameter(u_scale_tril)
self.set_constraint("u_scale_tril", constraints.lower_cholesky)
Expand All @@ -120,7 +121,7 @@ def model(self):
zero_loc = Xu.new_zeros(u_loc.shape)
u_name = param_with_module_name(self.name, "u")
if self.whiten:
Id = torch.eye(M, out=Xu.new_empty(M, M))
Id = eye_like(Xu, M)
pyro.sample(u_name,
dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim() - 1))
Expand Down
5 changes: 3 additions & 2 deletions pyro/distributions/lowrank_mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributions.utils import lazy_property

from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistribution
from pyro.distributions.util import eye_like


def _matrix_triangular_solve_compat(b, A, upper=True):
Expand Down Expand Up @@ -84,7 +85,7 @@ def scale_tril(self):
A = self.covariance_matrix_W_term / Dsqrt
At_A = A.t().matmul(A)
N = A.shape[1]
Id = torch.eye(N, N, out=A.new_empty(N, N))
Id = eye_like(A, N)
K = Id + At_A
L = K.potrf(upper=False)
return Dsqrt.unsqueeze(1) * L
Expand All @@ -111,7 +112,7 @@ def _compute_logdet_and_mahalanobis(self, D, W, y, trace_term=0):
"""
W_Dinv = W / D
M = W.shape[0]
Id = torch.eye(M, M, out=W.new_empty(M, M))
Id = eye_like(W, M)
K = Id + W_Dinv.matmul(W.t())
L = K.potrf(upper=False)
if y.dim() == 1:
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/omt_mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import constraints

from pyro.distributions.torch import MultivariateNormal
from pyro.distributions.util import sum_leftmost
from pyro.distributions.util import eye_like, sum_leftmost


class OMTMultivariateNormal(MultivariateNormal):
Expand Down Expand Up @@ -51,7 +51,7 @@ def backward(ctx, grad_output):
g = grad_output
loc_grad = sum_leftmost(grad_output, -1)

identity = torch.eye(dim, out=torch.tensor(g.new_empty(dim, dim)))
identity = eye_like(g, dim)
R_inv = torch.trtrs(identity, L.t(), transpose=False, upper=True)[0]

z_ja = z.unsqueeze(-1)
Expand Down
5 changes: 3 additions & 2 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.distributions import constraints

from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistributionMixin
from pyro.distributions.util import eye_like


class Bernoulli(torch.distributions.Bernoulli, TorchDistributionMixin):
Expand Down Expand Up @@ -39,6 +40,7 @@ def expand(self, batch_shape):


class Categorical(torch.distributions.Categorical, TorchDistributionMixin):

def expand(self, batch_shape):
try:
return super(Categorical, self).expand(batch_shape)
Expand Down Expand Up @@ -252,8 +254,7 @@ def expand(self, batch_shape):

def enumerate_support(self, expand=True):
n = self.event_shape[0]
values = self._new((n, n))
torch.eye(n, out=values)
values = eye_like(self._categorical._param, n)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = values.expand((n,) + self.batch_shape + (n,))
Expand Down
10 changes: 0 additions & 10 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def _torch_dirichlet_grad(x, concentration, total):
return unpatched_fn(x, concentration, total)


@_patch('torch.einsum')
def _einsum(equation, operands):
# work around torch.einsum performance issues
# see https://github.com/pytorch/pytorch/issues/10661
Expand All @@ -56,15 +55,6 @@ def _einsum(equation, operands):
y, x = operands
return (x.unsqueeze(1) * y).sum(0).transpose(0, 1)

# work around torch.einsum's limitation to 26 letters
symbols = sorted(set(equation) - set(',->'))
fritzo marked this conversation as resolved.
Show resolved Hide resolved
rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
equation = ''.join(rename.get(s, s) for s in equation)

# this workaround can be deleted after this issue is fixed in release:
# https://github.com/pytorch/pytorch/issues/7763
operands = [t.clone() for t in operands]

return _einsum._pyro_unpatched(equation, operands)


Expand Down
37 changes: 14 additions & 23 deletions pyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import numbers
from contextlib import contextmanager

import torch
import torch.distributions as torch_dist
from torch import logsumexp
from torch.distributions.utils import broadcast_all

_VALIDATION_ENABLED = False

log_sum_exp = logsumexp # DEPRECATED


def copy_docs_from(source_class, full_text=False):
"""
Expand Down Expand Up @@ -52,15 +56,23 @@ def is_identically_zero(x):
Check if argument is exactly the number zero. True for the number zero;
false for other numbers; false for :class:`~torch.Tensor`s.
"""
return isinstance(x, numbers.Number) and x == 0
if isinstance(x, numbers.Number):
return x == 0
elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape:
return x.item() == 0
return False


def is_identically_one(x):
"""
Check if argument is exactly the number one. True for the number one;
false for other numbers; false for :class:`~torch.Tensor`s.
"""
return isinstance(x, numbers.Number) and x == 1
if isinstance(x, numbers.Number):
return x == 1
elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape:
return x.item() == 1
return False


def broadcast_shape(*shapes, **kwargs):
Expand Down Expand Up @@ -178,27 +190,6 @@ def eye_like(value, m, n=None):
return eye


try:
from torch import logsumexp # for pytorch 0.4.1 and later
except ImportError:
def logsumexp(tensor, dim=-1, keepdim=False):
"""
Numerically stable implementation for the `LogSumExp` operation. The
summing is done along the dimension specified by ``dim``.

:param torch.Tensor tensor: Input tensor.
:param dim: Dimension to be summed out.
:param keepdim: Whether to retain the dimension
that is summed out.
"""
max_val = tensor.max(dim, keepdim=True)[0]
log_sum_exp = max_val + (tensor - max_val).exp().sum(dim=dim, keepdim=True).log()
return log_sum_exp if keepdim else log_sum_exp.squeeze(dim)


log_sum_exp = logsumexp # DEPRECATED


def enable_validation(is_validate):
global _VALIDATION_ENABLED
_VALIDATION_ENABLED = is_validate
Expand Down
34 changes: 34 additions & 0 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class HMC(TraceKernel):
:param int max_iarange_nesting: Optional bound on max number of nested
:func:`pyro.iarange` contexts. This is required if model contains
discrete sample sites that can be enumerated over in parallel.
:param bool jit_compile: Optional parameter denoting whether to use
the PyTorch JIT to trace the log density computation, and use this
optimized executable trace in the integrator.
:param bool experimental_use_einsum: Whether to use an einsum operation
to evaluate log pdf for the model trace. No-op unless the trace has
discrete sample sites. This flag is experimental and will most likely
Expand Down Expand Up @@ -83,6 +86,7 @@ def __init__(self,
adapt_step_size=False,
transforms=None,
max_iarange_nesting=float("inf"),
jit_compile=False,
experimental_use_einsum=False):
# Wrap model in `poutine.enum` to enumerate over discrete latent sites.
# No-op if model does not have any discrete latents.
Expand All @@ -99,6 +103,7 @@ def __init__(self,
self.trajectory_length = 2 * math.pi # from Stan
self.num_steps = max(1, int(self.trajectory_length / self.step_size))
self.adapt_step_size = adapt_step_size
self._jit_compile = jit_compile
self.use_einsum = experimental_use_einsum
self._target_accept_prob = 0.8 # from Stan

Expand Down Expand Up @@ -129,6 +134,8 @@ def _kinetic_energy(self, r):
return 0.5 * sum(x.pow(2).sum() for x in r.values())

def _potential_energy(self, z):
if self._jit_compile:
return self._potential_energy_jit(z)
# Since the model is specified in the constrained space, transform the
# unconstrained R.V.s `z` to the constrained space.
z_constrained = z.copy()
Expand All @@ -141,6 +148,32 @@ def _potential_energy(self, z):
potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum()
return potential_energy

def _potential_energy_jit(self, z):
names, vals = zip(*sorted(z.items()))
if self._compiled_potential_fn:
return self._compiled_potential_fn(*vals)

@torch.jit.trace(*vals, optimize=True)
def wrapped(*zi):
z_constrained = list(zi)
# transform to constrained space.
for i, name in enumerate(names):
if name in self.transforms:
transform = self.transforms[name]
z_constrained[i] = transform.inv(z_constrained[i])
z_constrained = dict(zip(names, z_constrained))
trace = self._get_trace(z_constrained)
potential_energy = -self._compute_trace_log_prob(trace)
# adjust by the jacobian for this transformation.
for i, name in enumerate(names):
if name in self.transforms:
transform = self.transforms[name]
potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum()
return potential_energy

self._compiled_potential_fn = wrapped
return self._compiled_potential_fn(*vals)

def _energy(self, z, r):
return self._kinetic_energy(r) + self._potential_energy(z)

Expand All @@ -149,6 +182,7 @@ def _reset(self):
self._accept_cnt = 0
self._r_dist = OrderedDict()
self._args = None
self._compiled_potential_fn = None
self._kwargs = None
self._prototype_trace = None
self._adapt_phase = False
Expand Down
Loading