From 585beb9d76f175fc7772bafacb4eb95aadd938e5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 1 Feb 2021 18:01:59 -0500 Subject: [PATCH 01/35] Bump to version 1.5.2 (#2755) --- .travis.yml | 2 +- docs/source/conf.py | 2 +- examples/air/main.py | 2 +- examples/baseball.py | 2 +- examples/contrib/autoname/mixture.py | 2 +- examples/contrib/autoname/scoping_mixture.py | 2 +- examples/contrib/autoname/tree_data.py | 2 +- examples/contrib/cevae/synthetic.py | 2 +- examples/contrib/epidemiology/regional.py | 2 +- examples/contrib/epidemiology/sir.py | 2 +- examples/contrib/forecast/bart.py | 2 +- examples/contrib/funsor/hmm.py | 2 +- examples/contrib/gp/sv-dkl.py | 2 +- examples/contrib/oed/ab_test.py | 2 +- examples/contrib/timeseries/gp_models.py | 2 +- examples/cvae/main.py | 2 +- examples/dmm.py | 2 +- examples/eight_schools/mcmc.py | 2 +- examples/eight_schools/svi.py | 2 +- examples/hmm.py | 2 +- examples/inclined_plane.py | 2 +- examples/lda.py | 2 +- examples/lkj.py | 2 +- examples/minipyro.py | 2 +- examples/neutra.py | 2 +- examples/rsa/generics.py | 2 +- examples/rsa/hyperbole.py | 2 +- examples/rsa/schelling.py | 2 +- examples/rsa/schelling_false.py | 2 +- examples/rsa/semantic_parsing.py | 2 +- examples/scanvi/scanvi.py | 2 +- examples/sir_hmc.py | 2 +- examples/sparse_gamma_def.py | 2 +- examples/sparse_regression.py | 2 +- examples/svi_horovod.py | 2 +- .../toy_mixture_model_discrete_enumeration.py | 2 +- examples/vae/ss_vae_M2.py | 2 +- examples/vae/vae.py | 2 +- examples/vae/vae_comparison.py | 2 +- pyro/__init__.py | 2 +- pyro/ops/tensor_utils.py | 18 ++++++++++++++++-- setup.py | 6 +++--- tests/test_examples.py | 8 +++++--- tutorial/source/air.ipynb | 2 +- tutorial/source/bayesian_regression.ipynb | 2 +- tutorial/source/bayesian_regression_ii.ipynb | 2 +- tutorial/source/bo.ipynb | 2 +- .../source/dirichlet_process_mixture.ipynb | 2 +- tutorial/source/easyguide.ipynb | 2 +- tutorial/source/ekf.ipynb | 2 +- tutorial/source/enumeration.ipynb | 2 +- tutorial/source/epi_intro.ipynb | 2 +- tutorial/source/forecasting_dlm.ipynb | 2 +- tutorial/source/forecasting_i.ipynb | 2 +- tutorial/source/forecasting_ii.ipynb | 2 +- tutorial/source/forecasting_iii.ipynb | 2 +- tutorial/source/gmm.ipynb | 2 +- tutorial/source/gp.ipynb | 2 +- tutorial/source/gplvm.ipynb | 2 +- tutorial/source/jit.ipynb | 2 +- tutorial/source/modules.ipynb | 2 +- tutorial/source/stable.ipynb | 2 +- tutorial/source/svi_part_i.ipynb | 2 +- tutorial/source/svi_part_iii.ipynb | 2 +- tutorial/source/tensor_shapes.ipynb | 2 +- tutorial/source/tracking_1d.ipynb | 2 +- tutorial/source/vae.ipynb | 2 +- 67 files changed, 88 insertions(+), 72 deletions(-) diff --git a/.travis.yml b/.travis.yml index 229b8ad30a..93935a475b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,7 +25,7 @@ install: - pip install -U pip # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - - pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + - pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - pip install .[test] - pip install coveralls - pip freeze diff --git a/docs/source/conf.py b/docs/source/conf.py index 250af8e937..d39a267cb9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -213,5 +213,5 @@ def setup(app): # @jpchen's hack to get rtd builder to install latest pytorch # See similar line in the install section of .travis.yml if 'READTHEDOCS' in os.environ: - os.system('pip install torch==1.6.0+cpu torchvision==0.7.0+cpu ' + os.system('pip install torch==1.7.0+cpu torchvision==0.8.1+cpu ' '-f https://download.pytorch.org/whl/torch_stable.html') diff --git a/examples/air/main.py b/examples/air/main.py index 8cf13b255e..43d30d3a02 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -248,7 +248,7 @@ def per_param_optim_args(module_name, param_name): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS) parser.add_argument('-n', '--num-steps', type=int, default=int(1e8), help='number of optimization steps to take') diff --git a/examples/baseball.py b/examples/baseball.py index 070adc2f10..433b807a7e 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -330,7 +330,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Baseball batting average using HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int) parser.add_argument("--num-chains", nargs='?', default=4, type=int) diff --git a/examples/contrib/autoname/mixture.py b/examples/contrib/autoname/mixture.py index 92ee992aff..1e26f0a3ec 100644 --- a/examples/contrib/autoname/mixture.py +++ b/examples/contrib/autoname/mixture.py @@ -75,7 +75,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=200, type=int) parser.add_argument('--jit', action='store_true') diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 842e9f03c0..363d3ace53 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -67,7 +67,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=200, type=int) args = parser.parse_args() diff --git a/examples/contrib/autoname/tree_data.py b/examples/contrib/autoname/tree_data.py index d2a7209d6b..14811fd758 100644 --- a/examples/contrib/autoname/tree_data.py +++ b/examples/contrib/autoname/tree_data.py @@ -105,7 +105,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=100, type=int) args = parser.parse_args() diff --git a/examples/contrib/cevae/synthetic.py b/examples/contrib/cevae/synthetic.py index f2193030f2..b4e6beb5b7 100644 --- a/examples/contrib/cevae/synthetic.py +++ b/examples/contrib/cevae/synthetic.py @@ -81,7 +81,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Causal Effect Variational Autoencoder") parser.add_argument("--num-data", default=1000, type=int) parser.add_argument("--feature-dim", default=5, type=int) diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index 220e1bee65..a944eae66f 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -144,7 +144,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser( description="Regional compartmental epidemiology modeling using HMC") parser.add_argument("-p", "--population", default=1000, type=int) diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 5c3088f8fc..436075f366 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -295,7 +295,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser( description="Compartmental epidemiology modeling using HMC") parser.add_argument("-p", "--population", default=1000, type=float) diff --git a/examples/contrib/forecast/bart.py b/examples/contrib/forecast/bart.py index ac1d088a97..f4078bc062 100644 --- a/examples/contrib/forecast/bart.py +++ b/examples/contrib/forecast/bart.py @@ -156,7 +156,7 @@ def transform(pred, truth): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example") parser.add_argument("--train-window", default=2160, type=int) parser.add_argument("--test-window", default=336, type=int) diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 0c70e0a113..24dbc1a05a 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -620,7 +620,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales") parser.add_argument("-m", "--model", default="1", type=str, help="one of: {}".format(", ".join(sorted(models.keys())))) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 338ad6a19f..f07c4052e4 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -165,7 +165,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') parser.add_argument('--data-dir', type=str, default=None, metavar='PATH', help='default directory to cache MNIST data') diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 16842b7c05..522dc44ad3 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -115,7 +115,7 @@ def main(num_vi_steps, num_bo_steps, seed): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="A/B test experiment design using VI") parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int) parser.add_argument('--num-bo-steps', nargs="?", default=5, type=int) diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index eb0075e626..840d5e2c65 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -149,7 +149,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="contrib.timeseries example usage") parser.add_argument("-n", "--num-steps", default=300, type=int) parser.add_argument("-s", "--seed", default=0, type=int) diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 224b4f05af..2056b4d3a6 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -82,7 +82,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-nq', '--num-quadrant-inputs', metavar='N', type=int, diff --git a/examples/dmm.py b/examples/dmm.py index 38117ede79..ce7061e790 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -453,7 +453,7 @@ def do_evaluation(): # parse command-line arguments and execute the main method if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', type=int, default=5000) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index 2a5b2c0bc1..636da1cf61 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -42,7 +42,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='Eight Schools MCMC') parser.add_argument('--num-samples', type=int, default=1000, help='number of MCMC samples (default: 1000)') diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index ded9d216d2..9aaec86a6a 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -75,7 +75,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='Eight Schools SVI') parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.01)') diff --git a/examples/hmm.py b/examples/hmm.py index 650c802653..7be3762f1d 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -638,7 +638,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales") parser.add_argument("-m", "--model", default="1", type=str, help="one of: {}".format(", ".join(sorted(models.keys())))) diff --git a/examples/inclined_plane.py b/examples/inclined_plane.py index 1df3c1ccb0..6aae4d2e4b 100644 --- a/examples/inclined_plane.py +++ b/examples/inclined_plane.py @@ -123,7 +123,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=500, type=int) args = parser.parse_args() diff --git a/examples/lda.py b/examples/lda.py index dd8b6fc2dc..a852b484d3 100644 --- a/examples/lda.py +++ b/examples/lda.py @@ -138,7 +138,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation") parser.add_argument("-t", "--num-topics", default=8, type=int) parser.add_argument("-w", "--num-words", default=1024, type=int) diff --git a/examples/lkj.py b/examples/lkj.py index 96b2b9cd2c..9895587020 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -49,7 +49,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior") parser.add_argument("--num-samples", nargs="?", default=200, type=int) parser.add_argument("--n", nargs="?", default=500, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index e12775dfda..33ab6eba01 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -65,7 +65,7 @@ def guide(data): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-b", "--backend", default="minipyro") parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index b8d2b9413e..a06292a32d 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -186,7 +186,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer') parser.add_argument('-n', '--num-steps', default=10000, type=int, help='number of SVI steps') diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 1e617316c4..9a13a4e54a 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -157,7 +157,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) args = parser.parse_args() diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index 04d878fa8b..a77b01870a 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -154,7 +154,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) parser.add_argument('--price', default=10000, type=int) diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index ea31af2811..886eb405b6 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -76,7 +76,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) parser.add_argument('--depth', default=2, type=int) diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index 4a5ffcdf98..998e3b70cb 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -89,7 +89,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) parser.add_argument('--depth', default=3, type=int) diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index 0a998c6227..15ffe901aa 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -340,7 +340,7 @@ def is_all_qud(world): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) args = parser.parse_args() diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 1a31690be5..eb7e029012 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -351,7 +351,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') # Parse command line arguments parser = argparse.ArgumentParser(description="single-cell ANnotation using Variational Inference") parser.add_argument('-s', '--seed', default=0, type=int, help='rng seed') diff --git a/examples/sir_hmc.py b/examples/sir_hmc.py index b85dbfe66e..d685f5987f 100644 --- a/examples/sir_hmc.py +++ b/examples/sir_hmc.py @@ -573,7 +573,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC") parser.add_argument("-p", "--population", default=10, type=int) parser.add_argument("-m", "--min-observations", default=3, type=int) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index 5390491a82..18dc887ee7 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -234,7 +234,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=1500, type=int, help='number of training epochs') diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index d110f96168..c7eac172f7 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -311,7 +311,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='Krylov KIT') parser.add_argument('--num-data', type=int, default=750) parser.add_argument('--num-steps', type=int, default=1000) diff --git a/examples/svi_horovod.py b/examples/svi_horovod.py index b66304b7d0..4e15199631 100644 --- a/examples/svi_horovod.py +++ b/examples/svi_horovod.py @@ -150,7 +150,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Distributed training via Horovod") parser.add_argument("-o", "--outfile") parser.add_argument("-s", "--size", default=1000000, type=int) diff --git a/examples/toy_mixture_model_discrete_enumeration.py b/examples/toy_mixture_model_discrete_enumeration.py index 8ad76fc1e9..700f9c48ec 100644 --- a/examples/toy_mixture_model_discrete_enumeration.py +++ b/examples/toy_mixture_model_discrete_enumeration.py @@ -127,7 +127,7 @@ def get_true_pred_CPDs(CPD, posterior_param): if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="Toy mixture model") parser.add_argument("-n", "--num-steps", default=4000, type=int) parser.add_argument("-o", "--num-obs", default=10000, type=int) diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index a427edde00..eac00ea24a 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -383,7 +383,7 @@ def main(args): "-sup 3000 -zd 50 -hl 500 -lr 0.00042 -b1 0.95 -bs 200 -log ./tmp.log" if __name__ == "__main__": - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN)) diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 321f3cfc62..9f714d9387 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -201,7 +201,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs') diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index 534aa9ea87..e95f66d2b3 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -246,7 +246,7 @@ def main(args): if __name__ == '__main__': - assert pyro.__version__.startswith('1.5.1') + assert pyro.__version__.startswith('1.5.2') parser = argparse.ArgumentParser(description='VAE using MNIST dataset') parser.add_argument('-n', '--num-epochs', nargs='?', default=10, type=int) parser.add_argument('--batch_size', nargs='?', default=128, type=int) diff --git a/pyro/__init__.py b/pyro/__init__.py index 1fbe6062fc..90295a7d23 100644 --- a/pyro/__init__.py +++ b/pyro/__init__.py @@ -10,7 +10,7 @@ from pyro.util import set_rng_seed # After changing this, run scripts/update_version.py -version_prefix = '1.5.1' +version_prefix = '1.5.2' # Get the __version__ string from the auto-generated _version.py file, if exists. try: diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 0ac27aac4e..e4eda46fb0 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -10,6 +10,20 @@ _ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0) +def as_complex(x): + """ + Similar to :func:`torch.view_as_complex` but copies data in case strides + are not multiples of two. + """ + if any(stride % 2 for stride in x.stride()[:-1]): + # First try to normalize strides. + x = x.squeeze().reshape(x.shape) + if any(stride % 2 for stride in x.stride()[:-1]): + # Fall back to copying data. + x = x.clone() + return torch.view_as_complex(x) + + def block_diag_embed(mat): """ Takes a tensor of shape (..., B, M, N) and returns a block diagonal tensor @@ -278,7 +292,7 @@ def dct(x, dim=-1): coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) M = Y.size(-1) coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1) - X = torch.view_as_complex(coef) * Y + X = as_complex(coef) * Y # NB: if we use the full-length version Y_full = fft(y, n=N), then # the real part of the later half of X will be the flip # of the negative of the imaginary part of the first half @@ -320,7 +334,7 @@ def idct(x, dim=-1): X = torch.stack([x[..., :M], xi], dim=-1) coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1) - Y = torch.view_as_complex(coef) * torch.view_as_complex(X) + Y = as_complex(coef) * as_complex(X) # Step 2 y = irfft(Y, n=N) # Step 3 diff --git a/setup.py b/setup.py index 40ffac4749..d0906bf8bf 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ 'jupyter>=1.0.0', 'graphviz>=0.8', 'matplotlib>=1.3', - 'torchvision>=0.7.0', + 'torchvision>=0.7<0.9', 'visdom>=0.1.4', 'pandas', 'scikit-learn', @@ -88,7 +88,7 @@ 'numpy>=1.7', 'opt_einsum>=2.3.2', 'pyro-api>=0.1.1', - 'torch>=1.6.0', + 'torch>=1.6<1.8', 'tqdm>=4.36', ], extras_require={ @@ -119,7 +119,7 @@ 'horovod': ['horovod[pytorch]>=0.19'], 'funsor': [ # This must be a released version when Pyro is released. - 'funsor[torch]==0.3.0', + 'funsor[torch]==0.4.0', ], }, python_requires='>=3.6', diff --git a/tests/test_examples.py b/tests/test_examples.py index b0d0cb96c8..e7200c3328 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -187,14 +187,16 @@ ] -def xfail_jit(*args): - return pytest.param(*args, marks=[pytest.mark.xfail(reason="not jittable"), +def xfail_jit(*args, **kwargs): + reason = kwargs.pop("reason", "not jittable") + return pytest.param(*args, marks=[pytest.mark.xfail(reason=reason), pytest.mark.skipif('CI' in os.environ, reason='slow test')]) JIT_EXAMPLES = [ 'air/main.py --num-steps=1 --jit', - 'baseball.py --num-samples=200 --warmup-steps=100 --jit', + xfail_jit('baseball.py --num-samples=200 --warmup-steps=100 --jit', + reason='unreproducible RuntimeError on CI'), 'contrib/autoname/mixture.py --num-epochs=1 --jit', 'contrib/cevae/synthetic.py --num-epochs=1 --jit', 'contrib/epidemiology/sir.py --jit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', diff --git a/tutorial/source/air.ipynb b/tutorial/source/air.ipynb index 4d56174cef..0dff592cee 100644 --- a/tutorial/source/air.ipynb +++ b/tutorial/source/air.ipynb @@ -41,7 +41,7 @@ "import numpy as np\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)" ] }, diff --git a/tutorial/source/bayesian_regression.ipynb b/tutorial/source/bayesian_regression.ipynb index 3f629adb5e..b45cab6e71 100644 --- a/tutorial/source/bayesian_regression.ipynb +++ b/tutorial/source/bayesian_regression.ipynb @@ -69,7 +69,7 @@ "\n", "# for CI testing\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(1)\n", "pyro.enable_validation(True)\n", diff --git a/tutorial/source/bayesian_regression_ii.ipynb b/tutorial/source/bayesian_regression_ii.ipynb index 267dc7d811..88b105af9d 100644 --- a/tutorial/source/bayesian_regression_ii.ipynb +++ b/tutorial/source/bayesian_regression_ii.ipynb @@ -44,7 +44,7 @@ "import pyro.optim as optim\n", "\n", "pyro.set_rng_seed(1)\n", - "assert pyro.__version__.startswith('1.5.1')" + "assert pyro.__version__.startswith('1.5.2')" ] }, { diff --git a/tutorial/source/bo.ipynb b/tutorial/source/bo.ipynb index 4d26c795f9..20207b6509 100644 --- a/tutorial/source/bo.ipynb +++ b/tutorial/source/bo.ipynb @@ -54,7 +54,7 @@ "import pyro\n", "import pyro.contrib.gp as gp\n", "\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # can help with debugging\n", "pyro.set_rng_seed(1)" ] diff --git a/tutorial/source/dirichlet_process_mixture.ipynb b/tutorial/source/dirichlet_process_mixture.ipynb index 03fc5ec376..ebb2077bd2 100644 --- a/tutorial/source/dirichlet_process_mixture.ipynb +++ b/tutorial/source/dirichlet_process_mixture.ipynb @@ -76,7 +76,7 @@ "from pyro.infer import Predictive, SVI, Trace_ELBO\n", "from pyro.optim import Adam\n", "\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # can help with debugging\n", "pyro.set_rng_seed(0)" ] diff --git a/tutorial/source/easyguide.ipynb b/tutorial/source/easyguide.ipynb index 800f9a2b3d..d0ce023831 100644 --- a/tutorial/source/easyguide.ipynb +++ b/tutorial/source/easyguide.ipynb @@ -45,7 +45,7 @@ "\n", "pyro.enable_validation(True)\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')" + "assert pyro.__version__.startswith('1.5.2')" ] }, { diff --git a/tutorial/source/ekf.ipynb b/tutorial/source/ekf.ipynb index 7fe9ab97e1..6d6c70cf91 100644 --- a/tutorial/source/ekf.ipynb +++ b/tutorial/source/ekf.ipynb @@ -98,7 +98,7 @@ "from pyro.contrib.tracking.measurements import PositionMeasurement\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)" ] }, diff --git a/tutorial/source/enumeration.ipynb b/tutorial/source/enumeration.ipynb index 9d849d4343..037b8636b6 100644 --- a/tutorial/source/enumeration.ipynb +++ b/tutorial/source/enumeration.ipynb @@ -50,7 +50,7 @@ "from pyro.ops.indexing import Vindex\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation()\n", "pyro.set_rng_seed(0)" ] diff --git a/tutorial/source/epi_intro.ipynb b/tutorial/source/epi_intro.ipynb index c34379ff33..d49efcbb8b 100644 --- a/tutorial/source/epi_intro.ipynb +++ b/tutorial/source/epi_intro.ipynb @@ -58,7 +58,7 @@ "from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "torch.set_default_dtype(torch.double) # Required for MCMC inference.\n", "pyro.enable_validation(True) # Always a good idea.\n", "smoke_test = ('CI' in os.environ)" diff --git a/tutorial/source/forecasting_dlm.ipynb b/tutorial/source/forecasting_dlm.ipynb index bd84ddad5a..14b7d3a163 100644 --- a/tutorial/source/forecasting_dlm.ipynb +++ b/tutorial/source/forecasting_dlm.ipynb @@ -46,7 +46,7 @@ "from pyro.ops.stats import quantile\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(20200928)\n", diff --git a/tutorial/source/forecasting_i.ipynb b/tutorial/source/forecasting_i.ipynb index d01aa418bd..50a5d74ade 100644 --- a/tutorial/source/forecasting_i.ipynb +++ b/tutorial/source/forecasting_i.ipynb @@ -47,7 +47,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(20200221)" ] diff --git a/tutorial/source/forecasting_ii.ipynb b/tutorial/source/forecasting_ii.ipynb index 436e4cca3b..1ed3b69d5e 100644 --- a/tutorial/source/forecasting_ii.ipynb +++ b/tutorial/source/forecasting_ii.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(20200305)" ] diff --git a/tutorial/source/forecasting_iii.ipynb b/tutorial/source/forecasting_iii.ipynb index ec870b92b1..e74b3acd50 100644 --- a/tutorial/source/forecasting_iii.ipynb +++ b/tutorial/source/forecasting_iii.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "pyro.set_rng_seed(20200305)" ] diff --git a/tutorial/source/gmm.ipynb b/tutorial/source/gmm.ipynb index 6ebb07d45f..22cbe88c76 100644 --- a/tutorial/source/gmm.ipynb +++ b/tutorial/source/gmm.ipynb @@ -41,7 +41,7 @@ "from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)" ] }, diff --git a/tutorial/source/gp.ipynb b/tutorial/source/gp.ipynb index 0c0d87106a..503f08058e 100644 --- a/tutorial/source/gp.ipynb +++ b/tutorial/source/gp.ipynb @@ -60,7 +60,7 @@ "import pyro.distributions as dist\n", "\n", "smoke_test = ('CI' in os.environ) # ignore; used to check code integrity in the Pyro repo\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # can help with debugging\n", "pyro.set_rng_seed(0)" ] diff --git a/tutorial/source/gplvm.ipynb b/tutorial/source/gplvm.ipynb index 45decda139..436e253c08 100644 --- a/tutorial/source/gplvm.ipynb +++ b/tutorial/source/gplvm.ipynb @@ -39,7 +39,7 @@ "import pyro.ops.stats as stats\n", "\n", "smoke_test = ('CI' in os.environ) # ignore; used to check code integrity in the Pyro repo\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # can help with debugging\n", "pyro.set_rng_seed(1)" ] diff --git a/tutorial/source/jit.ipynb b/tutorial/source/jit.ipynb index c36570b1a8..54e56b3adc 100644 --- a/tutorial/source/jit.ipynb +++ b/tutorial/source/jit.ipynb @@ -48,7 +48,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # <---- This is always a good idea!" ] }, diff --git a/tutorial/source/modules.ipynb b/tutorial/source/modules.ipynb index 87582e9006..7bf801900d 100644 --- a/tutorial/source/modules.ipynb +++ b/tutorial/source/modules.ipynb @@ -61,7 +61,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # <---- This is always a good idea!" ] }, diff --git a/tutorial/source/stable.ipynb b/tutorial/source/stable.ipynb index 039f087d73..9132737d8e 100644 --- a/tutorial/source/stable.ipynb +++ b/tutorial/source/stable.ipynb @@ -62,7 +62,7 @@ "from pyro.ops.tensor_utils import convolve\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "smoke_test = ('CI' in os.environ)" ] diff --git a/tutorial/source/svi_part_i.ipynb b/tutorial/source/svi_part_i.ipynb index 694ceb58dd..f5f9a8802a 100644 --- a/tutorial/source/svi_part_i.ipynb +++ b/tutorial/source/svi_part_i.ipynb @@ -265,7 +265,7 @@ "n_steps = 2 if smoke_test else 2000\n", "\n", "# enable validation (e.g. validate parameters of distributions)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "\n", "# clear the param store in case we're in a REPL\n", diff --git a/tutorial/source/svi_part_iii.ipynb b/tutorial/source/svi_part_iii.ipynb index bc40c8dec6..c39bec0be3 100644 --- a/tutorial/source/svi_part_iii.ipynb +++ b/tutorial/source/svi_part_iii.ipynb @@ -284,7 +284,7 @@ "import sys\n", "\n", "# enable validation (e.g. validate parameters of distributions)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "\n", "# this is for running the notebook in our testing framework\n", diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index 56e290ffcd..6bd283851e 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -52,7 +52,7 @@ "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True) # <---- This is always a good idea!\n", "\n", "# We'll ue this helper to check our models are correct.\n", diff --git a/tutorial/source/tracking_1d.ipynb b/tutorial/source/tracking_1d.ipynb index d809e5302b..e21e3f7ba6 100644 --- a/tutorial/source/tracking_1d.ipynb +++ b/tutorial/source/tracking_1d.ipynb @@ -30,7 +30,7 @@ "from pyro.optim import Adam\n", "\n", "%matplotlib inline\n", - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "smoke_test = ('CI' in os.environ)" ] diff --git a/tutorial/source/vae.ipynb b/tutorial/source/vae.ipynb index cfb6fafeee..4c274933c7 100644 --- a/tutorial/source/vae.ipynb +++ b/tutorial/source/vae.ipynb @@ -115,7 +115,7 @@ "metadata": {}, "outputs": [], "source": [ - "assert pyro.__version__.startswith('1.5.1')\n", + "assert pyro.__version__.startswith('1.5.2')\n", "pyro.enable_validation(True)\n", "pyro.distributions.enable_validation(False)\n", "pyro.set_rng_seed(0)\n", From 09bcbc079a74530bc2b31255b8e48639fb5b0300 Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 27 Apr 2021 17:07:54 +0200 Subject: [PATCH 02/35] Added sine skewed distribution and tests. --- pyro/distributions/__init__.py | 10 +-- pyro/distributions/sine_skewed.py | 41 ++++++++++++ tests/distributions/test_sine_skewed.py | 87 +++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 pyro/distributions/sine_skewed.py create mode 100644 tests/distributions/test_sine_skewed.py diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 2f179928de..347fe09944 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,10 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.distributions.torch_patch # noqa F403 -from pyro.distributions.torch import * # noqa F403 - -# isort: split - from pyro.distributions.affine_beta import AffineBeta from pyro.distributions.avf_mvn import AVFMultivariateNormal from pyro.distributions.coalescent import ( @@ -58,9 +54,11 @@ RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough, ) +from pyro.distributions.sine_skewed import SineSkewed from pyro.distributions.softlaplace import SoftLaplace from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable +from pyro.distributions.torch import * # noqa F403 from pyro.distributions.torch import __all__ as torch_dists from pyro.distributions.torch_distribution import ( ExpandedDistribution, @@ -80,9 +78,10 @@ ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, ) - from . import constraints, kl, transforms +# isort: split + __all__ = [ "AffineBeta", "AVFMultivariateNormal", @@ -128,6 +127,7 @@ "Rejector", "RelaxedBernoulliStraightThrough", "RelaxedOneHotCategoricalStraightThrough", + "SineSkewed", "SoftLaplace", "SpanningTree", "Stable", diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py new file mode 100644 index 0000000000..3e6d452ee3 --- /dev/null +++ b/pyro/distributions/sine_skewed.py @@ -0,0 +1,41 @@ +from math import pi + +import torch +from torch.distributions import Uniform + +from pyro.distributions import constraints + +from .torch_distribution import TorchDistribution + + +class SineSkewed(TorchDistribution): + """ Distribution for breaking pointwise symmetric distribution on the d-dimensional torus. + + ** References: ** + 1. Sine-skewed toroidal distributions and their application in protein bioinformatics + Ameijeiras-Alonso, J., Ley, C. (2019) + """ + arg_constraints = {'skewness': constraints.interval(-1., 1.)} + support = constraints.real + + def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): + assert torch.abs(skewness).sum() <= 1. + assert torch.Size((*base_density.event_shape, *base_density.batch_shape)) == skewness.shape + assert base_density.event_shape[-1] == 2 + self.base_density = base_density + self.skewness = skewness + super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) + + def __repr__(self): + return "" # TODO + + def sample(self, sample_shape=torch.Size()): + bd = self.base_density + ys = bd.sample(sample_shape) + mask = Uniform(0, 1.).sample(sample_shape) < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + + return torch.where(mask.view(*sample_shape, *(1 for _ in bd.event_shape)), ys, -ys + 2 * bd.mean) + + def log_prob(self, value): + bd = self.base_density + return bd.log_prob(value) + torch.log(1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).sum(-1)) diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py new file mode 100644 index 0000000000..c49588890a --- /dev/null +++ b/tests/distributions/test_sine_skewed.py @@ -0,0 +1,87 @@ +from math import pi + +import pytest +import torch + +import pyro +from pyro.distributions import Uniform, Normal +from pyro.distributions.sine_skewed import SineSkewed +from pyro.infer import Trace_ELBO, SVI +from pyro.infer.autoguide import AutoDelta +from pyro.optim import Adam +from tests.common import assert_equal + +BASE_DISTS = [(Uniform, ([-pi, -pi], [pi, pi]))] + + +def _skewness(batch_shape, event_shape): + n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) + skewness = torch.empty((*batch_shape, n)).view(-1, n) + tots = torch.zeros(batch_shape).view(-1) + for i in range(n): + skewness[..., i] = Uniform(0., 1 - tots).sample() + tots += skewness[..., i] + skewness = torch.where(Uniform(0, 1.).sample(skewness.shape) < .5, -skewness, skewness) + if (*batch_shape, *event_shape) == tuple(): + skewness = skewness.reshape((*batch_shape, *event_shape)) + else: + skewness = skewness.view(*batch_shape, *event_shape) + return skewness + + +@pytest.mark.parametrize('dist', BASE_DISTS) +def test_ss_log_prob(dist): + base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) + loc = Normal(0., 1.).sample(base_dist.mean.shape) % (2 * pi) - pi + + base_prob = base_dist.log_prob(loc) + skewness = _skewness(base_dist.batch_shape, base_dist.event_shape) + sine_prob = SineSkewed(base_dist, skewness).log_prob(loc) + assert_equal(base_prob + torch.log(1 + (skewness * torch.sin(loc - base_dist.mean)).sum()), sine_prob) + + +@pytest.mark.parametrize('dist', BASE_DISTS) +def test_ss_sample(dist): + base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) + + skewness_tar = _skewness(base_dist.batch_shape, base_dist.event_shape) + data = SineSkewed(base_dist, skewness_tar).sample((1000,)) + + def model(data, batch_shape, event_shape): + n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) + skewness = torch.empty((*batch_shape, n)).view(-1, n) + tots = torch.zeros(batch_shape).view(-1) + for i in range(n): + skewness[..., i] = pyro.sample(f'skew{i}', Uniform(0., 1 - tots)) + tots += skewness[..., i] + sign = pyro.sample('sign', Uniform(0., torch.ones(skewness.shape)).to_event(len(skewness.shape))) + skewness = torch.where(sign < .5, -skewness, skewness) + + if (*batch_shape, *event_shape) == tuple(): + skewness = skewness.reshape((*batch_shape, *event_shape)) + else: + skewness = skewness.view(*batch_shape, *event_shape) + + with pyro.plate("data", data.size(-len(data.size()))): + pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) + + pyro.clear_param_store() + adam = Adam({"lr": .1}) + guide = AutoDelta(model) + svi = SVI(model, guide, adam, loss=Trace_ELBO()) + + losses = [] + steps = 50 + for step in range(steps): + losses.append(svi.step(data, base_dist.batch_shape, base_dist.event_shape)) + + act_sign = pyro.param('AutoDelta.sign') + act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k]).T + act_skewness = torch.where(act_sign < .5, -act_skewness, act_skewness) + + if (*base_dist.batch_shape, *base_dist.event_shape) == tuple(): + act_skewness = act_skewness.reshape((*base_dist.batch_shape, *base_dist.event_shape)) + else: + act_skewness = act_skewness.view(*base_dist.batch_shape, *base_dist.event_shape) + + assert_equal(act_skewness, skewness_tar, 5e-2) From b7ae4d127423687ca6adc4131a5c862e29ae33fe Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 27 Apr 2021 21:22:33 +0200 Subject: [PATCH 03/35] Added repr. --- pyro/distributions/sine_skewed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 3e6d452ee3..dbeca88185 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -27,7 +27,12 @@ def __init__(self, base_density: TorchDistribution, skewness, validate_args=None super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) def __repr__(self): - return "" # TODO + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + + args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) + return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): bd = self.base_density From 85e352c34bd87db86769db8bb650cb2f3f064c43 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 27 Apr 2021 22:44:55 +0200 Subject: [PATCH 04/35] Fixed shape tests and minor fixes to docstring. --- docs/source/distributions.rst | 7 +++++++ pyro/distributions/__init__.py | 1 + pyro/distributions/sine_skewed.py | 20 +++++++++++++++----- tests/distributions/conftest.py | 8 +++++++- tests/distributions/test_distributions.py | 8 ++++++-- tests/distributions/test_sine_skewed.py | 4 ++-- 6 files changed, 38 insertions(+), 10 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 03f91b9054..aafaafc581 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -316,6 +316,13 @@ Rejector :undoc-members: :show-inheritance: +SineSkewed +---------- +.. autoclass:: pyro.distributions.SineSkewed + :members: + :undoc-members: + :show-inheritance: + SoftLaplace ------------- .. autoclass:: pyro.distributions.SoftLaplace diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 347fe09944..8d030ff73d 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -78,6 +78,7 @@ ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, ) + from . import constraints, kl, transforms # isort: split diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index dbeca88185..872434f806 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -9,18 +9,23 @@ class SineSkewed(TorchDistribution): - """ Distribution for breaking pointwise symmetric distribution on the d-dimensional torus. + """ Distribution for breaking pointwise-symmetry on distributions over the d-dimensional torus. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) + + :param base_density: base density on the d-dimensional torus; event_shape must be [..., 2] where + ``prod(event_shape[:-1]) == d``. + :param skewness: skewness of the distribution; must have same shape as base_density.event_shape, all values + must be in [-1,1] and ``abs(skewness).sum() <= 1``. """ arg_constraints = {'skewness': constraints.interval(-1., 1.)} - support = constraints.real + support = constraints.independent(constraints.real, 1) def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): - assert torch.abs(skewness).sum() <= 1. - assert torch.Size((*base_density.event_shape, *base_density.batch_shape)) == skewness.shape + assert torch.all(skewness.abs() <= 1) + assert torch.Size(base_density.event_shape) == skewness.shape assert base_density.event_shape[-1] == 2 self.base_density = base_density self.skewness = skewness @@ -38,9 +43,14 @@ def sample(self, sample_shape=torch.Size()): bd = self.base_density ys = bd.sample(sample_shape) mask = Uniform(0, 1.).sample(sample_shape) < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) - return torch.where(mask.view(*sample_shape, *(1 for _ in bd.event_shape)), ys, -ys + 2 * bd.mean) def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) bd = self.base_density return bd.log_prob(value) + torch.log(1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).sum(-1)) + + @classmethod + def infer_shapes(cls, **arg_shapes): + return arg_shapes['base_density'], arg_shapes['skewness'] diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 5bb43011f6..366bc6a933 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from math import pi import numpy as np import pytest @@ -15,7 +16,7 @@ ShapeAugmentedDirichlet, ShapeAugmentedGamma, ) -from tests.distributions.dist_fixture import Fixture +from tests.distributions.dist_fixture import Fixture, tensor_wrap class FoldedNormal(dist.FoldedDistribution): @@ -329,6 +330,11 @@ def __init__(self, rate, *, validate_args=None): {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, ]), + Fixture(pyro_dist=dist.SineSkewed, + examples=[{ + 'base_density': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), + 'skewness': [-pi/4, 0.], 'test_data': [pi/2, -2*pi/3] + }]) ] discrete_dists = [ diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index 86831eeda5..69d4d59921 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -44,8 +44,12 @@ def test_infer_shapes(dist): pytest.xfail(reason="cannot statically compute shape") for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) - arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () - for k, v in dist_params.items()} + if 'SineSkewed' == dist.pyro_dist.__name__: + arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else v.batch_shape + for k, v in dist_params.items()} + else: + arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () + for k, v in dist_params.items()} batch_shape, event_shape = dist.pyro_dist.infer_shapes(**arg_shapes) d = dist.pyro_dist(**dist_params) assert d.batch_shape == batch_shape diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index c49588890a..1b0ba94d23 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -4,9 +4,9 @@ import torch import pyro -from pyro.distributions import Uniform, Normal +from pyro.distributions import Normal, Uniform from pyro.distributions.sine_skewed import SineSkewed -from pyro.infer import Trace_ELBO, SVI +from pyro.infer import SVI, Trace_ELBO from pyro.infer.autoguide import AutoDelta from pyro.optim import Adam from tests.common import assert_equal From 7ee6643e44beead0d3fad914b4cd45a586924f67 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 27 Apr 2021 23:03:45 +0200 Subject: [PATCH 05/35] Fixed lint. --- tests/distributions/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index 69d4d59921..f31c703f42 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -46,10 +46,10 @@ def test_infer_shapes(dist): dist_params = dist.get_dist_params(idx) if 'SineSkewed' == dist.pyro_dist.__name__: arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else v.batch_shape - for k, v in dist_params.items()} + for k, v in dist_params.items()} else: arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () - for k, v in dist_params.items()} + for k, v in dist_params.items()} batch_shape, event_shape = dist.pyro_dist.infer_shapes(**arg_shapes) d = dist.pyro_dist(**dist_params) assert d.batch_shape == batch_shape From 789f550e9d8ce380e94055245c60b62b2ede1054 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 28 Apr 2021 09:10:35 +0200 Subject: [PATCH 06/35] Updated docstring with uniform prior. --- pyro/distributions/sine_skewed.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 872434f806..c901b28c09 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -9,7 +9,31 @@ class SineSkewed(TorchDistribution): - """ Distribution for breaking pointwise-symmetry on distributions over the d-dimensional torus. + """The Sine Skewed distribution is a distribution for breaking pointwise-symmetry on a base-distributions over + the d-dimensional torus. + + This distribution requires a base distribution on the torus. The parameter skewness can be inferred using + :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. The following will produce a uniform prior + over skewness,:: + + def model(data, batch_shape, event_shape): + n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) + skewness = torch.empty((*batch_shape, n)).view(-1, n) + tots = torch.zeros(batch_shape).view(-1) + for i in range(n): + skewness[..., i] = pyro.sample(f'skew{i}', Uniform(0., 1 - tots)) + tots += skewness[..., i] + sign = pyro.sample('sign', Uniform(0., torch.ones(skewness.shape)).to_event(len(skewness.shape))) + skewness = torch.where(sign < .5, -skewness, skewness) + + if (*batch_shape, *event_shape) == tuple(): + skewness = skewness.reshape((*batch_shape, *event_shape)) + else: + skewness = skewness.view(*batch_shape, *event_shape) + + .. note:: The base-distribution must be over a arbitrary dim torus. + + .. note:: ``skewness.abs().sum() <= 1.`` and ``(skewness.abs() <= 1).all()``. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics From be91d9a76e51d210ac03c511ffff5884adddc5bf Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 28 Apr 2021 13:51:11 +0200 Subject: [PATCH 07/35] Fixed skewness shape assertion. --- pyro/distributions/sine_skewed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index c901b28c09..7f1418ac5a 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -49,8 +49,8 @@ def model(data, batch_shape, event_shape): def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): assert torch.all(skewness.abs() <= 1) - assert torch.Size(base_density.event_shape) == skewness.shape assert base_density.event_shape[-1] == 2 + assert skewness.shape[-1] == 2 self.base_density = base_density self.skewness = skewness super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) From 2a200d391ef29a4170f811bf8f58dec6a7d5e5f8 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 30 Apr 2021 06:56:30 +0200 Subject: [PATCH 08/35] ensure `SineSkewed` is on the torus. --- pyro/distributions/sine_skewed.py | 35 +++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 872434f806..dd00acaa5b 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -9,7 +9,31 @@ class SineSkewed(TorchDistribution): - """ Distribution for breaking pointwise-symmetry on distributions over the d-dimensional torus. + """The Sine Skewed distribution is a distribution for breaking pointwise-symmetry on a base-distributions over + the d-dimensional torus. + + This distribution requires a base distribution on the torus. The parameter skewness can be inferred using + :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. The following will produce a uniform prior + over skewness,:: + + def model(data, batch_shape, event_shape): + n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) + skewness = torch.empty((*batch_shape, n)).view(-1, n) + tots = torch.zeros(batch_shape).view(-1) + for i in range(n): + skewness[..., i] = pyro.sample(f'skew{i}', Uniform(0., 1 - tots)) + tots += skewness[..., i] + sign = pyro.sample('sign', Uniform(0., torch.ones(skewness.shape)).to_event(len(skewness.shape))) + skewness = torch.where(sign < .5, -skewness, skewness) + + if (*batch_shape, *event_shape) == tuple(): + skewness = skewness.reshape((*batch_shape, *event_shape)) + else: + skewness = skewness.view(*batch_shape, *event_shape) + + .. note:: The base-distribution must be over a arbitrary dim torus. + + .. note:: ``skewness.abs().sum() <= 1.`` and ``(skewness.abs() <= 1).all()``. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics @@ -25,8 +49,8 @@ class SineSkewed(TorchDistribution): def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): assert torch.all(skewness.abs() <= 1) - assert torch.Size(base_density.event_shape) == skewness.shape assert base_density.event_shape[-1] == 2 + assert skewness.shape[-1] == 2 self.base_density = base_density self.skewness = skewness super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) @@ -42,8 +66,11 @@ def __repr__(self): def sample(self, sample_shape=torch.Size()): bd = self.base_density ys = bd.sample(sample_shape) - mask = Uniform(0, 1.).sample(sample_shape) < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) - return torch.where(mask.view(*sample_shape, *(1 for _ in bd.event_shape)), ys, -ys + 2 * bd.mean) + u = Uniform(0, 1.).sample(sample_shape + self.batch_shape) + mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + mask = mask.view(*sample_shape, *self.batch_shape, *(1 for _ in bd.event_shape)) + samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi + return samples def log_prob(self, value): if self._validate_args: From e7a1a743e37c89a22830e961ef1e980354ab02e8 Mon Sep 17 00:00:00 2001 From: Ola Date: Sat, 1 May 2021 13:49:35 +0200 Subject: [PATCH 09/35] Reverted `infer_shapes` in `sine_skewed` and `# isort: split` in `distributions/__init__.py` --- pyro/distributions/__init__.py | 6 ++++-- pyro/distributions/sine_skewed.py | 4 ---- tests/distributions/test_distributions.py | 10 +++------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 8d030ff73d..1b4cf607e8 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,6 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.distributions.torch_patch # noqa F403 +from pyro.distributions.torch import * + +# isort: split + from pyro.distributions.affine_beta import AffineBeta from pyro.distributions.avf_mvn import AVFMultivariateNormal from pyro.distributions.coalescent import ( @@ -81,8 +85,6 @@ from . import constraints, kl, transforms -# isort: split - __all__ = [ "AffineBeta", "AVFMultivariateNormal", diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index dd00acaa5b..44fb0fc8b6 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -77,7 +77,3 @@ def log_prob(self, value): self._validate_sample(value) bd = self.base_density return bd.log_prob(value) + torch.log(1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).sum(-1)) - - @classmethod - def infer_shapes(cls, **arg_shapes): - return arg_shapes['base_density'], arg_shapes['skewness'] diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index f31c703f42..82a0a6d6db 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -40,16 +40,12 @@ def test_support_shape(dist): def test_infer_shapes(dist): - if "LKJ" in dist.pyro_dist.__name__: + if "LKJ" in dist.pyro_dist.__name__ or "SineSkewed" == dist.pyro_dist.__name__: pytest.xfail(reason="cannot statically compute shape") for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) - if 'SineSkewed' == dist.pyro_dist.__name__: - arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else v.batch_shape - for k, v in dist_params.items()} - else: - arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () - for k, v in dist_params.items()} + arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () + for k, v in dist_params.items()} batch_shape, event_shape = dist.pyro_dist.infer_shapes(**arg_shapes) d = dist.pyro_dist(**dist_params) assert d.batch_shape == batch_shape From e802aa44701a0d11d4f3954964a64068fcff7fb3 Mon Sep 17 00:00:00 2001 From: Ola Date: Sat, 1 May 2021 14:00:46 +0200 Subject: [PATCH 10/35] Sketched `SineSkewed.expand` --- pyro/distributions/sine_skewed.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 44fb0fc8b6..fc0c312228 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -59,8 +59,8 @@ def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): @@ -77,3 +77,14 @@ def log_prob(self, value): self._validate_sample(value) bd = self.base_density return bd.log_prob(value) + torch.log(1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).sum(-1)) + + def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) + new = self._get_checked_instance(SineSkewed, _instance) + for name in self.arg_constraints: + setattr(new, name, getattr(self, name).expand(batch_shape)) + base_dist = self.base_density.expand(batch_shape, None) + new.base_density = base_dist + super(SineSkewed, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new From d1801b9bffeeb7cff015b84a1583340102ca4fc5 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 2 May 2021 19:36:37 +0200 Subject: [PATCH 11/35] Fixed `SineSkewed.log_prob`. --- pyro/distributions/sine_skewed.py | 20 ++++--- tests/distributions/test_sine_skewed.py | 77 ++++++++++++------------- 2 files changed, 50 insertions(+), 47 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index fc0c312228..fba003da50 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -59,15 +59,17 @@ def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): bd = self.base_density ys = bd.sample(sample_shape) u = Uniform(0, 1.).sample(sample_shape + self.batch_shape) - mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + + mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).view(*(sample_shape + self.batch_shape), + -1).sum(-1) mask = mask.view(*sample_shape, *self.batch_shape, *(1 for _ in bd.event_shape)) samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples @@ -75,16 +77,20 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) + flat_event = torch.tensor(self.event_shape).prod() bd = self.base_density - return bd.log_prob(value) + torch.log(1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).sum(-1)) + bd_prob = bd.log_prob(value) + sine_prob = torch.log( + 1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).reshape((-1, flat_event)).sum(-1)) + return (bd_prob.view((-1)) + sine_prob).view(bd_prob.shape) def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(SineSkewed, _instance) - for name in self.arg_constraints: - setattr(new, name, getattr(self, name).expand(batch_shape)) base_dist = self.base_density.expand(batch_shape, None) new.base_density = base_dist - super(SineSkewed, new).__init__(batch_shape, validate_args=False) + for name in self.arg_constraints: + setattr(new, name, getattr(self, name).expand((*batch_shape, *self.event_shape))) + super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=None) new._validate_args = self._validate_args return new diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index 1b0ba94d23..0ea1dcb07f 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -4,10 +4,8 @@ import torch import pyro -from pyro.distributions import Normal, Uniform -from pyro.distributions.sine_skewed import SineSkewed +from pyro.distributions import Normal, SineSkewed, Uniform, constraints from pyro.infer import SVI, Trace_ELBO -from pyro.infer.autoguide import AutoDelta from pyro.optim import Adam from tests.common import assert_equal @@ -16,12 +14,14 @@ def _skewness(batch_shape, event_shape): n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) - skewness = torch.empty((*batch_shape, n)).view(-1, n) - tots = torch.zeros(batch_shape).view(-1) - for i in range(n): - skewness[..., i] = Uniform(0., 1 - tots).sample() - tots += skewness[..., i] - skewness = torch.where(Uniform(0, 1.).sample(skewness.shape) < .5, -skewness, skewness) + skewness = torch.zeros((*batch_shape, n)) + while True: + for i in range(n): + max_ = 1. - skewness.abs().sum(-1) + if torch.any(max_ < 1e-15): + continue + skewness[..., i] = Uniform(-max_, max_).sample() + break if (*batch_shape, *event_shape) == tuple(): skewness = skewness.reshape((*batch_shape, *event_shape)) else: @@ -29,15 +29,28 @@ def _skewness(batch_shape, event_shape): return skewness +@pytest.mark.parametrize('batch_shape', + [(), (1,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) +@pytest.mark.parametrize('event_dim', [0, 1, 2, 3]) @pytest.mark.parametrize('dist', BASE_DISTS) -def test_ss_log_prob(dist): - base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) - loc = Normal(0., 1.).sample(base_dist.mean.shape) % (2 * pi) - pi +def test_ss_multidim_log_prob(event_dim, batch_shape, dist): + if len(batch_shape) >= event_dim and event_dim: + base_dist = dist[0](*(torch.tensor(param).expand(*batch_shape, 2) for param in dist[1])).to_event(event_dim + 1) + assert base_dist.batch_shape == batch_shape[:-event_dim] + else: + base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) + base_dist = base_dist.expand(batch_shape) + assert base_dist.batch_shape == batch_shape + + loc = Normal(0., 1.).sample(base_dist.event_shape) % (2 * pi) - pi base_prob = base_dist.log_prob(loc) skewness = _skewness(base_dist.batch_shape, base_dist.event_shape) - sine_prob = SineSkewed(base_dist, skewness).log_prob(loc) - assert_equal(base_prob + torch.log(1 + (skewness * torch.sin(loc - base_dist.mean)).sum()), sine_prob) + ss = SineSkewed(base_dist, skewness) + assert_equal(base_prob + torch.log( + 1 + (skewness * torch.sin(loc - base_dist.mean)).view(*base_dist.batch_shape, -1).sum(-1)), + ss.log_prob(loc)) + assert_equal(ss.sample().shape, torch.Size((*batch_shape, 2))) @pytest.mark.parametrize('dist', BASE_DISTS) @@ -47,41 +60,25 @@ def test_ss_sample(dist): skewness_tar = _skewness(base_dist.batch_shape, base_dist.event_shape) data = SineSkewed(base_dist, skewness_tar).sample((1000,)) - def model(data, batch_shape, event_shape): - n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) - skewness = torch.empty((*batch_shape, n)).view(-1, n) - tots = torch.zeros(batch_shape).view(-1) - for i in range(n): - skewness[..., i] = pyro.sample(f'skew{i}', Uniform(0., 1 - tots)) - tots += skewness[..., i] - sign = pyro.sample('sign', Uniform(0., torch.ones(skewness.shape)).to_event(len(skewness.shape))) - skewness = torch.where(sign < .5, -skewness, skewness) - - if (*batch_shape, *event_shape) == tuple(): - skewness = skewness.reshape((*batch_shape, *event_shape)) - else: - skewness = skewness.view(*batch_shape, *event_shape) + def model(data, batch_shape): + skew0 = pyro.param('skew0', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1)) + skew1 = pyro.param('skew1', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1)) + skewness = torch.stack((skew0, skew1), dim=-1) with pyro.plate("data", data.size(-len(data.size()))): pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) + def guide(data, batch_shape): + pass + pyro.clear_param_store() adam = Adam({"lr": .1}) - guide = AutoDelta(model) svi = SVI(model, guide, adam, loss=Trace_ELBO()) losses = [] - steps = 50 + steps = 80 for step in range(steps): - losses.append(svi.step(data, base_dist.batch_shape, base_dist.event_shape)) - - act_sign = pyro.param('AutoDelta.sign') - act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k]).T - act_skewness = torch.where(act_sign < .5, -act_skewness, act_skewness) - - if (*base_dist.batch_shape, *base_dist.event_shape) == tuple(): - act_skewness = act_skewness.reshape((*base_dist.batch_shape, *base_dist.event_shape)) - else: - act_skewness = act_skewness.view(*base_dist.batch_shape, *base_dist.event_shape) + losses.append(svi.step(data, base_dist.batch_shape)) + act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1) assert_equal(act_skewness, skewness_tar, 5e-2) From 3b44ebe58449d245e22954c0a18ab7987f20ce01 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 2 May 2021 19:45:55 +0200 Subject: [PATCH 12/35] Added pep exception to `distributions.__init__` --- pyro/distributions/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 1b4cf607e8..c4e2ccf1b1 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.distributions.torch_patch # noqa F403 -from pyro.distributions.torch import * +from pyro.distributions.torch import * # noqa F403 # isort: split From 84ac72e87a167e89bce499b269a56ebe23f50ab2 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 3 May 2021 11:29:30 +0200 Subject: [PATCH 13/35] Fixed `SineSkewed` on cuda. --- pyro/distributions/sine_skewed.py | 9 +++++++-- tests/distributions/test_cuda.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index fba003da50..40b4a857b4 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -51,6 +51,11 @@ def __init__(self, base_density: TorchDistribution, skewness, validate_args=None assert torch.all(skewness.abs() <= 1) assert base_density.event_shape[-1] == 2 assert skewness.shape[-1] == 2 + + if base_density.mean.device != skewness.device: + raise ValueError(f"base_density: {base_density.__class__.__name__} and {self.__class__.__name__} " + f"must be on same device.") + self.base_density = base_density self.skewness = skewness super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) @@ -66,7 +71,7 @@ def __repr__(self): def sample(self, sample_shape=torch.Size()): bd = self.base_density ys = bd.sample(sample_shape) - u = Uniform(0, 1.).sample(sample_shape + self.batch_shape) + u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).view(*(sample_shape + self.batch_shape), -1).sum(-1) @@ -77,7 +82,7 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - flat_event = torch.tensor(self.event_shape).prod() + flat_event = torch.tensor(self.event_shape, device=value.device).prod() bd = self.base_density bd_prob = bd.log_prob(value) sine_prob = torch.log( diff --git a/tests/distributions/test_cuda.py b/tests/distributions/test_cuda.py index 9cb887b467..fdccc66ef9 100644 --- a/tests/distributions/test_cuda.py +++ b/tests/distributions/test_cuda.py @@ -15,6 +15,8 @@ @requires_cuda def test_sample(dist): + if dist.pyro_dist.__name__ == 'SineSkewed': + pytest.xfail(reason="Fixture with distribution param not handled.") for idx in range(len(dist.dist_params)): # Compute CPU value. @@ -77,6 +79,8 @@ def test_rsample(dist): @requires_cuda def test_log_prob(dist): + if dist.pyro_dist.__name__ == 'SineSkewed': + pytest.xfail(reason="Fixture with distribution param not handled.") for idx in range(len(dist.dist_params)): # Compute CPU value. From c92ef62dd0bb18495e5f20e31042e4129bb53619 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 5 May 2021 13:47:25 +0200 Subject: [PATCH 14/35] Restricted `event_dim=2` --- pyro/distributions/sine_skewed.py | 31 ++++++++++++------------- tests/distributions/test_sine_skewed.py | 29 ++++++++++++----------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 40b4a857b4..963a924621 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -31,41 +31,40 @@ def model(data, batch_shape, event_shape): else: skewness = skewness.view(*batch_shape, *event_shape) - .. note:: The base-distribution must be over a arbitrary dim torus. + .. note:: An event in the base-distribution must be on a d-torus so the event_shape must be (d,2) or (2,). - .. note:: ``skewness.abs().sum() <= 1.`` and ``(skewness.abs() <= 1).all()``. + .. note:: For the skewness parameter it must hold that the sum of the absolute value of its weights for an event + must be less than or equal to one. See eq. 2.1 in [1]. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) - :param base_density: base density on the d-dimensional torus; event_shape must be [..., 2] where - ``prod(event_shape[:-1]) == d``. - :param skewness: skewness of the distribution; must have same shape as base_density.event_shape, all values - must be in [-1,1] and ``abs(skewness).sum() <= 1``. + :param base_density: base density on the d-dimensional torus. + :param skewness: skewness of the distribution. """ arg_constraints = {'skewness': constraints.interval(-1., 1.)} support = constraints.independent(constraints.real, 1) def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): - assert torch.all(skewness.abs() <= 1) - assert base_density.event_shape[-1] == 2 - assert skewness.shape[-1] == 2 - - if base_density.mean.device != skewness.device: - raise ValueError(f"base_density: {base_density.__class__.__name__} and {self.__class__.__name__} " - f"must be on same device.") + assert base_density.event_shape[-1] == 2 and len(base_density.event_shape) <= 2 + assert base_density.shape()[len(base_density.shape()) - len(skewness.shape):] == skewness.shape + assert (skewness.abs().sum(-1 if len(skewness.shape) == 1 else (-2, -1)) <= 1).all() self.base_density = base_density - self.skewness = skewness + self.skewness = skewness.broadcast_to(base_density.shape()) super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) + if self._validate_args and base_density.mean.device != skewness.device: + raise ValueError(f"base_density: {base_density.__class__.__name__} and {self.__class__.__name__} " + f"must be on same device.") + def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index 0ea1dcb07f..e810d47fc3 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -12,26 +12,27 @@ BASE_DISTS = [(Uniform, ([-pi, -pi], [pi, pi]))] -def _skewness(batch_shape, event_shape): - n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) - skewness = torch.zeros((*batch_shape, n)) - while True: - for i in range(n): +def _skewness(event_shape): + skewness = torch.zeros(event_shape.numel()) + done = False + while not done: + for i in range(event_shape.numel()): max_ = 1. - skewness.abs().sum(-1) if torch.any(max_ < 1e-15): - continue - skewness[..., i] = Uniform(-max_, max_).sample() - break - if (*batch_shape, *event_shape) == tuple(): - skewness = skewness.reshape((*batch_shape, *event_shape)) + break + skewness[i] = Uniform(-max_, max_).sample() + done = not torch.any(max_ < 1e-15) + + if event_shape == tuple(): + skewness = skewness.reshape(event_shape) else: - skewness = skewness.view(*batch_shape, *event_shape) + skewness = skewness.view(event_shape) return skewness @pytest.mark.parametrize('batch_shape', [(), (1,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) -@pytest.mark.parametrize('event_dim', [0, 1, 2, 3]) +@pytest.mark.parametrize('event_dim', [0, 1]) @pytest.mark.parametrize('dist', BASE_DISTS) def test_ss_multidim_log_prob(event_dim, batch_shape, dist): if len(batch_shape) >= event_dim and event_dim: @@ -45,7 +46,7 @@ def test_ss_multidim_log_prob(event_dim, batch_shape, dist): loc = Normal(0., 1.).sample(base_dist.event_shape) % (2 * pi) - pi base_prob = base_dist.log_prob(loc) - skewness = _skewness(base_dist.batch_shape, base_dist.event_shape) + skewness = _skewness(base_dist.event_shape) ss = SineSkewed(base_dist, skewness) assert_equal(base_prob + torch.log( 1 + (skewness * torch.sin(loc - base_dist.mean)).view(*base_dist.batch_shape, -1).sum(-1)), @@ -57,7 +58,7 @@ def test_ss_multidim_log_prob(event_dim, batch_shape, dist): def test_ss_sample(dist): base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) - skewness_tar = _skewness(base_dist.batch_shape, base_dist.event_shape) + skewness_tar = _skewness(base_dist.event_shape) data = SineSkewed(base_dist, skewness_tar).sample((1000,)) def model(data, batch_shape): From 906211ad23c004d3de81f0fa2ea9ea0fb16fe6d7 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 5 May 2021 14:14:45 +0200 Subject: [PATCH 15/35] Fixed doc_string and updated tests. --- pyro/distributions/sine_skewed.py | 35 ++++++++++--------------- tests/distributions/test_sine_skewed.py | 34 ++++++++++++++++++------ 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 963a924621..01e5ab2daa 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -9,27 +9,20 @@ class SineSkewed(TorchDistribution): - """The Sine Skewed distribution is a distribution for breaking pointwise-symmetry on a base-distributions over + """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over the d-dimensional torus. - This distribution requires a base distribution on the torus. The parameter skewness can be inferred using + This distribution requires the base distribution on a torus. The parameter skewness can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. The following will produce a uniform prior - over skewness,:: - - def model(data, batch_shape, event_shape): - n = torch.prod(torch.tensor(event_shape), -1, dtype=torch.int) - skewness = torch.empty((*batch_shape, n)).view(-1, n) - tots = torch.zeros(batch_shape).view(-1) - for i in range(n): - skewness[..., i] = pyro.sample(f'skew{i}', Uniform(0., 1 - tots)) - tots += skewness[..., i] - sign = pyro.sample('sign', Uniform(0., torch.ones(skewness.shape)).to_event(len(skewness.shape))) - skewness = torch.where(sign < .5, -skewness, skewness) - - if (*batch_shape, *event_shape) == tuple(): - skewness = skewness.reshape((*batch_shape, *event_shape)) - else: - skewness = skewness.view(*batch_shape, *event_shape) + over skewness for the 1-torus,:: + + def model(...): + ... + skewness_phi = pyro.sample(f'skewness_phi}', Uniform(skewness.abs().sum(), 1 - tots)) + psi_bound = 1 - skewness_phi.abs() + skewness_psi = pyro.sample(f'skewness_psi}', Uniform(-psi_bound, psi_bound) + skewness = torch.stack((skewness_phi, skewness_psi), dim=0) + ... .. note:: An event in the base-distribution must be on a d-torus so the event_shape must be (d,2) or (2,). @@ -40,7 +33,7 @@ def model(data, batch_shape, event_shape): 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) - :param base_density: base density on the d-dimensional torus. + :param base_density: base density on a d-dimensional torus. :param skewness: skewness of the distribution. """ arg_constraints = {'skewness': constraints.interval(-1., 1.)} @@ -81,11 +74,11 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - flat_event = torch.tensor(self.event_shape, device=value.device).prod() bd = self.base_density bd_prob = bd.log_prob(value) sine_prob = torch.log( - 1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).reshape((-1, flat_event)).sum(-1)) + 1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).reshape((-1, self.event_shape.numel())).sum( + -1)) return (bd_prob.view((-1)) + sine_prob).view(bd_prob.shape) def expand(self, batch_shape, _instance=None): diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index e810d47fc3..979716e61a 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -2,6 +2,7 @@ import pytest import torch +from numpy.testing import assert_raises import pyro from pyro.distributions import Normal, SineSkewed, Uniform, constraints @@ -30,28 +31,45 @@ def _skewness(event_shape): return skewness -@pytest.mark.parametrize('batch_shape', +@pytest.mark.parametrize('expand_shape', [(), (1,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) @pytest.mark.parametrize('event_dim', [0, 1]) @pytest.mark.parametrize('dist', BASE_DISTS) -def test_ss_multidim_log_prob(event_dim, batch_shape, dist): - if len(batch_shape) >= event_dim and event_dim: - base_dist = dist[0](*(torch.tensor(param).expand(*batch_shape, 2) for param in dist[1])).to_event(event_dim + 1) - assert base_dist.batch_shape == batch_shape[:-event_dim] +def test_ss_multidim_log_prob(event_dim, expand_shape, dist): + if len(expand_shape) >= event_dim and event_dim: + base_dist = dist[0](*(torch.tensor(param).expand(*expand_shape, 2) for param in dist[1])).to_event(event_dim + 1) + assert base_dist.batch_shape == expand_shape[:-event_dim] else: base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) - base_dist = base_dist.expand(batch_shape) - assert base_dist.batch_shape == batch_shape + base_dist = base_dist.expand(expand_shape) + assert base_dist.batch_shape == expand_shape loc = Normal(0., 1.).sample(base_dist.event_shape) % (2 * pi) - pi base_prob = base_dist.log_prob(loc) skewness = _skewness(base_dist.event_shape) + ss = SineSkewed(base_dist, skewness) assert_equal(base_prob + torch.log( 1 + (skewness * torch.sin(loc - base_dist.mean)).view(*base_dist.batch_shape, -1).sum(-1)), ss.log_prob(loc)) - assert_equal(ss.sample().shape, torch.Size((*batch_shape, 2))) + assert_equal(ss.sample().shape, torch.Size((*expand_shape, 2))) + + +def test_ss_invalid_event_shape(): + base_dist = Uniform(-1, 1).expand((3, 3, 2)).to_event(3) + assert_raises(AssertionError, SineSkewed, base_dist, torch.zeros(base_dist.shape())) + base_dist = Uniform(-1, 1).expand((5,)).to_event(1) + assert_raises(AssertionError, SineSkewed, base_dist, torch.zeros(base_dist.shape())) + + +def test_ss_skewness_too_high(): + base_dist = Uniform(-1, 1).expand((2,)).to_event(1) + assert_raises(AssertionError, SineSkewed, base_dist, torch.ones(base_dist.shape())) + base_dist = Uniform(-1, 1).expand((1, 2,)).to_event(1) + assert_raises(AssertionError, SineSkewed, base_dist, .51 * torch.ones(base_dist.shape())) + base_dist = Uniform(-1, 1).expand((2, 2,)).to_event(1) + assert_raises(AssertionError, SineSkewed, base_dist, .5 * torch.ones(base_dist.shape())) @pytest.mark.parametrize('dist', BASE_DISTS) From e23753127119996c01227605c582ef92e938b2af Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 5 May 2021 14:17:07 +0200 Subject: [PATCH 16/35] fixed linting --- pyro/distributions/sine_skewed.py | 4 ++-- tests/distributions/test_sine_skewed.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 01e5ab2daa..16ce9a1951 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -56,8 +56,8 @@ def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index 979716e61a..ac3e4ab7b5 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -37,7 +37,8 @@ def _skewness(event_shape): @pytest.mark.parametrize('dist', BASE_DISTS) def test_ss_multidim_log_prob(event_dim, expand_shape, dist): if len(expand_shape) >= event_dim and event_dim: - base_dist = dist[0](*(torch.tensor(param).expand(*expand_shape, 2) for param in dist[1])).to_event(event_dim + 1) + base_dist = dist[0](*(torch.tensor(param).expand(*expand_shape, 2) for param in dist[1])) + base_dist = base_dist.to_event(event_dim + 1) assert base_dist.batch_shape == expand_shape[:-event_dim] else: base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) From 5e4020a71e20a196607900fa8e0aaf99f22344d3 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 5 May 2021 14:35:55 +0200 Subject: [PATCH 17/35] fixed arg_constraints --- pyro/distributions/sine_skewed.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 16ce9a1951..851016d287 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -13,20 +13,20 @@ class SineSkewed(TorchDistribution): the d-dimensional torus. This distribution requires the base distribution on a torus. The parameter skewness can be inferred using - :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. The following will produce a uniform prior + :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. For example, the following will produce a uniform prior over skewness for the 1-torus,:: def model(...): ... - skewness_phi = pyro.sample(f'skewness_phi}', Uniform(skewness.abs().sum(), 1 - tots)) + skewness_phi = pyro.sample(f'skewness_phi', Uniform(skewness.abs().sum(), 1 - tots)) psi_bound = 1 - skewness_phi.abs() - skewness_psi = pyro.sample(f'skewness_psi}', Uniform(-psi_bound, psi_bound) + skewness_psi = pyro.sample(f'skewness_psi', Uniform(-psi_bound, psi_bound) skewness = torch.stack((skewness_phi, skewness_psi), dim=0) ... - .. note:: An event in the base-distribution must be on a d-torus so the event_shape must be (d,2) or (2,). + .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d, 2) or (2,). - .. note:: For the skewness parameter it must hold that the sum of the absolute value of its weights for an event + .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1]. ** References: ** @@ -36,7 +36,8 @@ def model(...): :param base_density: base density on a d-dimensional torus. :param skewness: skewness of the distribution. """ - arg_constraints = {'skewness': constraints.interval(-1., 1.)} + arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)} + support = constraints.independent(constraints.real, 1) def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): @@ -56,8 +57,8 @@ def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) + if self.__dict__[p].numel() == 1 + else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): From bd93a2b6279ee77611f56b3afc81c4d990682b0c Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 5 May 2021 14:41:52 +0200 Subject: [PATCH 18/35] cleaned __repr__ --- pyro/distributions/sine_skewed.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 851016d287..66f82115ee 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -50,16 +50,14 @@ def __init__(self, base_density: TorchDistribution, skewness, validate_args=None super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) if self._validate_args and base_density.mean.device != skewness.device: - raise ValueError(f"base_density: {base_density.__class__.__name__} and {self.__class__.__name__} " + raise ValueError(f"base_density: {base_density.__class__.__name__} and SineSkewed " f"must be on same device.") def __repr__(self): - param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] - - args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].numel() == 1 - else self.__dict__[p].size()) for p in param_names]) - return self.__class__.__name__ + '(' + f'base_density: {self.base_density.__repr__()}, ' + args_string + ')' + args_string = ', '.join(['{}: {}'.format(p, getattr(self, p) + if getattr(self, p).numel() == 1 + else getattr(self, p).size()) for p in self.arg_constraints.keys()]) + return self.__class__.__name__ + '(' + f'base_density: {str(self.base_density)}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): bd = self.base_density From 4646ce63078b388efa54096511175d52dc9f6999 Mon Sep 17 00:00:00 2001 From: Ola Date: Wed, 5 May 2021 23:03:33 +0200 Subject: [PATCH 19/35] Fixed comments. --- pyro/distributions/sine_skewed.py | 38 +++++++++++++++++-------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 66f82115ee..674bedb84c 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -3,7 +3,7 @@ import torch from torch.distributions import Uniform -from pyro.distributions import constraints +from pyro.distributions import constraints, sum_rightmost from .torch_distribution import TorchDistribution @@ -33,40 +33,45 @@ def model(...): 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) - :param base_density: base density on a d-dimensional torus. + :param base_dist: base density on a d-dimensional torus. :param skewness: skewness of the distribution. """ arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)} support = constraints.independent(constraints.real, 1) - def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): - assert base_density.event_shape[-1] == 2 and len(base_density.event_shape) <= 2 - assert base_density.shape()[len(base_density.shape()) - len(skewness.shape):] == skewness.shape + def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): + assert base_dist.event_shape[-1] == 2 and len(base_dist.event_shape) <= 2 assert (skewness.abs().sum(-1 if len(skewness.shape) == 1 else (-2, -1)) <= 1).all() - self.base_density = base_density - self.skewness = skewness.broadcast_to(base_density.shape()) - super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) + self.base_density = base_dist + self.skewness = skewness.broadcast_to(base_dist.shape()) + super().__init__(base_dist.batch_shape, base_dist.event_shape, validate_args=validate_args) - if self._validate_args and base_density.mean.device != skewness.device: - raise ValueError(f"base_density: {base_density.__class__.__name__} and SineSkewed " + if self._validate_args and base_dist.mean.device != skewness.device: + raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " f"must be on same device.") def __repr__(self): args_string = ', '.join(['{}: {}'.format(p, getattr(self, p) - if getattr(self, p).numel() == 1 - else getattr(self, p).size()) for p in self.arg_constraints.keys()]) + if getattr(self, p).numel() == 1 + else getattr(self, p).size()) for p in self.arg_constraints.keys()]) return self.__class__.__name__ + '(' + f'base_density: {str(self.base_density)}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): + """ Sampling method described in section 2.3 of [1]. + + ** References: ** + 1. Sine-skewed toroidal distributions and their application in protein bioinformatics + Ameijeiras-Alonso, J., Ley, C. (2019) + """ bd = self.base_density ys = bd.sample(sample_shape) u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) - mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).view(*(sample_shape + self.batch_shape), - -1).sum(-1) - mask = mask.view(*sample_shape, *self.batch_shape, *(1 for _ in bd.event_shape)) + # Equation in step 3 + mask = u < 1. + sum_rightmost(self.skewness * torch.sin((ys - bd.mean) % (2 * pi)), self.event_shape) + mask = mask.view(*sample_shape, *self.batch_shape, *(1 for _ in self.event_shape)) samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples @@ -76,8 +81,7 @@ def log_prob(self, value): bd = self.base_density bd_prob = bd.log_prob(value) sine_prob = torch.log( - 1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).reshape((-1, self.event_shape.numel())).sum( - -1)) + 1 + sum_rightmost(self.skewness * torch.sin((value - bd.mean) % (2 * pi)), self.event_shape)) return (bd_prob.view((-1)) + sine_prob).view(bd_prob.shape) def expand(self, batch_shape, _instance=None): From ffe50e6ca07873adca4efd45f5f99a3c9969a1ba Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 7 May 2021 14:07:08 +0200 Subject: [PATCH 20/35] Fixed `n_dim=1` and updated `test_sine_skewed`; missing updated fixtures. --- pyro/distributions/sine_skewed.py | 50 ++++++++++---------- tests/distributions/test_sine_skewed.py | 61 ++++++++----------------- 2 files changed, 42 insertions(+), 69 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 66f82115ee..55971d8c35 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -10,11 +10,11 @@ class SineSkewed(TorchDistribution): """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over - the d-dimensional torus. + the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the + 1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile). - This distribution requires the base distribution on a torus. The parameter skewness can be inferred using - :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. For example, the following will produce a uniform prior - over skewness for the 1-torus,:: + The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. + For example, the following will produce a uniform prior over skewness for the 2-torus,:: def model(...): ... @@ -24,7 +24,13 @@ def model(...): skewness = torch.stack((skewness_phi, skewness_psi), dim=0) ... - .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d, 2) or (2,). + In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a + latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist + cannot be reparameterized). For the 1-torus (circle) with a + :class:`~pyro.distribution.ProjectedNormal` base distribution inference is tractable using ``poutine.reparam`` as + outlined in :class:`~pyro.distribution.ProjectedNormal`. + + .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,). .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1]. @@ -33,24 +39,22 @@ def model(...): 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) - :param base_density: base density on a d-dimensional torus. + :param base_dist: base density on a d-dimensional torus. :param skewness: skewness of the distribution. """ arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)} support = constraints.independent(constraints.real, 1) - def __init__(self, base_density: TorchDistribution, skewness, validate_args=None): - assert base_density.event_shape[-1] == 2 and len(base_density.event_shape) <= 2 - assert base_density.shape()[len(base_density.shape()) - len(skewness.shape):] == skewness.shape - assert (skewness.abs().sum(-1 if len(skewness.shape) == 1 else (-2, -1)) <= 1).all() + def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): + assert skewness.abs().sum(-1) <= 1., "total skewness weight cannot exceed one." - self.base_density = base_density - self.skewness = skewness.broadcast_to(base_density.shape()) - super().__init__(base_density.batch_shape, base_density.event_shape, validate_args=validate_args) + self.base_density = base_dist + self.skewness = skewness.broadcast_to(base_dist.shape()) + super().__init__(base_dist.batch_shape, base_dist.event_shape, validate_args=validate_args) - if self._validate_args and base_density.mean.device != skewness.device: - raise ValueError(f"base_density: {base_density.__class__.__name__} and SineSkewed " + if self._validate_args and base_dist.mean.device != skewness.device: + raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " f"must be on same device.") def __repr__(self): @@ -64,29 +68,23 @@ def sample(self, sample_shape=torch.Size()): ys = bd.sample(sample_shape) u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) - mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).view(*(sample_shape + self.batch_shape), - -1).sum(-1) - mask = mask.view(*sample_shape, *self.batch_shape, *(1 for _ in bd.event_shape)) + mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + mask = mask[..., None] samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples def log_prob(self, value): if self._validate_args: self._validate_sample(value) - bd = self.base_density - bd_prob = bd.log_prob(value) - sine_prob = torch.log( - 1 + (self.skewness * torch.sin((value - bd.mean) % (2 * pi))).reshape((-1, self.event_shape.numel())).sum( - -1)) - return (bd_prob.view((-1)) + sine_prob).view(bd_prob.shape) + skew_prob = torch.log(1 + (self.skewness * torch.sin((value - self.base_density.mean) % (2 * pi))).sum(-1)) + return self.base_density.log_prob(value) + skew_prob def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(SineSkewed, _instance) base_dist = self.base_density.expand(batch_shape, None) new.base_density = base_dist - for name in self.arg_constraints: - setattr(new, name, getattr(self, name).expand((*batch_shape, *self.event_shape))) + new.skewness = self.skewness.expand(batch_shape) super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=None) new._validate_args = self._validate_args return new diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index ac3e4ab7b5..82071e50b7 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -2,21 +2,20 @@ import pytest import torch -from numpy.testing import assert_raises import pyro -from pyro.distributions import Normal, SineSkewed, Uniform, constraints +from pyro.distributions import Normal, SineSkewed, Uniform, VonMises, constraints from pyro.infer import SVI, Trace_ELBO from pyro.optim import Adam from tests.common import assert_equal -BASE_DISTS = [(Uniform, ([-pi, -pi], [pi, pi]))] +BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))] -def _skewness(event_shape): +def _skewness(event_shape, max_sum=1.): skewness = torch.zeros(event_shape.numel()) done = False - while not done: + while not done and skewness.abs().sum(-1) > max_sum: for i in range(event_shape.numel()): max_ = 1. - skewness.abs().sum(-1) if torch.any(max_ < 1e-15): @@ -32,59 +31,35 @@ def _skewness(event_shape): @pytest.mark.parametrize('expand_shape', - [(), (1,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) -@pytest.mark.parametrize('event_dim', [0, 1]) + [(1,), (2,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) @pytest.mark.parametrize('dist', BASE_DISTS) -def test_ss_multidim_log_prob(event_dim, expand_shape, dist): - if len(expand_shape) >= event_dim and event_dim: - base_dist = dist[0](*(torch.tensor(param).expand(*expand_shape, 2) for param in dist[1])) - base_dist = base_dist.to_event(event_dim + 1) - assert base_dist.batch_shape == expand_shape[:-event_dim] - else: - base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) - base_dist = base_dist.expand(expand_shape) - assert base_dist.batch_shape == expand_shape +def test_ss_multidim_log_prob(expand_shape, dist): + base_dist = dist[0](*(torch.tensor(param).expand(expand_shape) for param in dist[1])).to_event(1) - loc = Normal(0., 1.).sample(base_dist.event_shape) % (2 * pi) - pi + loc = base_dist.sample((10,)) + Normal(0., 1e-3).sample() base_prob = base_dist.log_prob(loc) skewness = _skewness(base_dist.event_shape) ss = SineSkewed(base_dist, skewness) - assert_equal(base_prob + torch.log( - 1 + (skewness * torch.sin(loc - base_dist.mean)).view(*base_dist.batch_shape, -1).sum(-1)), - ss.log_prob(loc)) - assert_equal(ss.sample().shape, torch.Size((*expand_shape, 2))) - - -def test_ss_invalid_event_shape(): - base_dist = Uniform(-1, 1).expand((3, 3, 2)).to_event(3) - assert_raises(AssertionError, SineSkewed, base_dist, torch.zeros(base_dist.shape())) - base_dist = Uniform(-1, 1).expand((5,)).to_event(1) - assert_raises(AssertionError, SineSkewed, base_dist, torch.zeros(base_dist.shape())) - - -def test_ss_skewness_too_high(): - base_dist = Uniform(-1, 1).expand((2,)).to_event(1) - assert_raises(AssertionError, SineSkewed, base_dist, torch.ones(base_dist.shape())) - base_dist = Uniform(-1, 1).expand((1, 2,)).to_event(1) - assert_raises(AssertionError, SineSkewed, base_dist, .51 * torch.ones(base_dist.shape())) - base_dist = Uniform(-1, 1).expand((2, 2,)).to_event(1) - assert_raises(AssertionError, SineSkewed, base_dist, .5 * torch.ones(base_dist.shape())) + assert_equal(base_prob.shape, ss.log_prob(loc).shape) + assert_equal(ss.sample().shape, torch.Size(expand_shape)) @pytest.mark.parametrize('dist', BASE_DISTS) -def test_ss_sample(dist): - base_dist = dist[0](*(torch.tensor(param) for param in dist[1])).to_event(1) +@pytest.mark.parametrize('dim', [1, 2]) +def test_ss_mle(dim, dist): + base_dist = dist[0](*(torch.tensor(param).expand((dim,)) for param in dist[1])).to_event(1) skewness_tar = _skewness(base_dist.event_shape) data = SineSkewed(base_dist, skewness_tar).sample((1000,)) def model(data, batch_shape): - skew0 = pyro.param('skew0', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1)) - skew1 = pyro.param('skew1', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1)) + skews = [] + for i in range(dim): + skews.append(pyro.param(f'skew{i}', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1))) - skewness = torch.stack((skew0, skew1), dim=-1) + skewness = torch.stack(skews, dim=-1).squeeze() with pyro.plate("data", data.size(-len(data.size()))): pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) @@ -101,4 +76,4 @@ def guide(data, batch_shape): losses.append(svi.step(data, base_dist.batch_shape)) act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1) - assert_equal(act_skewness, skewness_tar, 5e-2) + assert_equal(act_skewness, skewness_tar, 1e-1) From d935f7447ce9f29ee69c757b59dc57b626151189 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 7 May 2021 22:49:47 +0200 Subject: [PATCH 21/35] Added fixture. --- pyro/distributions/sine_skewed.py | 27 +++++++++++---------- tests/distributions/conftest.py | 32 +++++++++++++------------ tests/distributions/test_sine_skewed.py | 6 ++--- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 55971d8c35..1ce4f790de 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -1,6 +1,7 @@ from math import pi import torch +from torch import broadcast_shapes from torch.distributions import Uniform from pyro.distributions import constraints @@ -47,11 +48,13 @@ def model(...): support = constraints.independent(constraints.real, 1) def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): - assert skewness.abs().sum(-1) <= 1., "total skewness weight cannot exceed one." + if (skewness.abs().sum(-1) > 1.).any(): + raise Warning("Total skewness weight shouldn't exceed one.", UserWarning) - self.base_density = base_dist - self.skewness = skewness.broadcast_to(base_dist.shape()) - super().__init__(base_dist.batch_shape, base_dist.event_shape, validate_args=validate_args) + batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) + self.skewness = skewness.broadcast_to(batch_shape + base_dist.event_shape) + self.base_dist = base_dist.expand(batch_shape) + super().__init__(batch_shape, base_dist.event_shape, validate_args=validate_args) if self._validate_args and base_dist.mean.device != skewness.device: raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " @@ -61,10 +64,10 @@ def __repr__(self): args_string = ', '.join(['{}: {}'.format(p, getattr(self, p) if getattr(self, p).numel() == 1 else getattr(self, p).size()) for p in self.arg_constraints.keys()]) - return self.__class__.__name__ + '(' + f'base_density: {str(self.base_density)}, ' + args_string + ')' + return self.__class__.__name__ + '(' + f'base_density: {str(self.base_dist)}, ' + args_string + ')' def sample(self, sample_shape=torch.Size()): - bd = self.base_density + bd = self.base_dist ys = bd.sample(sample_shape) u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) @@ -76,15 +79,15 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - skew_prob = torch.log(1 + (self.skewness * torch.sin((value - self.base_density.mean) % (2 * pi))).sum(-1)) - return self.base_density.log_prob(value) + skew_prob + skew_prob = torch.log(1 + (self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1)) + return self.base_dist.log_prob(value) + skew_prob def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(SineSkewed, _instance) - base_dist = self.base_density.expand(batch_shape, None) - new.base_density = base_dist - new.skewness = self.skewness.expand(batch_shape) - super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=None) + base_dist = self.base_dist.expand(batch_shape) + new.base_dist = base_dist + new.skewness = self.skewness.expand(batch_shape+(-1,)) + super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 366bc6a933..581e9f0629 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -188,7 +188,7 @@ def __init__(self, rate, *, validate_args=None): ], # This hack seems to be the best option right now, as 'scale' is not handled well by get_scipy_batch_logpdf scipy_arg_fn=lambda loc, covariance_matrix=None: - ((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}), + ((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}), prec=0.01, min_samples=500000), Fixture(pyro_dist=dist.LowRankMultivariateNormal, @@ -198,7 +198,7 @@ def __init__(self, rate, *, validate_args=None): 'test_data': [[2.0, 1.0], [9.0, 3.4]]}, ], scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None: - ((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}), + ((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}), prec=0.01, min_samples=500000), Fixture(pyro_dist=FoldedNormal, @@ -281,12 +281,12 @@ def __init__(self, rate, *, validate_args=None): Fixture(pyro_dist=dist.LKJ, examples=[ {'dim': 3, 'concentration': 1., 'test_data': - [[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]], - [[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]], - [[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]], - [[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]], - [[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]}, - ]), + [[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]], + [[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]], + [[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]], + [[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]], + [[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]}, + ]), Fixture(pyro_dist=dist.LKJCholesky, examples=[ { @@ -306,19 +306,19 @@ def __init__(self, rate, *, validate_args=None): examples=[ {'stability': [1.5], 'skew': 0.1, 'test_data': [-10.]}, {'stability': [1.5], 'skew': 0.1, 'scale': 2.0, 'loc': -2.0, 'test_data': [10.]}, - ]), + ]), Fixture(pyro_dist=dist.MultivariateStudentT, examples=[ {'df': 1.5, 'loc': [0.2, 0.3], 'scale_tril': [[0.8, 0.0], [1.3, 0.4]], 'test_data': [-3., 2]}, - ]), + ]), Fixture(pyro_dist=dist.ProjectedNormal, examples=[ {'concentration': [0., 0.], 'test_data': [1., 0.]}, {'concentration': [2., 3.], 'test_data': [0., 1.]}, {'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]}, {'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]}, - ]), + ]), Fixture(pyro_dist=dist.SoftLaplace, examples=[ {'loc': [2.0], 'scale': [4.0], @@ -329,12 +329,14 @@ def __init__(self, rate, *, validate_args=None): 'test_data': [[[2.0]]]}, {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, - ]), + ]), Fixture(pyro_dist=dist.SineSkewed, examples=[{ - 'base_density': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), - 'skewness': [-pi/4, 0.], 'test_data': [pi/2, -2*pi/3] - }]) + 'base_dist': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), + 'skewness': [-pi / 4, .1], 'test_data': [pi / 2, -2 * pi / 3]}, + {'base_dist': dist.VonMises(*tensor_wrap([0.], [1.])).to_event(1), + 'skewness': [[.342355], [0.]], 'test_data': [[-.4], [.1]]}, + ]) ] discrete_dists = [ diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index 82071e50b7..46913903f3 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -12,10 +12,10 @@ BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))] -def _skewness(event_shape, max_sum=1.): +def _skewness(event_shape): skewness = torch.zeros(event_shape.numel()) done = False - while not done and skewness.abs().sum(-1) > max_sum: + while not done: for i in range(event_shape.numel()): max_ = 1. - skewness.abs().sum(-1) if torch.any(max_ < 1e-15): @@ -59,7 +59,7 @@ def model(data, batch_shape): for i in range(dim): skews.append(pyro.param(f'skew{i}', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1))) - skewness = torch.stack(skews, dim=-1).squeeze() + skewness = torch.stack(skews, dim=-1) with pyro.plate("data", data.size(-len(data.size()))): pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) From 9ca8a4503f1f07b20e1366b49bb46c9f9bf9ae45 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 9 May 2021 16:08:16 +0200 Subject: [PATCH 22/35] Fixed tests. --- pyro/distributions/sine_skewed.py | 11 +++++------ tests/distributions/conftest.py | 10 ++++++---- tests/distributions/test_sine_skewed.py | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 1ce4f790de..5e7d7155ae 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -27,9 +27,7 @@ def model(...): In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist - cannot be reparameterized). For the 1-torus (circle) with a - :class:`~pyro.distribution.ProjectedNormal` base distribution inference is tractable using ``poutine.reparam`` as - outlined in :class:`~pyro.distribution.ProjectedNormal`. + cannot be reparameterized. .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,). @@ -52,9 +50,10 @@ def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): raise Warning("Total skewness weight shouldn't exceed one.", UserWarning) batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) - self.skewness = skewness.broadcast_to(batch_shape + base_dist.event_shape) + event_shape = skewness.shape[-1:] + self.skewness = skewness.broadcast_to(batch_shape + event_shape) self.base_dist = base_dist.expand(batch_shape) - super().__init__(batch_shape, base_dist.event_shape, validate_args=validate_args) + super().__init__(batch_shape, event_shape, validate_args=validate_args) if self._validate_args and base_dist.mean.device != skewness.device: raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " @@ -87,7 +86,7 @@ def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(SineSkewed, _instance) base_dist = self.base_dist.expand(batch_shape) new.base_dist = base_dist - new.skewness = self.skewness.expand(batch_shape+(-1,)) + new.skewness = self.skewness.expand(batch_shape + (-1,)) super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 581e9f0629..510ad17a04 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -331,11 +331,13 @@ def __init__(self, rate, *, validate_args=None): 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, ]), Fixture(pyro_dist=dist.SineSkewed, - examples=[{ - 'base_dist': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), - 'skewness': [-pi / 4, .1], 'test_data': [pi / 2, -2 * pi / 3]}, + examples=[ {'base_dist': dist.VonMises(*tensor_wrap([0.], [1.])).to_event(1), - 'skewness': [[.342355], [0.]], 'test_data': [[-.4], [.1]]}, + 'skewness': [.342355], 'test_data': [.1]}, + {'base_dist': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), + 'skewness': [-pi / 4, .1], 'test_data': [pi / 2, -2 * pi / 3]}, + {'base_dist': dist.VonMises(*tensor_wrap([0., -1.234], [1., 10.])).to_event(1), + 'skewness': [[.342355, -.0001], [.91, 0.09]], 'test_data': [[.1, -3.2], [-2., 0.]]}, ]) ] diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index 46913903f3..0dae84783e 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -57,7 +57,7 @@ def test_ss_mle(dim, dist): def model(data, batch_shape): skews = [] for i in range(dim): - skews.append(pyro.param(f'skew{i}', torch.zeros(batch_shape), constraint=constraints.interval(-1, 1))) + skews.append(pyro.param(f'skew{i}', .5 * torch.ones(batch_shape), constraint=constraints.interval(-1, 1))) skewness = torch.stack(skews, dim=-1) with pyro.plate("data", data.size(-len(data.size()))): From dd461fd51018306cc7859f7acaa4851fcd31d968 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 9 May 2021 17:55:15 +0200 Subject: [PATCH 23/35] removed deprecated add_stylesheet --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index eadbdd10f3..de12acacd2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -205,7 +205,7 @@ def skip(app, what, name, obj, skip, options): def setup(app): - app.add_stylesheet('css/pyro.css') + app.add_css_file('css/pyro.css') # app.connect("autodoc-skip-member", skip) From 5cfee34149e4431978e157d2e4d4c3c6ac4da4ec Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 9 May 2021 18:07:41 +0200 Subject: [PATCH 24/35] reverted to `add_stylesheet` --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index de12acacd2..eadbdd10f3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -205,7 +205,7 @@ def skip(app, what, name, obj, skip, options): def setup(app): - app.add_css_file('css/pyro.css') + app.add_stylesheet('css/pyro.css') # app.connect("autodoc-skip-member", skip) From 6d79eb3d378120e5f93de8b21e006a1f88dbd15f Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 10 May 2021 10:38:08 +0200 Subject: [PATCH 25/35] Removed raise from sine_skewed.py --- pyro/distributions/sine_skewed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 5e7d7155ae..04e47a0f5e 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -1,3 +1,4 @@ +import warnings from math import pi import torch @@ -47,7 +48,7 @@ def model(...): def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): if (skewness.abs().sum(-1) > 1.).any(): - raise Warning("Total skewness weight shouldn't exceed one.", UserWarning) + warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning) batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] From ff26ce91b537dcf9ae3bba4fe2cacbc9f7cd2651 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 10 May 2021 10:51:03 +0200 Subject: [PATCH 26/35] Added equation references. --- pyro/distributions/sine_skewed.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 04e47a0f5e..c3801229ed 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -71,6 +71,7 @@ def sample(self, sample_shape=torch.Size()): ys = bd.sample(sample_shape) u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) + # Section 2.3 step 3 in [1] mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) mask = mask[..., None] samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi @@ -79,6 +80,8 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._validate_args: self._validate_sample(value) + + # Eq. 2.1 in [1] skew_prob = torch.log(1 + (self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1)) return self.base_dist.log_prob(value) + skew_prob From 4427e372412044b5e05bb09197194aacb0f15daf Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 10 May 2021 11:01:50 +0200 Subject: [PATCH 27/35] Fixed sampling bound in `SineSkewed`. --- pyro/distributions/sine_skewed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index c3801229ed..1a7c4eb6d7 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -72,7 +72,7 @@ def sample(self, sample_shape=torch.Size()): u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] - mask = u < 1. + (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + mask = u < .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) mask = mask[..., None] samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples From a7bf5fe611b61d2011a143f3957597ad6e93aa30 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 10 May 2021 17:45:29 +0200 Subject: [PATCH 28/35] Fixed prior on `SineSkewed` to avoid `AffineTransform`. --- pyro/distributions/sine_skewed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index c3801229ed..7a04c23039 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -20,10 +20,10 @@ class SineSkewed(TorchDistribution): def model(...): ... - skewness_phi = pyro.sample(f'skewness_phi', Uniform(skewness.abs().sum(), 1 - tots)) + skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.)) psi_bound = 1 - skewness_phi.abs() - skewness_psi = pyro.sample(f'skewness_psi', Uniform(-psi_bound, psi_bound) - skewness = torch.stack((skewness_phi, skewness_psi), dim=0) + skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.)) + skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0) ... In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a From c9ede43412514b2c9788502861c0c6dbf2113cee Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 10 May 2021 17:51:03 +0200 Subject: [PATCH 29/35] Merged origin. --- pyro/distributions/sine_skewed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 47f2c9e0d4..d7e7f8a3c1 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -72,7 +72,7 @@ def sample(self, sample_shape=torch.Size()): u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] - mask = u < .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) mask = mask[..., None] samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples From fac58645dc90e61185fe43dc65cf8e78d83b6aa4 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 2 Jun 2021 12:34:34 +0200 Subject: [PATCH 30/35] removed import all pyro distributions --- pyro/distributions/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index c4e2ccf1b1..c7bcc3b5ce 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -62,7 +62,6 @@ from pyro.distributions.softlaplace import SoftLaplace from pyro.distributions.spanning_tree import SpanningTree from pyro.distributions.stable import Stable -from pyro.distributions.torch import * # noqa F403 from pyro.distributions.torch import __all__ as torch_dists from pyro.distributions.torch_distribution import ( ExpandedDistribution, From 9fc36ff683e3892677abf08cecf57b3c16fda3f8 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 2 Jun 2021 14:00:33 +0200 Subject: [PATCH 31/35] Fixed tests for SineSkewed with wrapper class. --- tests/distributions/conftest.py | 37 ++++++++++++++++++----- tests/distributions/test_distributions.py | 2 +- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index b408280132..7d17485c63 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -39,6 +39,18 @@ def __init__(self, rate, *, validate_args=None): super().__init__(rate, is_sparse=True, validate_args=validate_args) +class SineSkewedUniform(dist.SineSkewed): + def __init__(self, lower, upper, skewness, *args, **kwargs): + base_dist = dist.Uniform(lower, upper).to_event(lower.ndim) + super().__init__(base_dist, skewness, *args, **kwargs) + + +class SineSkewedVonMises(dist.SineSkewed): + def __init__(self, von_loc, von_conc, skewness): + base_dist = dist.VonMises(von_loc, von_conc).to_event(von_loc.ndim) + super().__init__(base_dist, skewness) + + continuous_dists = [ Fixture(pyro_dist=dist.Uniform, scipy_dist=sp.uniform, @@ -342,14 +354,23 @@ def __init__(self, rate, *, validate_args=None): {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, ]), - Fixture(pyro_dist=dist.SineSkewed, - examples=[ - {'base_dist': dist.VonMises(*tensor_wrap([0.], [1.])).to_event(1), - 'skewness': [.342355], 'test_data': [.1]}, - {'base_dist': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1), - 'skewness': [-pi / 4, .1], 'test_data': [pi / 2, -2 * pi / 3]}, - {'base_dist': dist.VonMises(*tensor_wrap([0., -1.234], [1., 10.])).to_event(1), - 'skewness': [[.342355, -.0001], [.91, 0.09]], 'test_data': [[.1, -3.2], [-2., 0.]]}, + Fixture(pyro_dist=SineSkewedUniform, + examples=[ + {'lower': [-pi, -pi], + 'upper':[pi, pi], + 'skewness': [-pi / 4, .1], + 'test_data': [pi / 2, -2 * pi / 3]} + ]), + Fixture(pyro_dist=SineSkewedVonMises, + examples=[ + {'von_loc': [0.], + 'von_conc': [1.], + 'skewness': [.342355], + 'test_data': [.1]}, + {'von_loc': [0., -1.234], + 'von_conc': [1., 10.], + 'skewness': [[.342355, -.0001], [.91, 0.09]], + 'test_data': [[.1, -3.2], [-2., 0.]]} ]), Fixture(pyro_dist=dist.AsymmetricLaplace, examples=[ diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index e962f79254..f68047827a 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -40,7 +40,7 @@ def test_support_shape(dist): def test_infer_shapes(dist): - if "LKJ" in dist.pyro_dist.__name__ or "SineSkewed" == dist.pyro_dist.__name__: + if "LKJ" in dist.pyro_dist.__name__ or "SineSkewed" in dist.pyro_dist.__name__: pytest.xfail(reason="cannot statically compute shape") for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) From 9b78a7ae78725a4df9c755fbd81807357367158e Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 2 Jun 2021 14:06:14 +0200 Subject: [PATCH 32/35] Removed unused import from conftest.py --- tests/distributions/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 7d17485c63..35de110756 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -16,7 +16,7 @@ ShapeAugmentedDirichlet, ShapeAugmentedGamma, ) -from tests.distributions.dist_fixture import Fixture, tensor_wrap +from tests.distributions.dist_fixture import Fixture class FoldedNormal(dist.FoldedDistribution): From 0b5af735b50bceef2a037d2c97735c7abff3a065 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 2 Jun 2021 14:30:21 +0200 Subject: [PATCH 33/35] Removed xfail int test_cuda for `SineSkewed` --- tests/distributions/test_cuda.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/distributions/test_cuda.py b/tests/distributions/test_cuda.py index fdccc66ef9..9cb887b467 100644 --- a/tests/distributions/test_cuda.py +++ b/tests/distributions/test_cuda.py @@ -15,8 +15,6 @@ @requires_cuda def test_sample(dist): - if dist.pyro_dist.__name__ == 'SineSkewed': - pytest.xfail(reason="Fixture with distribution param not handled.") for idx in range(len(dist.dist_params)): # Compute CPU value. @@ -79,8 +77,6 @@ def test_rsample(dist): @requires_cuda def test_log_prob(dist): - if dist.pyro_dist.__name__ == 'SineSkewed': - pytest.xfail(reason="Fixture with distribution param not handled.") for idx in range(len(dist.dist_params)): # Compute CPU value. From 4d44aa08a9ae714accc7b52e94ae2d372546c5ee Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 4 Jun 2021 10:23:21 +0200 Subject: [PATCH 34/35] Fixed DocString example. --- pyro/distributions/sine_skewed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 1d103d4da8..eb049009d6 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -41,10 +41,10 @@ def model(obs): with pyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, - phi_concentration=1000 * phi_conc, - psi_concentration=1000 * psi_conc, - weighted_correlation=corr_scale) - return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) + phi_concentration=1000 * phi_conc, + psi_concentration=1000 * psi_conc, + weighted_correlation=corr_scale) + return pyro.sample(' phi_psi', SineSkewed(sine, skewness), obs=obs) To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of @@ -109,7 +109,7 @@ def __repr__(self): def sample(self, sample_shape=torch.Size()): bd = self.base_dist ys = bd.sample(sample_shape) - u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape) + u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) From 3143b4560c7a041236ab8f21eabb58b23c5ef1ce Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 7 Jun 2021 09:04:13 +0200 Subject: [PATCH 35/35] Fixed psi_phi name in docstring. --- pyro/distributions/sine_skewed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index eb049009d6..d17cb1821c 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -44,7 +44,7 @@ def model(obs): phi_concentration=1000 * phi_conc, psi_concentration=1000 * psi_conc, weighted_correlation=corr_scale) - return pyro.sample(' phi_psi', SineSkewed(sine, skewness), obs=obs) + return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of