Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to using pytorch-1.0-rc #1431

Merged
merged 197 commits into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
197 commits
Select commit Hold shift + click to select a range
4c6d9b7
Fix PyTorch 0.4.1 errors
fritzo Jul 27, 2018
cd945f3
fix tests for pytorch 0.4.1
neerajprad Jul 27, 2018
4499b9c
fix test_mask
neerajprad Jul 27, 2018
f707101
skip JIT tests
neerajprad Jul 27, 2018
4737e10
fix NUTS tests
neerajprad Jul 27, 2018
7dc2bc1
fix gaussian_scale_mixture
neerajprad Jul 27, 2018
b58ef93
Mark test_em.py failures as xfail
fritzo Jul 27, 2018
448a760
Allow inf values in assert_tensors_equal
fritzo Jul 27, 2018
62f21d5
update examples
jpchen Jul 27, 2018
385ef94
Merge branch 'dev' into fix-dist-0.4.1
neerajprad Jul 28, 2018
814e025
remove redundant xfail
neerajprad Jul 28, 2018
e1cd9b4
Update JIT usage to PyTorch 0.4.1 (#1276)
fritzo Jul 30, 2018
a302944
Merge branch 'dev' into fix-dist-0.4.1
neerajprad Jul 31, 2018
8cb89a3
use float in arange
neerajprad Jul 31, 2018
eac0b73
Merge branch 'dev' into fix-dist-0.4.1
fritzo Aug 7, 2018
660b8d1
Fix Categorical.enumerate_support to make JitTraceEnum_ELBO work
fritzo Aug 7, 2018
d583dd1
Refactor test_examples.py to allow xfailing examples
fritzo Aug 7, 2018
fc9a4a7
Add xfailing examples that use --jit
fritzo Aug 7, 2018
15fba47
Fix missing import in test_jit.py
fritzo Aug 7, 2018
91bf748
Enable jit in most SVI examples
fritzo Aug 7, 2018
afc0b70
Merge branch 'dev' into pytorch-0.4.1
fritzo Aug 8, 2018
71779ec
Revert changes to torch_patch.py
fritzo Aug 8, 2018
fcdddcc
Work around jit issues; bayesian_regressian example now jits
fritzo Aug 8, 2018
d260fe0
Fix doctests to pass on Python 2.7
fritzo Aug 8, 2018
7ee642b
Merge branch 'fix-doctest-2.7' into pytorch-0.4.1
fritzo Aug 8, 2018
9f7fd54
Fix arange usage
fritzo Aug 8, 2018
d3cafb1
Only patch Categorical if broadcast_tensors is defined
fritzo Aug 8, 2018
17cb636
Add patches to work around bugs in 0.4.1
fritzo Aug 8, 2018
ad2e391
Merge branch 'dev' into pytorch-0.4.1
fritzo Aug 8, 2018
a251d3d
Fix test failures
fritzo Aug 8, 2018
1e6fb98
flake8
fritzo Aug 8, 2018
80cf76f
Fix typo in skipif markers
fritzo Aug 9, 2018
0008725
Work around bugs in torch unwind backward
fritzo Aug 9, 2018
55003ec
Mark xfailing jit test
fritzo Aug 9, 2018
8ea0461
Update all uses of torch.arange
fritzo Aug 9, 2018
09cadfb
Remove obsolete logsumexp implementation
fritzo Aug 9, 2018
2a93a9a
Patch torch.distributions.Categorical.log_prob
fritzo Aug 9, 2018
e91aa86
Work around lack of jit support for torch.eye(_, out=_)
fritzo Aug 9, 2018
5d71162
Add test-jit target to Makefile
fritzo Aug 9, 2018
0fb6870
Fix bug in eye_like when m!=n
fritzo Aug 9, 2018
6f88bbc
Fix jit errors: torch_scale and variable len(args)
fritzo Aug 9, 2018
2f190e5
Patch multivariate normal __init__ methods to be jittable
fritzo Aug 9, 2018
f7ef56e
Patch torch.log
fritzo Aug 9, 2018
b5ba5f1
Patch torch.Tensor.log
fritzo Aug 9, 2018
3f64101
Patch torch.exp and torch.Tensor.exp
fritzo Aug 9, 2018
59c48e1
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 10, 2018
894fa58
Use JIT traced potential energy computation in HMC (#1299)
neerajprad Aug 14, 2018
f303855
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 28, 2018
bf98735
Merge branch 'dev' into pytorch-0.4.1
neerajprad Aug 31, 2018
9313b0f
add xfailing test
neerajprad Sep 4, 2018
a5bc9bb
Merge branch 'dev' into pytorch-0.4.1
fritzo Sep 11, 2018
3cfba13
Remove obsolete PyTorch patches
fritzo Sep 11, 2018
a677512
Remove patch for Tensor._standard_gamma
fritzo Sep 11, 2018
8f665ba
Fix some jit errors
fritzo Sep 11, 2018
06b0e63
Convert to valid einsum chars in torch_log backend
fritzo Sep 11, 2018
1785b6c
Updating distributions module with PyTorch master (#1377)
neerajprad Sep 11, 2018
29bb3ed
Use native torch.tensordot
fritzo Sep 11, 2018
8c914b8
Remove duplicate implementation of logsumexp
fritzo Sep 11, 2018
1c708b4
Merge branch 'dev' into pytorch-0.5.0
fritzo Sep 11, 2018
35a6965
Ignore jit warnings
fritzo Sep 11, 2018
692b883
Ignore a couple TracerWarnings in pyro.ops.jit.trace
fritzo Sep 11, 2018
0c50243
Fix a tiny test_jit error
fritzo Sep 11, 2018
850432f
Add jit test for OneHotCategorical
fritzo Sep 11, 2018
a03164a
fix JIT errors for HMC
neerajprad Sep 11, 2018
effada7
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 11, 2018
c36a9aa
change assert in torch_log
neerajprad Sep 11, 2018
ecfc995
Work around more jit missing coverage
fritzo Sep 11, 2018
c47e294
Strengthen masked_fill test
fritzo Sep 12, 2018
9ea7fd2
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 12, 2018
c19830d
fix hmc enum test
neerajprad Sep 12, 2018
0c233d3
Fix failing jit tests
fritzo Sep 13, 2018
dceaf9a
Add test for .scatter_() workaround
fritzo Sep 13, 2018
ac122b0
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 13, 2018
e5cd034
add expand for MaskedDistribution
neerajprad Sep 13, 2018
6ce9925
remove binomial and half cauchy
neerajprad Sep 13, 2018
e80ef20
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 14, 2018
68c1168
reinstate Independent constraint
neerajprad Sep 14, 2018
657fc56
add expand methods to more distributions
neerajprad Sep 14, 2018
d44f90b
Fix CUDA tests in test_eig.py
neerajprad Sep 14, 2018
c2c4b72
remove standard gamma patch
neerajprad Sep 14, 2018
62bf019
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 17, 2018
5532c3f
Work-around to allow JIT compiler to infer batch size in iarange (#1392)
neerajprad Sep 19, 2018
021707a
Remove deprecated new_tensor invocation
neerajprad Sep 20, 2018
981d64b
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 20, 2018
a36f25a
Remove deprecated new_tensor invocation
neerajprad Sep 20, 2018
693fcd9
remove .new
neerajprad Sep 20, 2018
3129059
address comments
neerajprad Sep 20, 2018
b48f108
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
a648968
fix test_hessian
neerajprad Sep 20, 2018
c1c9e82
fix more tests
neerajprad Sep 20, 2018
edf000e
remove redundant parens
neerajprad Sep 20, 2018
53c3e07
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
c2d3de8
fix test_elbo_mapdata
neerajprad Sep 20, 2018
bf85894
fix test_conj_gaussian
neerajprad Sep 20, 2018
c6dd8d7
fix test_valid_models
neerajprad Sep 20, 2018
b54081c
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
d4ff53c
fix dist tests
neerajprad Sep 20, 2018
0064b6a
fix test_gaussian_mixtures
neerajprad Sep 20, 2018
94cb96c
Merge branch 'new-tensor' into pytorch-0.5.0
neerajprad Sep 20, 2018
8ceaaa0
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 25, 2018
710406d
Test fixes for compatibility with PyTorch master
neerajprad Sep 26, 2018
eef40eb
address comments; more fixes
neerajprad Sep 26, 2018
2325b14
more test fixes
neerajprad Sep 26, 2018
a5a457f
uncomment torch_patch
neerajprad Sep 26, 2018
c9078dd
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 26, 2018
6078bc7
Merge branch 'test-fixes' into pytorch-0.5.0
neerajprad Sep 26, 2018
41164bb
ignore jit warnings in hmc
neerajprad Sep 26, 2018
21b32c0
remove default jit compilation in air
neerajprad Sep 26, 2018
865ddab
set args.jit default to false
neerajprad Sep 26, 2018
6ddeb38
ignore jit warnings in hmc tests
neerajprad Sep 26, 2018
43c5e4e
mark failing hmc tests
neerajprad Sep 27, 2018
b5ae590
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 27, 2018
68171b9
test against nightly build
neerajprad Sep 27, 2018
f4db712
fix channel name
neerajprad Sep 27, 2018
fb290e8
downgrade ipython
neerajprad Sep 27, 2018
ed4b360
fix lapack issue
neerajprad Sep 27, 2018
2142589
include mkl
neerajprad Sep 27, 2018
1efd552
addons to .travis
neerajprad Sep 27, 2018
b84122f
add pytorch channel
neerajprad Sep 27, 2018
85342ef
remove pythonpath
neerajprad Sep 27, 2018
464fe68
editable install
neerajprad Sep 27, 2018
475f483
add ld_library_path
neerajprad Sep 27, 2018
07e20d4
conda install pip
neerajprad Sep 27, 2018
a624b57
debug build
neerajprad Sep 27, 2018
053163f
debug - revert to pytorch release
neerajprad Sep 27, 2018
fb62e69
add before install
neerajprad Sep 27, 2018
13be21f
use nightly wheel
neerajprad Sep 27, 2018
31b5d63
Fix incompatible dependency between jupyter-console and ipython
neerajprad Sep 27, 2018
83b3eb4
Merge branch 'fix-ipython-dep' into pytorch-0.5.0
neerajprad Sep 27, 2018
f904d0c
remove torch==0.4.1 from setup
neerajprad Sep 27, 2018
13d7b51
remove torchvision temporarily
neerajprad Sep 27, 2018
0e17d3c
install torchvision without deps
neerajprad Sep 28, 2018
92d6d4d
remove torchvision from setup
neerajprad Sep 28, 2018
72c76fd
update to contextlib2
neerajprad Sep 28, 2018
0e60ce2
fix benchmark tests
neerajprad Sep 28, 2018
b328d12
add xfail markers for failing tests
neerajprad Sep 28, 2018
3161deb
temporarily xfail ubersum_sizes test
neerajprad Sep 28, 2018
e298cb4
fix xfail marker
neerajprad Sep 28, 2018
d13039b
Merge branch 'dev' into pytorch-0.5.0
neerajprad Sep 28, 2018
feca15d
remove xfail marker from test_enum
neerajprad Sep 28, 2018
a7300e8
add xfail for mixture of diag normals
neerajprad Oct 1, 2018
17f2033
Merge branch 'dev' into pytorch-0.5.0
neerajprad Oct 1, 2018
379ffef
fix mask fill on non contiguous tensor
neerajprad Oct 2, 2018
4f93b51
Merge branch 'dev' into pytorch-1.0
neerajprad Oct 22, 2018
f948b7c
fix imports
neerajprad Oct 22, 2018
9230599
Fix jit arg error in hmm example (#1445)
fritzo Oct 15, 2018
13eb17e
Merge branch 'pyt-local-1.0' into pytorch-1.0
neerajprad Oct 22, 2018
75ab233
Revert change to broadcast messenger
neerajprad Oct 22, 2018
16ad65d
fix parametrize in test_nuts
neerajprad Oct 22, 2018
776afcd
Fix tests
neerajprad Oct 23, 2018
fdb9d20
stash
neerajprad Oct 23, 2018
f51076c
Fix JIT tests
neerajprad Oct 23, 2018
e0300f2
Merge branch 'dev' into pytorch-1.0
neerajprad Oct 23, 2018
7200fc3
fix test_mapdata
neerajprad Oct 23, 2018
86e3cf3
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 12, 2018
a599587
Change torch.potrf usage to torch.cholesky (#1529)
neerajprad Nov 13, 2018
6be5d06
update test_beta_bernoulli to use pyro.plate
neerajprad Nov 13, 2018
40bb987
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 13, 2018
c228cb6
log example output while running
neerajprad Nov 14, 2018
8943517
fix error in broadcast_all
neerajprad Nov 15, 2018
0b9647e
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 15, 2018
25a40da
Fix test_memory to ignore UserWarning with gc.getobjects()
neerajprad Nov 15, 2018
a8a49b1
add xfail marker to failing jittraceenum test
neerajprad Nov 15, 2018
40965f1
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 16, 2018
92af0f4
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 26, 2018
e55dff8
fix test_svi_enum
neerajprad Nov 26, 2018
8b6d40f
add jit markers for hmc
neerajprad Nov 27, 2018
e5fdb18
skip slow jit tests
neerajprad Nov 27, 2018
19e0720
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 27, 2018
f154fa3
add gp to jit test
neerajprad Nov 27, 2018
68b88b2
remove low rank mvn docs
neerajprad Nov 27, 2018
364e2a1
fix lint
neerajprad Nov 27, 2018
820493e
address comment
neerajprad Nov 27, 2018
16b5149
update travis build
neerajprad Nov 27, 2018
f110c01
Fix build timeout due to dataset download issues (#1571)
neerajprad Nov 27, 2018
f838a34
rtd install pytorch 1.0 (#1572)
jpchen Nov 27, 2018
26b0797
change data directory for gp
neerajprad Nov 27, 2018
b303af0
revert rtd command change
jpchen Nov 27, 2018
d6653ef
use pytorch 0.4 for rtd
jpchen Nov 27, 2018
1da48c1
Cache dataset directory on CI build
neerajprad Nov 27, 2018
ea76e3f
fix default
neerajprad Nov 27, 2018
68c26dc
fix example
neerajprad Nov 27, 2018
7819cea
fix dmm path
neerajprad Nov 27, 2018
3447b55
Merge branch 'cache-data' into pytorch-1.0
neerajprad Nov 27, 2018
9433f67
address lack of exists_ok in makedirs in python 2
neerajprad Nov 27, 2018
cd882db
Merge branch 'cache-data' into pytorch-1.0
neerajprad Nov 27, 2018
bcffd2e
add debug info
neerajprad Nov 28, 2018
7f947b8
fix lint
neerajprad Nov 28, 2018
ed4e113
fix errno
neerajprad Nov 28, 2018
a59321b
skip gp example to see which others fail
neerajprad Nov 28, 2018
d0271ca
fix pytest param
neerajprad Nov 28, 2018
5dfd731
revert changes from #1573
neerajprad Nov 28, 2018
aeadda0
update pytorch build
neerajprad Nov 28, 2018
499185a
revert to nov 27 build
neerajprad Nov 28, 2018
cd654cd
Merge branch 'dev' into pytorch-1.0
neerajprad Nov 28, 2018
7b2a014
Update travis with 11/28 build
neerajprad Nov 28, 2018
59c7304
add typing to setup.py
neerajprad Nov 28, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@ cache:

install:
- pip install -U pip
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl;
else
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl;
fi
- pip install torch_nightly==1.0.0.dev20181128 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install torchvision --no-dependencies
- pip install .[test]
- pip freeze

Expand Down
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ test-cuda: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda

test-jit: FORCE
@echo See jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/mcmc/test_hmc.py tests/infer/mcmc/test_nuts.py \
-k JIT=True | tee -a jit.log

clean: FORCE
git clean -dfx -e pyro-egg.info

Expand Down
25 changes: 6 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,14 @@ Make sure that the models come from the same release version of the [Pyro source

For recent features you can install Pyro from source.

To install a compatible CPU version of PyTorch on OSX / Linux, you
could use the PyTorch install helper script.
To install a compatible version of PyTorch, use the PyTorch nightly
[build](https://pytorch.org/). We recommend pinning to the specific
nightly build below that has been well tested.

```
bash scripts/install_pytorch.sh
```

Alternatively, build PyTorch following instructions in the PyTorch
[README](https://github.com/pytorch/pytorch/blob/master/README.md).
```sh
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
git checkout 200fb22 # <---- a well-tested commit
```
On Linux:
```sh
python setup.py install
```
On OSX:
```sh
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
build_ver=1.0.0.dev20181127 # <---- a well-tested PyTorch build
pip install torch_nightly==${build_ver} -f \
https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
```

Finally install Pyro using pip or from source as shown below.
Expand Down
9 changes: 5 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def setup(app):

# @jpchen's hack to get rtd builder to install latest pytorch
if 'READTHEDOCS' in os.environ:
os.system('curl -o install.sh https://raw.githubusercontent.com/uber/pyro/dev/scripts/install_pytorch.sh')
os.system('curl https://raw.githubusercontent.com/uber/pyro/dev/README.md > README.md')
os.system('bash install.sh')
os.system('rm -f install.sh')
os.system('pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl')
# for pytorch 1.0 (currently fails with OOM
# https://readthedocs.org/projects/pyro-ppl/builds/8159615/
# os.system('pip install torch_nightly==1.0.0.dev20181127 -f '
# 'https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html')
22 changes: 0 additions & 22 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ AVFMultivariateNormal
:undoc-members:
:show-inheritance:

Binomial
--------

.. autoclass:: pyro.distributions.Binomial
:members:
:undoc-members:
:show-inheritance:

Delta
-----
.. autoclass:: pyro.distributions.Delta
Expand All @@ -85,20 +77,6 @@ GaussianScaleMixture
:undoc-members:
:show-inheritance:

HalfCauchy
----------
.. autoclass:: pyro.distributions.HalfCauchy
:members:
:undoc-members:
:show-inheritance:

LowRankMultivariateNormal
-------------------------
.. autoclass:: pyro.distributions.LowRankMultivariateNormal
:members:
:undoc-members:
:show-inheritance:

MaskedMixture
-------------
.. autoclass:: pyro.distributions.MaskedMixture
Expand Down
2 changes: 1 addition & 1 deletion docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ Primitives
.. autofunction:: pyro.validation_enabled
.. autofunction:: pyro.enable_validation

.. autofunction:: pyro.ops.jit.compile
.. autofunction:: pyro.ops.jit.trace
2 changes: 2 additions & 0 deletions examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,5 +333,7 @@ def main(args):
parser.add_argument("--num-chains", nargs='?', default=4, type=int)
parser.add_argument("--warmup-steps", nargs='?', default=100, type=int)
parser.add_argument("--rng_seed", nargs='?', default=0, type=int)
parser.add_argument('--jit', action='store_true', default=False,
help='use PyTorch jit')
args = parser.parse_args()
main(args)
3 changes: 2 additions & 1 deletion examples/eight_schools/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def conditioned_model(model, sigma, y):


def main(args):
nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit,)
posterior = MCMC(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
Expand All @@ -58,6 +58,7 @@ def main(args):
help='number of parallel MCMC chains (default: 1)')
parser.add_argument('--warmup-steps', type=int, default=1000,
help='number of MCMC samples for warmup (default: 1000)')
parser.add_argument('--jit', action='store_true', default=False)
args = parser.parse_args()

main(args)
4 changes: 2 additions & 2 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def main(args):
# We'll train on small minibatches.
logging.info('Step\tLoss')
for step in range(args.num_steps):
loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

# We evaluate on the entire training dataset,
Expand All @@ -340,7 +340,7 @@ def main(args):
if args.truncate:
lengths.clamp_(max=args.truncate)
num_observations = float(lengths.sum())
test_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False)
test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False)
logging.info('test loss = {}'.format(test_loss / num_observations))

# We expect models with higher capacity to perform better,
Expand Down
1 change: 0 additions & 1 deletion examples/vae/utils/mnist_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from pyro.contrib.examples.util import get_data_directory


# This file contains utilities for caching, transforming and splitting MNIST data
# efficiently. By default, a PyTorch DataLoader will apply the transform every epoch
# we avoid this by caching the data early on in MNISTCached class
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def laplace_approximation(self, *args, **kwargs):
loc = pyro.param("{}_loc".format(self.prefix))
H = hessian(loss, loc.unconstrained())
cov = H.inverse()
scale_tril = cov.potrf(upper=False)
scale_tril = cov.cholesky()

# calculate scale_tril from self.guide()
scale_tril_name = "{}_scale_tril".format(self.prefix)
Expand Down
25 changes: 21 additions & 4 deletions pyro/contrib/examples/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,23 @@
import os
import sys

import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from pyro.distributions.torch_patch import patch_dependency


@patch_dependency('torchvision.datasets.MNIST', torchvision)
class _MNIST(getattr(MNIST, '_pyro_unpatched', MNIST)):
urls = [
"https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz",
"https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz",
"https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz",
"https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz",
]


def get_data_loader(dataset_name,
Expand All @@ -18,11 +32,14 @@ def get_data_loader(dataset_name,
dataset_transforms = []
trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
dataset = getattr(datasets, dataset_name)
print("downloading data")
dset = dataset(root=data_dir,
train=is_training_set,
transform=trans,
download=True)
print("download complete.")
return DataLoader(
dataset(root=data_dir,
train=is_training_set,
transform=trans,
download=True),
dset,
batch_size=batch_size,
shuffle=shuffle
)
Expand Down
8 changes: 4 additions & 4 deletions pyro/contrib/gp/models/gplvm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
from pyro.contrib.gp.util import Parameterized
import pyro.distributions as dist
import pyro.infer as infer
import pyro.optim as optim
from pyro.contrib.gp.util import Parameterized
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, base_model, name="GPLVM"):

C = self.X_loc.shape[1]
X_scale_tril_shape = self.X_loc.shape + (C,)
Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
Id = eye_like(self.X_loc, C)
X_scale_tril = Id.expand(X_scale_tril_shape)
self.X_scale_tril = Parameter(X_scale_tril)
self.set_constraint("X_scale_tril", constraints.lower_cholesky)
Expand All @@ -87,7 +87,7 @@ def model(self):
# sample X from unit multivariate normal distribution
zero_loc = self.X_loc.new_zeros(self.X_loc.shape)
C = self.X_loc.shape[1]
Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
Id = eye_like(self.X_loc, C)
X_name = param_with_module_name(self.name, "X")
X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim()-1))
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def model(self):
N = self.X.shape[0]
Kff = self.kernel(self.X)
Kff.view(-1)[::N + 1] += noise # add noise to diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

zero_loc = self.X.new_zeros(self.X.shape[0])
f_loc = zero_loc + self.mean_function(self.X)
Expand Down Expand Up @@ -129,7 +129,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
N = self.X.shape[0]
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += noise # add noise to the diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

y_residual = self.y - self.mean_function(self.X)
loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff,
Expand Down Expand Up @@ -185,7 +185,7 @@ def sample_next(xnew, outside_vars):
X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"]

# Compute Cholesky decomposition of kernel matrix
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
y_residual = y - self.mean_function(X)

# Compute conditional mean and variance
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/sgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def model(self):
M = Xu.shape[0]
Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()
Kuf = self.kernel(Xu, self.X)
W = Kuf.trtrs(Luu, upper=False)[0].t()

Expand Down Expand Up @@ -210,7 +210,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):

Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()
Kus = self.kernel(Xu, Xnew)
Kuf = self.kernel(Xu, self.X)

Expand All @@ -225,7 +225,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
W_Dinv = W / D
K = W_Dinv.matmul(W.t()).contiguous()
K.view(-1)[::M + 1] += 1 # add identity matrix to K
L = K.potrf(upper=False)
L = K.cholesky()

# get y_residual and convert it into 2D tensor for packing
y_residual = self.y - self.mean_function(self.X)
Expand Down
7 changes: 4 additions & 3 deletions pyro/contrib/gp/models/vgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyro.distributions as dist
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self, X, y, kernel, likelihood, mean_function=None,
self.f_loc = Parameter(f_loc)

f_scale_tril_shape = self.latent_shape + (N, N)
Id = torch.eye(N, out=self.X.new_empty(N, N))
Id = eye_like(self.X, N)
f_scale_tril = Id.expand(f_scale_tril_shape)
self.f_scale_tril = Parameter(f_scale_tril)
self.set_constraint("f_scale_tril", constraints.lower_cholesky)
Expand All @@ -90,13 +91,13 @@ def model(self):
N = self.X.shape[0]
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

zero_loc = self.X.new_zeros(f_loc.shape)
f_name = param_with_module_name(self.name, "f")

if self.whiten:
Id = torch.eye(N, out=self.X.new_empty(N, N))
Id = eye_like(self.X, N)
pyro.sample(f_name,
dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim() - 1))
Expand Down
7 changes: 4 additions & 3 deletions pyro/contrib/gp/models/vsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyro.poutine as poutine
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like
from pyro.params import param_with_module_name


Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None,
self.u_loc = Parameter(u_loc)

u_scale_tril_shape = self.latent_shape + (M, M)
Id = torch.eye(M, out=self.Xu.new_empty(M, M))
Id = eye_like(self.Xu, M)
u_scale_tril = Id.expand(u_scale_tril_shape)
self.u_scale_tril = Parameter(u_scale_tril)
self.set_constraint("u_scale_tril", constraints.lower_cholesky)
Expand All @@ -115,12 +116,12 @@ def model(self):
M = Xu.shape[0]
Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()

zero_loc = Xu.new_zeros(u_loc.shape)
u_name = param_with_module_name(self.name, "u")
if self.whiten:
Id = torch.eye(M, out=Xu.new_empty(M, M))
Id = eye_like(Xu, M)
pyro.sample(u_name,
dist.MultivariateNormal(zero_loc, scale_tril=Id)
.independent(zero_loc.dim() - 1))
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa
if Lff is None:
Kff = kernel(X).contiguous()
Kff.view(-1)[::N + 1] += jitter # add jitter to diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
Kfs = kernel(X, Xnew)

# convert f_loc_shape from latent_shape x N to N x latent_shape
Expand Down
Loading