From b7ba915b851d51c61d805da741bf7a74fcf9319d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 2 Jul 2021 05:35:59 -0600 Subject: [PATCH] Apply black code formatter (#2891) * Add black to Makefile and setup.cfg * Tweak files in preparation for black * Update .github/workflows/ci.yml * Apply black code formatter * Require black>=21.4 * Require black>=21.4b0 --- .github/workflows/ci.yml | 4 +- Makefile | 2 + docs/source/conf.py | 99 +- examples/air/air.py | 255 +- examples/air/main.py | 375 ++- examples/air/modules.py | 16 +- examples/air/viz.py | 20 +- examples/baseball.py | 207 +- examples/capture_recapture/cjs.py | 198 +- examples/contrib/autoname/mixture.py | 16 +- examples/contrib/autoname/scoping_mixture.py | 45 +- examples/contrib/autoname/tree_data.py | 14 +- examples/contrib/cevae/synthetic.py | 38 +- examples/contrib/epidemiology/regional.py | 99 +- examples/contrib/epidemiology/sir.py | 164 +- examples/contrib/forecast/bart.py | 66 +- examples/contrib/funsor/hmm.py | 446 ++- examples/contrib/gp/sv-dkl.py | 169 +- examples/contrib/mue/FactorMuE.py | 444 ++- examples/contrib/mue/ProfileHMM.py | 293 +- examples/contrib/oed/ab_test.py | 47 +- examples/contrib/oed/gp_bayes_opt.py | 25 +- examples/contrib/timeseries/gp_models.py | 112 +- examples/cvae/baseline.py | 38 +- examples/cvae/cvae.py | 47 +- examples/cvae/main.py | 96 +- examples/cvae/mnist.py | 42 +- examples/cvae/util.py | 83 +- examples/dmm.py | 313 +- examples/eight_schools/data.py | 4 +- examples/eight_schools/mcmc.py | 50 +- examples/eight_schools/svi.py | 44 +- examples/einsum.py | 97 +- examples/hmm.py | 403 ++- examples/inclined_plane.py | 60 +- examples/lda.py | 78 +- examples/lkj.py | 22 +- examples/minipyro.py | 10 +- examples/mixed_hmm/experiment.py | 60 +- examples/mixed_hmm/model.py | 201 +- examples/mixed_hmm/seal_data.py | 14 +- examples/neutra.py | 211 +- examples/rsa/generics.py | 61 +- examples/rsa/hyperbole.py | 108 +- examples/rsa/schelling.py | 22 +- examples/rsa/schelling_false.py | 30 +- examples/rsa/search_inference.py | 54 +- examples/rsa/semantic_parsing.py | 90 +- examples/scanvi/data.py | 87 +- examples/scanvi/scanvi.py | 177 +- examples/sir_hmc.py | 295 +- examples/smcfilter.py | 33 +- examples/sparse_gamma_def.py | 156 +- examples/sparse_regression.py | 155 +- examples/svi_horovod.py | 16 +- .../toy_mixture_model_discrete_enumeration.py | 97 +- examples/vae/ss_vae_M2.py | 272 +- examples/vae/utils/custom_mlp.py | 59 +- examples/vae/utils/mnist_cached.py | 73 +- examples/vae/utils/vae_plots.py | 23 +- examples/vae/vae.py | 81 +- examples/vae/vae_comparison.py | 108 +- profiler/distributions.py | 152 +- profiler/hmm.py | 20 +- profiler/profiling_utils.py | 36 +- pyro/__init__.py | 2 +- pyro/contrib/__init__.py | 1 + pyro/contrib/autoguide.py | 8 +- pyro/contrib/autoname/named.py | 32 +- pyro/contrib/autoname/scoping.py | 3 + pyro/contrib/bnn/hidden_layer.py | 29 +- pyro/contrib/cevae/__init__.py | 127 +- pyro/contrib/conjugate/infer.py | 51 +- pyro/contrib/easyguide/easyguide.py | 75 +- pyro/contrib/epidemiology/compartmental.py | 341 +- pyro/contrib/epidemiology/distributions.py | 39 +- pyro/contrib/epidemiology/models.py | 475 +-- pyro/contrib/epidemiology/util.py | 250 +- pyro/contrib/examples/bart.py | 22 +- pyro/contrib/examples/multi_mnist.py | 24 +- .../examples/polyphonic_data_loader.py | 64 +- pyro/contrib/examples/util.py | 32 +- pyro/contrib/forecast/evaluate.py | 65 +- pyro/contrib/forecast/forecaster.py | 137 +- pyro/contrib/forecast/util.py | 60 +- pyro/contrib/funsor/__init__.py | 19 +- .../contrib/funsor/handlers/enum_messenger.py | 132 +- .../funsor/handlers/named_messenger.py | 40 +- .../funsor/handlers/plate_messenger.py | 76 +- pyro/contrib/funsor/handlers/primitives.py | 10 +- .../funsor/handlers/replay_messenger.py | 9 +- pyro/contrib/funsor/handlers/runtime.py | 87 +- .../funsor/handlers/trace_messenger.py | 34 +- pyro/contrib/funsor/infer/discrete.py | 13 +- pyro/contrib/funsor/infer/elbo.py | 17 +- pyro/contrib/funsor/infer/trace_elbo.py | 22 +- pyro/contrib/funsor/infer/traceenum_elbo.py | 147 +- pyro/contrib/funsor/infer/tracetmc_elbo.py | 19 +- pyro/contrib/gp/kernels/__init__.py | 15 +- pyro/contrib/gp/kernels/brownian.py | 10 +- pyro/contrib/gp/kernels/coregionalize.py | 22 +- pyro/contrib/gp/kernels/dot_product.py | 12 +- pyro/contrib/gp/kernels/isotropic.py | 32 +- pyro/contrib/gp/kernels/kernel.py | 47 +- pyro/contrib/gp/kernels/periodic.py | 12 +- pyro/contrib/gp/kernels/static.py | 6 +- pyro/contrib/gp/likelihoods/__init__.py | 15 +- pyro/contrib/gp/likelihoods/binary.py | 7 +- pyro/contrib/gp/likelihoods/gaussian.py | 5 +- pyro/contrib/gp/likelihoods/likelihood.py | 1 + pyro/contrib/gp/likelihoods/multi_class.py | 24 +- pyro/contrib/gp/likelihoods/poisson.py | 7 +- pyro/contrib/gp/models/gplvm.py | 11 +- pyro/contrib/gp/models/gpr.py | 48 +- pyro/contrib/gp/models/model.py | 33 +- pyro/contrib/gp/models/sgpr.py | 39 +- pyro/contrib/gp/models/vgp.py | 55 +- pyro/contrib/gp/models/vsgp.py | 72 +- pyro/contrib/gp/parameterized.py | 26 +- pyro/contrib/gp/util.py | 28 +- pyro/contrib/minipyro.py | 59 +- pyro/contrib/mue/dataloaders.py | 91 +- pyro/contrib/mue/missingdatahmm.py | 64 +- pyro/contrib/mue/models.py | 594 ++-- pyro/contrib/mue/statearrangers.py | 118 +- pyro/contrib/oed/__init__.py | 5 +- pyro/contrib/oed/eig.py | 514 ++- pyro/contrib/oed/glmm/glmm.py | 286 +- pyro/contrib/oed/glmm/guides.py | 102 +- pyro/contrib/oed/search.py | 4 +- pyro/contrib/oed/util.py | 21 +- .../contrib/randomvariable/random_variable.py | 43 +- pyro/contrib/timeseries/base.py | 1 + pyro/contrib/timeseries/gp.py | 263 +- pyro/contrib/timeseries/lgssm.py | 58 +- pyro/contrib/timeseries/lgssmgp.py | 165 +- pyro/contrib/tracking/assignment.py | 160 +- pyro/contrib/tracking/distributions.py | 41 +- pyro/contrib/tracking/dynamic_models.py | 188 +- .../tracking/extended_kalman_filter.py | 107 +- pyro/contrib/tracking/hashing.py | 22 +- pyro/contrib/tracking/measurements.py | 71 +- pyro/contrib/util.py | 8 +- pyro/distributions/asymmetriclaplace.py | 48 +- pyro/distributions/avf_mvn.py | 26 +- pyro/distributions/coalescent.py | 69 +- pyro/distributions/conditional.py | 15 +- pyro/distributions/conjugate.py | 83 +- pyro/distributions/constraints.py | 51 +- pyro/distributions/delta.py | 21 +- pyro/distributions/diag_normal_mixture.py | 110 +- .../diag_normal_mixture_shared_cov.py | 102 +- pyro/distributions/distribution.py | 19 +- pyro/distributions/empirical.py | 77 +- pyro/distributions/extended.py | 18 +- pyro/distributions/folded.py | 3 +- pyro/distributions/gaussian_scale_mixture.py | 106 +- pyro/distributions/hmm.py | 384 ++- pyro/distributions/improper_uniform.py | 3 +- pyro/distributions/inverse_gamma.py | 12 +- pyro/distributions/kl.py | 10 +- pyro/distributions/lkj.py | 13 +- pyro/distributions/logistic.py | 9 +- pyro/distributions/mixture.py | 51 +- pyro/distributions/multivariate_studentt.py | 61 +- pyro/distributions/omt_mvn.py | 18 +- pyro/distributions/one_one_matching.py | 5 +- pyro/distributions/one_two_matching.py | 17 +- pyro/distributions/polya_gamma.py | 24 +- pyro/distributions/projected_normal.py | 29 +- pyro/distributions/rejector.py | 6 +- .../distributions/relaxed_straight_through.py | 6 +- pyro/distributions/score_parts.py | 5 +- .../distributions/sine_bivariate_von_mises.py | 163 +- pyro/distributions/sine_skewed.py | 58 +- pyro/distributions/spanning_tree.py | 84 +- pyro/distributions/stable.py | 40 +- pyro/distributions/testing/gof.py | 17 +- pyro/distributions/testing/naive_dirichlet.py | 6 +- .../testing/rejection_exponential.py | 3 +- pyro/distributions/testing/rejection_gamma.py | 64 +- pyro/distributions/torch.py | 96 +- pyro/distributions/torch_distribution.py | 85 +- pyro/distributions/torch_patch.py | 32 +- pyro/distributions/torch_transform.py | 1 + pyro/distributions/transforms/__init__.py | 140 +- .../transforms/affine_autoregressive.py | 37 +- .../transforms/affine_coupling.py | 93 +- pyro/distributions/transforms/basic.py | 9 +- pyro/distributions/transforms/batchnorm.py | 15 +- .../transforms/block_autoregressive.py | 114 +- pyro/distributions/transforms/cholesky.py | 35 +- .../transforms/discrete_cosine.py | 12 +- .../transforms/generalized_channel_permute.py | 26 +- pyro/distributions/transforms/haar.py | 10 +- pyro/distributions/transforms/householder.py | 36 +- .../transforms/lower_cholesky_affine.py | 20 +- .../transforms/matrix_exponential.py | 49 +- .../transforms/neural_autoregressive.py | 33 +- pyro/distributions/transforms/normalize.py | 1 + pyro/distributions/transforms/ordered.py | 1 + pyro/distributions/transforms/permute.py | 16 +- pyro/distributions/transforms/planar.py | 38 +- pyro/distributions/transforms/polynomial.py | 32 +- pyro/distributions/transforms/radial.py | 36 +- pyro/distributions/transforms/softplus.py | 5 +- pyro/distributions/transforms/spline.py | 242 +- .../transforms/spline_autoregressive.py | 32 +- .../transforms/spline_coupling.py | 39 +- pyro/distributions/transforms/sylvester.py | 21 +- pyro/distributions/unit.py | 3 +- pyro/distributions/util.py | 53 +- pyro/distributions/von_mises_3d.py | 21 +- pyro/distributions/zero_inflated.py | 65 +- pyro/infer/abstract_infer.py | 118 +- pyro/infer/autoguide/__init__.py | 44 +- pyro/infer/autoguide/guides.py | 252 +- pyro/infer/autoguide/initialization.py | 13 +- pyro/infer/autoguide/utils.py | 8 +- pyro/infer/csis.py | 63 +- pyro/infer/discrete.py | 95 +- pyro/infer/elbo.py | 74 +- pyro/infer/energy_distance.py | 54 +- pyro/infer/enum.py | 88 +- pyro/infer/importance.py | 79 +- pyro/infer/mcmc/adaptation.py | 104 +- pyro/infer/mcmc/api.py | 297 +- pyro/infer/mcmc/hmc.py | 145 +- pyro/infer/mcmc/logger.py | 37 +- pyro/infer/mcmc/mcmc_kernel.py | 1 - pyro/infer/mcmc/nuts.py | 265 +- pyro/infer/mcmc/util.py | 317 +- pyro/infer/predictive.py | 193 +- pyro/infer/renyi_elbo.py | 85 +- pyro/infer/reparam/conjugate.py | 7 +- pyro/infer/reparam/discrete_cosine.py | 9 +- pyro/infer/reparam/haar.py | 9 +- pyro/infer/reparam/hmm.py | 73 +- pyro/infer/reparam/loc_scale.py | 9 +- pyro/infer/reparam/neutra.py | 31 +- pyro/infer/reparam/reparam.py | 1 + pyro/infer/reparam/split.py | 3 +- pyro/infer/reparam/stable.py | 48 +- pyro/infer/reparam/studentt.py | 6 +- pyro/infer/reparam/unit_jacobian.py | 30 +- pyro/infer/rws.py | 100 +- pyro/infer/smcfilter.py | 54 +- pyro/infer/svgd.py | 54 +- pyro/infer/svi.py | 66 +- pyro/infer/trace_elbo.py | 90 +- pyro/infer/trace_mean_field_elbo.py | 80 +- pyro/infer/trace_mmd.py | 98 +- pyro/infer/trace_tail_adaptive_elbo.py | 21 +- pyro/infer/traceenum_elbo.py | 219 +- pyro/infer/tracegraph_elbo.py | 134 +- pyro/infer/tracetmc_elbo.py | 61 +- pyro/infer/util.py | 49 +- pyro/logger.py | 2 +- pyro/nn/auto_reg_nn.py | 114 +- pyro/nn/dense_nn.py | 30 +- pyro/nn/module.py | 91 +- pyro/ops/arrowhead.py | 10 +- pyro/ops/contract.py | 164 +- pyro/ops/dual_averaging.py | 4 +- pyro/ops/einsum/__init__.py | 8 +- pyro/ops/einsum/adjoint.py | 15 +- pyro/ops/einsum/torch_log.py | 15 +- pyro/ops/einsum/torch_map.py | 12 +- pyro/ops/einsum/torch_marginal.py | 12 +- pyro/ops/einsum/torch_sample.py | 14 +- pyro/ops/einsum/util.py | 3 +- pyro/ops/gamma_gaussian.py | 120 +- pyro/ops/gaussian.py | 117 +- pyro/ops/hessian.py | 4 +- pyro/ops/indexing.py | 2 + pyro/ops/integrator.py | 19 +- pyro/ops/jit.py | 53 +- pyro/ops/linalg.py | 27 +- pyro/ops/newton.py | 86 +- pyro/ops/packed.py | 63 +- pyro/ops/rings.py | 78 +- pyro/ops/special.py | 23 +- pyro/ops/ssm_gp.py | 55 +- pyro/ops/stats.py | 53 +- pyro/ops/tensor_utils.py | 59 +- pyro/ops/welford.py | 39 +- pyro/optim/adagrad_rmsprop.py | 30 +- pyro/optim/clipped_adam.py | 55 +- pyro/optim/dct_adam.py | 84 +- pyro/optim/horovod.py | 6 +- pyro/optim/lr_scheduler.py | 17 +- pyro/optim/multi.py | 18 +- pyro/optim/optim.py | 38 +- pyro/optim/pytorch_optimizers.py | 25 +- pyro/params/param_store.py | 29 +- pyro/poutine/block_messenger.py | 45 +- pyro/poutine/broadcast_messenger.py | 29 +- pyro/poutine/collapse_messenger.py | 15 +- pyro/poutine/condition_messenger.py | 1 + pyro/poutine/do_messenger.py | 31 +- pyro/poutine/enum_messenger.py | 56 +- pyro/poutine/escape_messenger.py | 2 + pyro/poutine/handlers.py | 68 +- pyro/poutine/indep_messenger.py | 12 +- pyro/poutine/infer_config_messenger.py | 1 + pyro/poutine/lift_messenger.py | 10 +- pyro/poutine/markov_messenger.py | 11 +- pyro/poutine/mask_messenger.py | 9 +- pyro/poutine/messenger.py | 6 +- pyro/poutine/plate_messenger.py | 9 +- pyro/poutine/reparam_messenger.py | 16 +- pyro/poutine/replay_messenger.py | 8 +- pyro/poutine/runtime.py | 29 +- pyro/poutine/scale_messenger.py | 7 +- pyro/poutine/seed_messenger.py | 1 + pyro/poutine/subsample_messenger.py | 76 +- pyro/poutine/trace_messenger.py | 31 +- pyro/poutine/trace_struct.py | 176 +- pyro/poutine/uncondition_messenger.py | 1 + pyro/poutine/util.py | 18 +- pyro/primitives.py | 62 +- pyro/util.py | 363 ++- scripts/update_headers.py | 5 +- scripts/update_version.py | 3 +- setup.cfg | 2 +- setup.py | 155 +- tests/__init__.py | 2 +- tests/common.py | 86 +- tests/conftest.py | 47 +- tests/contrib/autoguide/test_inference.py | 180 +- .../autoguide/test_mean_field_entropy.py | 19 +- tests/contrib/autoname/test_named.py | 18 +- tests/contrib/autoname/test_scoping.py | 62 +- tests/contrib/bnn/test_hidden_layer.py | 33 +- tests/contrib/cevae/test_cevae.py | 7 +- tests/contrib/easyguide/test_easyguide.py | 99 +- .../epidemiology/test_distributions.py | 144 +- tests/contrib/epidemiology/test_models.py | 377 ++- tests/contrib/epidemiology/test_util.py | 4 +- tests/contrib/forecast/test_evaluate.py | 70 +- tests/contrib/forecast/test_forecaster.py | 105 +- tests/contrib/forecast/test_util.py | 32 +- tests/contrib/funsor/test_enum_funsor.py | 1118 ++++--- tests/contrib/funsor/test_infer_discrete.py | 228 +- tests/contrib/funsor/test_named_handlers.py | 69 +- tests/contrib/funsor/test_pyroapi_funsor.py | 1 + tests/contrib/funsor/test_tmc.py | 159 +- .../contrib/funsor/test_valid_models_enum.py | 228 +- .../contrib/funsor/test_valid_models_plate.py | 65 +- .../test_valid_models_sequential_plate.py | 54 +- .../contrib/funsor/test_vectorized_markov.py | 640 ++-- tests/contrib/gp/test_conditional.py | 65 +- tests/contrib/gp/test_kernels.py | 97 +- tests/contrib/gp/test_likelihoods.py | 87 +- tests/contrib/gp/test_models.py | 162 +- tests/contrib/gp/test_parameterized.py | 42 +- tests/contrib/mue/test_dataloaders.py | 47 +- tests/contrib/mue/test_missingdatahmm.py | 167 +- tests/contrib/mue/test_models.py | 98 +- tests/contrib/mue/test_statearrangers.py | 193 +- tests/contrib/oed/test_ewma.py | 18 +- tests/contrib/oed/test_finite_spaces_eig.py | 196 +- tests/contrib/oed/test_glmm.py | 171 +- tests/contrib/oed/test_linear_models_eig.py | 207 +- tests/contrib/oed/test_xexpx.py | 13 +- .../randomvariable/test_random_variable.py | 12 +- tests/contrib/test_minipyro.py | 47 +- tests/contrib/test_util.py | 48 +- tests/contrib/timeseries/test_gp.py | 126 +- tests/contrib/timeseries/test_lgssm.py | 78 +- tests/contrib/tracking/test_assignment.py | 212 +- tests/contrib/tracking/test_distributions.py | 12 +- tests/contrib/tracking/test_dynamic_models.py | 48 +- tests/contrib/tracking/test_ekf.py | 23 +- tests/contrib/tracking/test_em.py | 155 +- tests/contrib/tracking/test_hashing.py | 82 +- tests/contrib/tracking/test_measurements.py | 13 +- tests/distributions/conftest.py | 1539 +++++---- tests/distributions/dist_fixture.py | 59 +- tests/distributions/test_binomial.py | 28 +- tests/distributions/test_categorical.py | 24 +- tests/distributions/test_coalescent.py | 14 +- tests/distributions/test_conjugate.py | 75 +- tests/distributions/test_conjugate_update.py | 2 +- tests/distributions/test_constraints.py | 9 +- tests/distributions/test_cuda.py | 11 +- tests/distributions/test_delta.py | 39 +- tests/distributions/test_distributions.py | 148 +- tests/distributions/test_empirical.py | 77 +- tests/distributions/test_extended.py | 20 +- tests/distributions/test_gaussian_mixtures.py | 125 +- tests/distributions/test_haar.py | 2 +- tests/distributions/test_hmm.py | 771 +++-- tests/distributions/test_ig.py | 14 +- tests/distributions/test_improper_uniform.py | 14 +- tests/distributions/test_independent.py | 90 +- tests/distributions/test_kl.py | 31 +- tests/distributions/test_lkj.py | 33 +- tests/distributions/test_mask.py | 60 +- tests/distributions/test_mixture.py | 67 +- tests/distributions/test_mvn.py | 33 +- tests/distributions/test_mvt.py | 102 +- tests/distributions/test_omt_mvn.py | 85 +- .../distributions/test_one_hot_categorical.py | 30 +- tests/distributions/test_one_one_matching.py | 22 +- tests/distributions/test_one_two_matching.py | 31 +- tests/distributions/test_ordered_logistic.py | 2 +- tests/distributions/test_pickle.py | 42 +- tests/distributions/test_polya_gamma.py | 8 +- tests/distributions/test_rejector.py | 78 +- .../test_relaxed_straight_through.py | 66 +- tests/distributions/test_reshape.py | 56 +- .../test_sine_bivariate_von_mises.py | 71 +- tests/distributions/test_sine_skewed.py | 53 +- tests/distributions/test_spanning_tree.py | 62 +- tests/distributions/test_stable.py | 29 +- tests/distributions/test_tensor_type.py | 10 +- tests/distributions/test_torch_patch.py | 2 +- tests/distributions/test_transforms.py | 145 +- tests/distributions/test_unit.py | 6 +- tests/distributions/test_util.py | 149 +- tests/distributions/test_von_mises.py | 88 +- tests/distributions/test_zero_inflated.py | 16 +- tests/doctest_fixtures.py | 24 +- tests/infer/mcmc/test_adaptation.py | 42 +- tests/infer/mcmc/test_hmc.py | 200 +- tests/infer/mcmc/test_mcmc_api.py | 276 +- tests/infer/mcmc/test_mcmc_util.py | 70 +- tests/infer/mcmc/test_nuts.py | 282 +- tests/infer/mcmc/test_valid_models.py | 323 +- tests/infer/reparam/test_conjugate.py | 53 +- tests/infer/reparam/test_discrete_cosine.py | 81 +- tests/infer/reparam/test_haar.py | 71 +- tests/infer/reparam/test_hmm.py | 110 +- tests/infer/reparam/test_loc_scale.py | 8 +- tests/infer/reparam/test_neutra.py | 59 +- tests/infer/reparam/test_softmax.py | 5 +- tests/infer/reparam/test_split.py | 57 +- tests/infer/reparam/test_stable.py | 18 +- tests/infer/reparam/test_structured.py | 10 +- tests/infer/reparam/test_studentt.py | 7 +- tests/infer/reparam/test_transform.py | 11 +- tests/infer/reparam/test_unit_jacobian.py | 4 +- tests/infer/test_abstract_infer.py | 37 +- tests/infer/test_autoguide.py | 743 +++-- tests/infer/test_compute_downstream_costs.py | 551 ++-- tests/infer/test_conjugate_gradients.py | 25 +- tests/infer/test_csis.py | 25 +- tests/infer/test_discrete.py | 251 +- tests/infer/test_elbo_mapdata.py | 79 +- tests/infer/test_enum.py | 2812 +++++++++++------ tests/infer/test_gradient.py | 182 +- tests/infer/test_inference.py | 582 ++-- tests/infer/test_initialization.py | 6 +- tests/infer/test_jit.py | 329 +- tests/infer/test_multi_sample_elbos.py | 6 +- tests/infer/test_predictive.py | 49 +- tests/infer/test_sampling.py | 45 +- tests/infer/test_smcfilter.py | 138 +- tests/infer/test_svgd.py | 74 +- tests/infer/test_tmc.py | 235 +- tests/infer/test_util.py | 38 +- tests/infer/test_valid_models.py | 1321 ++++---- .../test_conjugate_gaussian_models.py | 400 ++- .../integration_tests/test_tracegraph_elbo.py | 392 ++- tests/nn/test_autoregressive.py | 51 +- tests/nn/test_module.py | 135 +- tests/ops/einsum/test_adjoint.py | 110 +- tests/ops/einsum/test_torch_log.py | 74 +- tests/ops/test_arrowhead.py | 15 +- tests/ops/test_contract.py | 548 ++-- tests/ops/test_gamma_gaussian.py | 180 +- tests/ops/test_gaussian.py | 156 +- tests/ops/test_indexing.py | 197 +- tests/ops/test_integrator.py | 167 +- tests/ops/test_jit.py | 23 +- tests/ops/test_linalg.py | 27 +- tests/ops/test_newton.py | 62 +- tests/ops/test_packed.py | 14 +- tests/ops/test_special.py | 54 +- tests/ops/test_ssm_gp.py | 12 +- tests/ops/test_stats.py | 81 +- tests/ops/test_tensor_utils.py | 74 +- tests/ops/test_welford.py | 32 +- tests/optim/test_multi.py | 53 +- tests/optim/test_optim.py | 184 +- tests/params/test_module.py | 15 +- tests/params/test_param.py | 99 +- tests/perf/test_benchmark.py | 91 +- tests/poutine/test_counterfactual.py | 56 +- tests/poutine/test_mapdata.py | 82 +- tests/poutine/test_nesting.py | 1 - tests/poutine/test_poutines.py | 453 +-- tests/poutine/test_properties.py | 91 +- tests/poutine/test_trace_struct.py | 14 +- tests/pyroapi/test_pyroapi.py | 2 +- tests/test_examples.py | 574 ++-- tests/test_generic.py | 42 +- tests/test_primitives.py | 9 +- tests/test_util.py | 22 +- tutorial/source/cleannb.py | 4 +- tutorial/source/conf.py | 76 +- tutorial/source/search_inference.py | 54 +- 503 files changed, 30393 insertions(+), 17202 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 45da605482..879c1759bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,8 +31,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip wheel setuptools - pip install flake8 isort>=5.0 mypy nbstripout nbformat - - name: Lint with flake8 + pip install flake8 black isort>=5.0 mypy nbstripout nbformat + - name: Lint run: | make lint docs: diff --git a/Makefile b/Makefile index 9500b5f833..0f9197fa6c 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,7 @@ tutorial: FORCE lint: FORCE flake8 + black --check . isort --check . python scripts/update_headers.py --check mypy pyro @@ -29,6 +30,7 @@ license: FORCE python scripts/update_headers.py format: license FORCE + black . isort . version: FORCE diff --git a/docs/source/conf.py b/docs/source/conf.py index 78463717a4..1a984731e9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- General configuration ------------------------------------------------ @@ -38,15 +38,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.intersphinx', # - 'sphinx.ext.todo', # - 'sphinx.ext.mathjax', # - 'sphinx.ext.ifconfig', # - 'sphinx.ext.viewcode', # - 'sphinx.ext.githubpages', # - 'sphinx.ext.graphviz', # - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', + "sphinx.ext.intersphinx", # + "sphinx.ext.todo", # + "sphinx.ext.mathjax", # + "sphinx.ext.ifconfig", # + "sphinx.ext.viewcode", # + "sphinx.ext.githubpages", # + "sphinx.ext.graphviz", # + "sphinx.ext.autodoc", + "sphinx.ext.doctest", ] # Disable documentation inheritance so as to avoid inheriting @@ -56,31 +56,32 @@ autodoc_inherit_docstrings = False # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Pyro' -copyright = u'2017-2018, Uber Technologies, Inc' -author = u'Uber AI Labs' +project = u"Pyro" +copyright = u"2017-2018, Uber Technologies, Inc" +author = u"Uber AI Labs" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. -version = '' +version = "" -if 'READTHEDOCS' not in os.environ: +if "READTHEDOCS" not in os.environ: # if developing locally, use pyro.__version__ as version from pyro import __version__ # noqaE402 + version = __version__ # release version @@ -99,7 +100,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -110,10 +111,10 @@ # -- Options for HTML output ---------------------------------------------- # logo -html_logo = '_static/img/pyro_logo_wide.png' +html_logo = "_static/img/pyro_logo_wide.png" # logo -html_favicon = '_static/img/favicon/favicon.ico' +html_favicon = "_static/img/favicon/favicon.ico" # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. @@ -126,20 +127,20 @@ # documentation. html_theme_options = { - 'navigation_depth': 3, - 'logo_only': True, + "navigation_depth": 3, + "logo_only": True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_style = 'css/pyro.css' +html_static_path = ["_static"] +html_style = "css/pyro.css" # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'Pyrodoc' +htmlhelp_basename = "Pyrodoc" # -- Options for LaTeX output --------------------------------------------- @@ -147,15 +148,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -165,14 +163,14 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Pyro.tex', u'Pyro Documentation', u'Uber AI Labs', 'manual'), + (master_doc, "Pyro.tex", u"Pyro Documentation", u"Uber AI Labs", "manual"), ] # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, 'pyro', u'Pyro Documentation', [author], 1)] +man_pages = [(master_doc, "pyro", u"Pyro Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -180,19 +178,26 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Pyro', u'Pyro Documentation', author, 'Pyro', - 'Deep Universal Probabilistic Programming.', 'Miscellaneous'), + ( + master_doc, + "Pyro", + u"Pyro Documentation", + author, + "Pyro", + "Deep Universal Probabilistic Programming.", + "Miscellaneous", + ), ] # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'torch': ('https://pytorch.org/docs/master/', None), - 'funsor': ('http://funsor.pyro.ai/en/stable/', None), - 'opt_einsum': ('https://optimized-einsum.readthedocs.io/en/stable/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), - 'Bio': ('https://biopython.readthedocs.io/en/latest/', None), - 'horovod': ('https://horovod.readthedocs.io/en/stable/', None), + "python": ("https://docs.python.org/3/", None), + "torch": ("https://pytorch.org/docs/master/", None), + "funsor": ("http://funsor.pyro.ai/en/stable/", None), + "opt_einsum": ("https://optimized-einsum.readthedocs.io/en/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "Bio": ("https://biopython.readthedocs.io/en/latest/", None), + "horovod": ("https://horovod.readthedocs.io/en/stable/", None), } # document class constructors (__init__ methods): @@ -205,13 +210,17 @@ def skip(app, what, name, obj, skip, options): def setup(app): - app.add_css_file('css/pyro.css') + app.add_css_file("css/pyro.css") + + # app.connect("autodoc-skip-member", skip) # @jpchen's hack to get rtd builder to install latest pytorch # See similar line in the install section of .travis.yml -if 'READTHEDOCS' in os.environ: - os.system('pip install numpy') - os.system('pip install torch==1.9.0+cpu torchvision==0.10.0+cpu ' - '-f https://download.pytorch.org/whl/torch_stable.html') +if "READTHEDOCS" in os.environ: + os.system("pip install numpy") + os.system( + "pip install torch==1.9.0+cpu torchvision==0.10.0+cpu " + "-f https://download.pytorch.org/whl/torch_stable.html" + ) diff --git a/examples/air/air.py b/examples/air/air.py index 985e2853bf..f0c0e07118 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -25,34 +25,38 @@ def default_z_pres_prior_p(t): return 0.5 -ModelState = namedtuple('ModelState', ['x', 'z_pres', 'z_where']) -GuideState = namedtuple('GuideState', ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what']) +ModelState = namedtuple("ModelState", ["x", "z_pres", "z_where"]) +GuideState = namedtuple( + "GuideState", ["h", "c", "bl_h", "bl_c", "z_pres", "z_where", "z_what"] +) class AIR(nn.Module): - def __init__(self, - num_steps, - x_size, - window_size, - z_what_size, - rnn_hidden_size, - encoder_net=[], - decoder_net=[], - predict_net=[], - embed_net=None, - bl_predict_net=[], - non_linearity='ReLU', - decoder_output_bias=None, - decoder_output_use_sigmoid=False, - use_masking=True, - use_baselines=True, - baseline_scalar=None, - scale_prior_mean=3.0, - scale_prior_sd=0.1, - pos_prior_mean=0.0, - pos_prior_sd=1.0, - likelihood_sd=0.3, - use_cuda=False): + def __init__( + self, + num_steps, + x_size, + window_size, + z_what_size, + rnn_hidden_size, + encoder_net=[], + decoder_net=[], + predict_net=[], + embed_net=None, + bl_predict_net=[], + non_linearity="ReLU", + decoder_output_bias=None, + decoder_output_use_sigmoid=False, + use_masking=True, + use_baselines=True, + baseline_scalar=None, + scale_prior_mean=3.0, + scale_prior_sd=0.1, + pos_prior_mean=0.0, + pos_prior_sd=1.0, + likelihood_sd=0.3, + use_cuda=False, + ): super().__init__() @@ -66,7 +70,7 @@ def __init__(self, self.baseline_scalar = baseline_scalar self.likelihood_sd = likelihood_sd self.use_cuda = use_cuda - prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.) + prototype = torch.tensor(0.0).cuda() if use_cuda else torch.tensor(0.0) self.options = dict(dtype=prototype.dtype, device=prototype.device) self.z_pres_size = 1 @@ -76,10 +80,12 @@ def __init__(self, # optimization.) self.z_where_loc_prior = nn.Parameter( torch.FloatTensor([scale_prior_mean, pos_prior_mean, pos_prior_mean]), - requires_grad=False) + requires_grad=False, + ) self.z_where_scale_prior = nn.Parameter( torch.FloatTensor([scale_prior_sd, pos_prior_sd, pos_prior_sd]), - requires_grad=False) + requires_grad=False, + ) # Create nn modules. rnn_input_size = x_size ** 2 if embed_net is None else embed_net[-1] @@ -88,14 +94,26 @@ def __init__(self, self.rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size) self.encode = Encoder(window_size ** 2, encoder_net, z_what_size, nl) - self.decode = Decoder(window_size ** 2, decoder_net, z_what_size, - decoder_output_bias, decoder_output_use_sigmoid, nl) - self.predict = Predict(rnn_hidden_size, predict_net, self.z_pres_size, self.z_where_size, nl) - self.embed = Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True) + self.decode = Decoder( + window_size ** 2, + decoder_net, + z_what_size, + decoder_output_bias, + decoder_output_use_sigmoid, + nl, + ) + self.predict = Predict( + rnn_hidden_size, predict_net, self.z_pres_size, self.z_where_size, nl + ) + self.embed = ( + Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True) + ) self.bl_rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size) self.bl_predict = MLP(rnn_hidden_size, bl_predict_net + [1], nl) - self.bl_embed = Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True) + self.bl_embed = ( + Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True) + ) # Create parameters. self.h_init = nn.Parameter(torch.zeros(1, rnn_hidden_size)) @@ -113,7 +131,8 @@ def prior(self, n, **kwargs): state = ModelState( x=torch.zeros(n, self.x_size, self.x_size, **self.options), z_pres=torch.ones(n, self.z_pres_size, **self.options), - z_where=None) + z_where=None, + ) z_pres = [] z_where = [] @@ -128,9 +147,10 @@ def prior(self, n, **kwargs): def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): # Sample presence indicators. - z_pres = pyro.sample('z_pres_{}'.format(t), - dist.Bernoulli(z_pres_prior_p(t) * prev.z_pres) - .to_event(1)) + z_pres = pyro.sample( + "z_pres_{}".format(t), + dist.Bernoulli(z_pres_prior_p(t) * prev.z_pres).to_event(1), + ) # If zero is sampled for a data point, then no more objects # will be added to its output image. We can't @@ -139,18 +159,26 @@ def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): sample_mask = z_pres if self.use_masking else torch.tensor(1.0) # Sample attention window position. - z_where = pyro.sample('z_where_{}'.format(t), - dist.Normal(self.z_where_loc_prior.expand(n, self.z_where_size), - self.z_where_scale_prior.expand(n, self.z_where_size)) - .mask(sample_mask) - .to_event(1)) + z_where = pyro.sample( + "z_where_{}".format(t), + dist.Normal( + self.z_where_loc_prior.expand(n, self.z_where_size), + self.z_where_scale_prior.expand(n, self.z_where_size), + ) + .mask(sample_mask) + .to_event(1), + ) # Sample latent code for contents of the attention window. - z_what = pyro.sample('z_what_{}'.format(t), - dist.Normal(torch.zeros(n, self.z_what_size, **self.options), - torch.ones(n, self.z_what_size, **self.options)) - .mask(sample_mask) - .to_event(1)) + z_what = pyro.sample( + "z_what_{}".format(t), + dist.Normal( + torch.zeros(n, self.z_what_size, **self.options), + torch.ones(n, self.z_what_size, **self.options), + ) + .mask(sample_mask) + .to_event(1), + ) # Map latent code to pixel space. y_att = self.decode(z_what) @@ -167,42 +195,50 @@ def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): def model(self, data, batch_size, **kwargs): pyro.module("decode", self.decode) - with pyro.plate('data', data.size(0), device=data.device) as ix: + with pyro.plate("data", data.size(0), device=data.device) as ix: batch = data[ix] n = batch.size(0) (z_where, z_pres), x = self.prior(n, **kwargs) - pyro.sample('obs', - dist.Normal(x.view(n, -1), - (self.likelihood_sd * torch.ones(n, self.x_size ** 2, **self.options))) - .to_event(1), - obs=batch.view(n, -1)) + pyro.sample( + "obs", + dist.Normal( + x.view(n, -1), + ( + self.likelihood_sd + * torch.ones(n, self.x_size ** 2, **self.options) + ), + ).to_event(1), + obs=batch.view(n, -1), + ) def guide(self, data, batch_size, **kwargs): - pyro.module('rnn', self.rnn), - pyro.module('predict', self.predict), - pyro.module('encode', self.encode), - pyro.module('embed', self.embed), - pyro.module('bl_rnn', self.bl_rnn), - pyro.module('bl_predict', self.bl_predict), - pyro.module('bl_embed', self.bl_embed) - - pyro.param('h_init', self.h_init) - pyro.param('c_init', self.c_init) - pyro.param('z_where_init', self.z_where_init) - pyro.param('z_what_init', self.z_what_init) - pyro.param('bl_h_init', self.bl_h_init) - pyro.param('bl_c_init', self.bl_c_init) - - with pyro.plate('data', data.size(0), subsample_size=batch_size, device=data.device) as ix: + pyro.module("rnn", self.rnn), + pyro.module("predict", self.predict), + pyro.module("encode", self.encode), + pyro.module("embed", self.embed), + pyro.module("bl_rnn", self.bl_rnn), + pyro.module("bl_predict", self.bl_predict), + pyro.module("bl_embed", self.bl_embed) + + pyro.param("h_init", self.h_init) + pyro.param("c_init", self.c_init) + pyro.param("z_where_init", self.z_where_init) + pyro.param("z_what_init", self.z_what_init) + pyro.param("bl_h_init", self.bl_h_init) + pyro.param("bl_c_init", self.bl_c_init) + + with pyro.plate( + "data", data.size(0), subsample_size=batch_size, device=data.device + ) as ix: batch = data[ix] n = batch.size(0) # Embed inputs. flattened_batch = batch.view(n, -1) inputs = { - 'raw': batch, - 'embed': self.embed(flattened_batch), - 'bl_embed': self.bl_embed(flattened_batch) + "raw": batch, + "embed": self.embed(flattened_batch), + "bl_embed": self.bl_embed(flattened_batch), } # Initial state. @@ -213,7 +249,8 @@ def guide(self, data, batch_size, **kwargs): bl_c=batch_expand(self.bl_c_init, n), z_pres=torch.ones(n, self.z_pres_size, **self.options), z_where=batch_expand(self.z_where_init, n), - z_what=batch_expand(self.z_what_init, n)) + z_what=batch_expand(self.z_what_init, n), + ) z_pres = [] z_where = [] @@ -227,7 +264,9 @@ def guide(self, data, batch_size, **kwargs): def guide_step(self, t, n, prev, inputs): - rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1) + rnn_input = torch.cat( + (inputs["embed"], prev.z_where, prev.z_what, prev.z_pres), 1 + ) h, c = self.rnn(rnn_input, (prev.h, prev.c)) z_pres_p, z_where_loc, z_where_scale = self.predict(h) @@ -235,31 +274,45 @@ def guide_step(self, t, n, prev, inputs): infer_dict, bl_h, bl_c = self.baseline_step(prev, inputs) # Sample presence. - z_pres = pyro.sample('z_pres_{}'.format(t), - dist.Bernoulli(z_pres_p * prev.z_pres).to_event(1), - infer=infer_dict) + z_pres = pyro.sample( + "z_pres_{}".format(t), + dist.Bernoulli(z_pres_p * prev.z_pres).to_event(1), + infer=infer_dict, + ) sample_mask = z_pres if self.use_masking else torch.tensor(1.0) - z_where = pyro.sample('z_where_{}'.format(t), - dist.Normal(z_where_loc + self.z_where_loc_prior, - z_where_scale * self.z_where_scale_prior) - .mask(sample_mask) - .to_event(1)) + z_where = pyro.sample( + "z_where_{}".format(t), + dist.Normal( + z_where_loc + self.z_where_loc_prior, + z_where_scale * self.z_where_scale_prior, + ) + .mask(sample_mask) + .to_event(1), + ) # Figure 2 of [1] shows x_att depending on z_where and h, # rather than z_where and x as here, but I think this is # correct. - x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw']) + x_att = image_to_window(z_where, self.window_size, self.x_size, inputs["raw"]) # Encode attention windows. z_what_loc, z_what_scale = self.encode(x_att) - z_what = pyro.sample('z_what_{}'.format(t), - dist.Normal(z_what_loc, z_what_scale) - .mask(sample_mask) - .to_event(1)) - return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what) + z_what = pyro.sample( + "z_what_{}".format(t), + dist.Normal(z_what_loc, z_what_scale).mask(sample_mask).to_event(1), + ) + return GuideState( + h=h, + c=c, + bl_h=bl_h, + bl_c=bl_c, + z_pres=z_pres, + z_where=z_where, + z_what=z_what, + ) def baseline_step(self, prev, inputs): if not self.use_baselines: @@ -267,10 +320,15 @@ def baseline_step(self, prev, inputs): # Prevent gradients flowing back from baseline loss to # inference net by detaching from graph here. - rnn_input = torch.cat((inputs['bl_embed'], - prev.z_where.detach(), - prev.z_what.detach(), - prev.z_pres.detach()), 1) + rnn_input = torch.cat( + ( + inputs["bl_embed"], + prev.z_where.detach(), + prev.z_what.detach(), + prev.z_pres.detach(), + ), + 1, + ) bl_h, bl_c = self.bl_rnn(rnn_input, (prev.bl_h, prev.bl_c)) bl_value = self.bl_predict(bl_h) @@ -327,7 +385,7 @@ def z_where_inv(z_where): def window_to_image(z_where, window_size, image_size, windows): n = windows.size(0) - assert windows.size(1) == window_size ** 2, 'Size mismatch.' + assert windows.size(1) == window_size ** 2, "Size mismatch." theta = expand_z_where(z_where) grid = F.affine_grid(theta, torch.Size((n, 1, image_size, image_size))) out = F.grid_sample(windows.view(n, 1, window_size, window_size), grid) @@ -336,7 +394,7 @@ def window_to_image(z_where, window_size, image_size, windows): def image_to_window(z_where, window_size, image_size, images): n = images.size(0) - assert images.size(1) == images.size(2) == image_size, 'Size mismatch.' + assert images.size(1) == images.size(2) == image_size, "Size mismatch." theta_inv = expand_z_where(z_where_inv(z_where)) grid = F.affine_grid(theta_inv, torch.Size((n, 1, window_size, window_size))) out = F.grid_sample(images.view(n, 1, image_size, image_size), grid) @@ -354,6 +412,9 @@ def batch_expand(t, n): # a single tensor, with size: # [batch_size, num_steps, z_where_size + z_pres_size] def latents_to_tensor(z): - return torch.stack([ - torch.cat((z_where.cpu().data, z_pres.cpu().data), 1) - for z_where, z_pres in zip(*z)]).transpose(0, 1) + return torch.stack( + [ + torch.cat((z_where.cpu().data, z_pres.cpu().data), 1) + for z_where, z_pres in zip(*z) + ] + ).transpose(0, 1) diff --git a/examples/air/main.py b/examples/air/main.py index 6a8736e12c..53f0484c40 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -31,8 +31,8 @@ def count_accuracy(X, true_counts, air, batch_size): - assert X.size(0) == true_counts.size(0), 'Size mismatch.' - assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.' + assert X.size(0) == true_counts.size(0), "Size mismatch." + assert X.size(0) % batch_size == 0, "Input size must be multiple of batch_size." counts = torch.LongTensor(3, 4).zero_() error_latents = [] error_indicators = [] @@ -43,8 +43,8 @@ def count_vec_to_mat(vec, max_index): return out for i in range(X.size(0) // batch_size): - X_batch = X[i * batch_size:(i + 1) * batch_size] - true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size] + X_batch = X[i * batch_size : (i + 1) * batch_size] + true_counts_batch = true_counts[i * batch_size : (i + 1) * batch_size] z_where, z_pres = air.guide(X_batch, batch_size) inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data true_counts_m = count_vec_to_mat(true_counts_batch, 2) @@ -52,7 +52,9 @@ def count_vec_to_mat(vec, max_index): counts += torch.mm(true_counts_m.t(), inferred_counts_m) error_ind = 1 - (true_counts_batch == inferred_counts) error_ix = error_ind.nonzero(as_tuple=False).squeeze() - error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix)) + error_latents.append( + latents_to_tensor((z_where, z_pres)).index_select(0, error_ix) + ) error_indicators.append(error_ind) acc = counts.diag().sum().float() / X.size(0) @@ -67,10 +69,10 @@ def count_vec_to_mat(vec, max_index): # between p(steps=n) and p(steps=n+1). def make_prior(k): assert 0 < k <= 1 - u = 1 / (1 + k + k**2 + k**3) + u = 1 / (1 + k + k ** 2 + k ** 3) p0 = 1 - u p1 = 1 - (k * u) / p0 - p2 = 1 - (k**2 * u) / (p0 * p1) + p2 = 1 - (k ** 2 * u) / (p0 * p1) trial_probs = [p0, p1, p2] # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2] # print(dist) @@ -129,7 +131,7 @@ def main(**kwargs): args = argparse.Namespace(**kwargs) - if 'save' in args: + if "save" in args: if os.path.exists(args.save): raise RuntimeError('Output file "{}" already exists.'.format(args.save)) @@ -143,36 +145,45 @@ def main(**kwargs): # Build a function to compute z_pres prior probabilities. if args.z_pres_prior_raw: + def base_z_pres_prior_p(t): return args.z_pres_prior + else: base_z_pres_prior_p = make_prior(args.z_pres_prior) # Wrap with logic to apply any annealing. def z_pres_prior_p(opt_step, time_step): p = base_z_pres_prior_p(time_step) - if args.anneal_prior == 'none': + if args.anneal_prior == "none": return p else: decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior] - return decay(p, args.anneal_prior_to, args.anneal_prior_begin, - args.anneal_prior_duration, opt_step) - - model_arg_keys = ['window_size', - 'rnn_hidden_size', - 'decoder_output_bias', - 'decoder_output_use_sigmoid', - 'baseline_scalar', - 'encoder_net', - 'decoder_net', - 'predict_net', - 'embed_net', - 'bl_predict_net', - 'non_linearity', - 'pos_prior_mean', - 'pos_prior_sd', - 'scale_prior_mean', - 'scale_prior_sd'] + return decay( + p, + args.anneal_prior_to, + args.anneal_prior_begin, + args.anneal_prior_duration, + opt_step, + ) + + model_arg_keys = [ + "window_size", + "rnn_hidden_size", + "decoder_output_bias", + "decoder_output_use_sigmoid", + "baseline_scalar", + "encoder_net", + "decoder_net", + "predict_net", + "embed_net", + "bl_predict_net", + "non_linearity", + "pos_prior_mean", + "pos_prior_sd", + "scale_prior_mean", + "scale_prior_sd", + ] model_args = {key: getattr(args, key) for key in model_arg_keys if key in args} air = AIR( num_steps=args.model_steps, @@ -188,8 +199,8 @@ def z_pres_prior_p(opt_step, time_step): print(air) print(args) - if 'load' in args: - print('Loading parameters...') + if "load" in args: + print("Loading parameters...") air.load_state_dict(torch.load(args.load)) # Viz sample from prior. @@ -199,11 +210,15 @@ def z_pres_prior_p(opt_step, time_step): vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z)))) def isBaselineParam(param_name): - return 'bl_' in param_name + return "bl_" in param_name def per_param_optim_args(param_name): - lr = args.baseline_learning_rate if isBaselineParam(param_name) else args.learning_rate - return {'lr': lr} + lr = ( + args.baseline_learning_rate + if isBaselineParam(param_name) + else args.learning_rate + ) + return {"lr": lr} adam = optim.Adam(per_param_optim_args) elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO() @@ -215,14 +230,19 @@ def per_param_optim_args(param_name): for i in range(1, args.num_steps + 1): - loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) + loss = svi.step( + X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i) + ) if args.progress_every > 0 and i % args.progress_every == 0: - print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( - i, - (i * args.batch_size) / X_size, - (time.time() - t0) / 3600, - loss / X_size)) + print( + "i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}".format( + i, + (i * args.batch_size) / X_size, + (time.time() - t0) / 3600, + loss / X_size, + ) + ) if args.viz and i % args.viz_every == 0: trace = poutine.trace(air.guide).get_trace(examples_to_viz, None) @@ -237,98 +257,199 @@ def per_param_optim_args(param_name): if args.eval_every > 0 and i % args.eval_every == 0: # Measure accuracy on subset of training data. acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000) - print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist())) + print( + "i={}, accuracy={}, counts={}".format(i, acc, counts.numpy().tolist()) + ) if args.viz and error_ix.size(0) > 0: - vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])), - opts=dict(caption='errors ({})'.format(i))) + vis.images( + draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])), + opts=dict(caption="errors ({})".format(i)), + ) - if 'save' in args and i % args.save_every == 0: - print('Saving parameters...') + if "save" in args and i % args.save_every == 0: + print("Saving parameters...") torch.save(air.state_dict(), args.save) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS) - parser.add_argument('-n', '--num-steps', type=int, default=int(1e8), - help='number of optimization steps to take') - parser.add_argument('-b', '--batch-size', type=int, default=64, - help='batch size') - parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, - help='learning rate') - parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3, - help='baseline learning rate') - parser.add_argument('--progress-every', type=int, default=1, - help='number of steps between writing progress to stdout') - parser.add_argument('--eval-every', type=int, default=0, - help='number of steps between evaluations') - parser.add_argument('--baseline-scalar', type=float, - help='scale the output of the baseline nets by this value') - parser.add_argument('--no-baselines', action='store_true', default=False, - help='do not use data dependent baselines') - parser.add_argument('--encoder-net', type=int, nargs='+', default=[200], - help='encoder net hidden layer sizes') - parser.add_argument('--decoder-net', type=int, nargs='+', default=[200], - help='decoder net hidden layer sizes') - parser.add_argument('--predict-net', type=int, nargs='+', - help='predict net hidden layer sizes') - parser.add_argument('--embed-net', type=int, nargs='+', - help='embed net architecture') - parser.add_argument('--bl-predict-net', type=int, nargs='+', - help='baseline predict net hidden layer sizes') - parser.add_argument('--non-linearity', type=str, - help='non linearity to use throughout') - parser.add_argument('--viz', action='store_true', default=False, - help='generate vizualizations during optimization') - parser.add_argument('--viz-every', type=int, default=100, - help='number of steps between vizualizations') - parser.add_argument('--visdom-env', default='main', - help='visdom enviroment name') - parser.add_argument('--load', type=str, - help='load previously saved parameters') - parser.add_argument('--save', type=str, - help='save parameters to specified file') - parser.add_argument('--save-every', type=int, default=1e4, - help='number of steps between parameter saves') - parser.add_argument('--cuda', action='store_true', default=False, - help='use cuda') - parser.add_argument('--jit', action='store_true', default=False, - help='use PyTorch jit') - parser.add_argument('-t', '--model-steps', type=int, default=3, - help='number of time steps') - parser.add_argument('--rnn-hidden-size', type=int, default=256, - help='rnn hidden size') - parser.add_argument('--encoder-latent-size', type=int, default=50, - help='attention window encoder/decoder latent space size') - parser.add_argument('--decoder-output-bias', type=float, - help='bias added to decoder output (prior to applying non-linearity)') - parser.add_argument('--decoder-output-use-sigmoid', action='store_true', - help='apply sigmoid function to output of decoder network') - parser.add_argument('--window-size', type=int, default=28, - help='attention window size') - parser.add_argument('--z-pres-prior', type=float, default=0.5, - help='prior success probability for z_pres') - parser.add_argument('--z-pres-prior-raw', action='store_true', default=False, - help='use --z-pres-prior directly as success prob instead of a geometric like prior') - parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none', - help='anneal z_pres prior during optimization') - parser.add_argument('--anneal-prior-to', type=float, default=1e-7, - help='target z_pres prior prob') - parser.add_argument('--anneal-prior-begin', type=int, default=0, - help='number of steps to wait before beginning to anneal the prior') - parser.add_argument('--anneal-prior-duration', type=int, default=100000, - help='number of steps over which to anneal the prior') - parser.add_argument('--pos-prior-mean', type=float, - help='mean of the window position prior') - parser.add_argument('--pos-prior-sd', type=float, - help='std. dev. of the window position prior') - parser.add_argument('--scale-prior-mean', type=float, - help='mean of the window scale prior') - parser.add_argument('--scale-prior-sd', type=float, - help='std. dev. of the window scale prior') - parser.add_argument('--no-masking', action='store_true', default=False, - help='do not mask out the costs of unused choices') - parser.add_argument('--seed', type=int, help='random seed', default=None) - parser.add_argument('-v', '--verbose', action='store_true', default=False, - help='write hyper parameters and network architecture to stdout') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="Pyro AIR example", argument_default=argparse.SUPPRESS + ) + parser.add_argument( + "-n", + "--num-steps", + type=int, + default=int(1e8), + help="number of optimization steps to take", + ) + parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size") + parser.add_argument( + "-lr", "--learning-rate", type=float, default=1e-4, help="learning rate" + ) + parser.add_argument( + "-blr", + "--baseline-learning-rate", + type=float, + default=1e-3, + help="baseline learning rate", + ) + parser.add_argument( + "--progress-every", + type=int, + default=1, + help="number of steps between writing progress to stdout", + ) + parser.add_argument( + "--eval-every", type=int, default=0, help="number of steps between evaluations" + ) + parser.add_argument( + "--baseline-scalar", + type=float, + help="scale the output of the baseline nets by this value", + ) + parser.add_argument( + "--no-baselines", + action="store_true", + default=False, + help="do not use data dependent baselines", + ) + parser.add_argument( + "--encoder-net", + type=int, + nargs="+", + default=[200], + help="encoder net hidden layer sizes", + ) + parser.add_argument( + "--decoder-net", + type=int, + nargs="+", + default=[200], + help="decoder net hidden layer sizes", + ) + parser.add_argument( + "--predict-net", type=int, nargs="+", help="predict net hidden layer sizes" + ) + parser.add_argument( + "--embed-net", type=int, nargs="+", help="embed net architecture" + ) + parser.add_argument( + "--bl-predict-net", + type=int, + nargs="+", + help="baseline predict net hidden layer sizes", + ) + parser.add_argument( + "--non-linearity", type=str, help="non linearity to use throughout" + ) + parser.add_argument( + "--viz", + action="store_true", + default=False, + help="generate vizualizations during optimization", + ) + parser.add_argument( + "--viz-every", + type=int, + default=100, + help="number of steps between vizualizations", + ) + parser.add_argument("--visdom-env", default="main", help="visdom enviroment name") + parser.add_argument("--load", type=str, help="load previously saved parameters") + parser.add_argument("--save", type=str, help="save parameters to specified file") + parser.add_argument( + "--save-every", + type=int, + default=1e4, + help="number of steps between parameter saves", + ) + parser.add_argument("--cuda", action="store_true", default=False, help="use cuda") + parser.add_argument( + "--jit", action="store_true", default=False, help="use PyTorch jit" + ) + parser.add_argument( + "-t", "--model-steps", type=int, default=3, help="number of time steps" + ) + parser.add_argument( + "--rnn-hidden-size", type=int, default=256, help="rnn hidden size" + ) + parser.add_argument( + "--encoder-latent-size", + type=int, + default=50, + help="attention window encoder/decoder latent space size", + ) + parser.add_argument( + "--decoder-output-bias", + type=float, + help="bias added to decoder output (prior to applying non-linearity)", + ) + parser.add_argument( + "--decoder-output-use-sigmoid", + action="store_true", + help="apply sigmoid function to output of decoder network", + ) + parser.add_argument( + "--window-size", type=int, default=28, help="attention window size" + ) + parser.add_argument( + "--z-pres-prior", + type=float, + default=0.5, + help="prior success probability for z_pres", + ) + parser.add_argument( + "--z-pres-prior-raw", + action="store_true", + default=False, + help="use --z-pres-prior directly as success prob instead of a geometric like prior", + ) + parser.add_argument( + "--anneal-prior", + choices="none lin exp".split(), + default="none", + help="anneal z_pres prior during optimization", + ) + parser.add_argument( + "--anneal-prior-to", type=float, default=1e-7, help="target z_pres prior prob" + ) + parser.add_argument( + "--anneal-prior-begin", + type=int, + default=0, + help="number of steps to wait before beginning to anneal the prior", + ) + parser.add_argument( + "--anneal-prior-duration", + type=int, + default=100000, + help="number of steps over which to anneal the prior", + ) + parser.add_argument( + "--pos-prior-mean", type=float, help="mean of the window position prior" + ) + parser.add_argument( + "--pos-prior-sd", type=float, help="std. dev. of the window position prior" + ) + parser.add_argument( + "--scale-prior-mean", type=float, help="mean of the window scale prior" + ) + parser.add_argument( + "--scale-prior-sd", type=float, help="std. dev. of the window scale prior" + ) + parser.add_argument( + "--no-masking", + action="store_true", + default=False, + help="do not mask out the costs of unused choices", + ) + parser.add_argument("--seed", type=int, help="random seed", default=None) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="write hyper parameters and network architecture to stdout", + ) main(**vars(parser.parse_args())) diff --git a/examples/air/modules.py b/examples/air/modules.py index 084ff91d46..bb57c52cd4 100644 --- a/examples/air/modules.py +++ b/examples/air/modules.py @@ -18,7 +18,7 @@ def __init__(self, x_size, h_sizes, z_size, non_linear_layer): def forward(self, x): a = self.mlp(x) - return a[:, 0:self.z_size], softplus(a[:, self.z_size:]) + return a[:, 0 : self.z_size], softplus(a[:, self.z_size :]) # Takes a latent code, z_what, to pixel intensities. @@ -42,7 +42,9 @@ def forward(self, z): # [Linear (256 -> 256), ReLU (), Linear (256 -> 1), ReLU ()] # etc. class MLP(nn.Module): - def __init__(self, in_size, out_sizes, non_linear_layer, output_non_linearity=False): + def __init__( + self, in_size, out_sizes, non_linear_layer, output_non_linearity=False + ): super().__init__() assert len(out_sizes) >= 1 layers = [] @@ -63,7 +65,9 @@ def forward(self, x): # Takes the guide RNN hidden state to parameters of the guide # distributions over z_where and z_pres. class Predict(nn.Module): - def __init__(self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_layer): + def __init__( + self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_layer + ): super().__init__() self.z_pres_size = z_pres_size self.z_where_size = z_where_size @@ -72,9 +76,9 @@ def __init__(self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_la def forward(self, h): out = self.mlp(h) - z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size]) - z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size] - z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):]) + z_pres_p = torch.sigmoid(out[:, 0 : self.z_pres_size]) + z_where_loc = out[:, self.z_pres_size : self.z_pres_size + self.z_where_size] + z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size) :]) return z_pres_p, z_where_loc, z_where_scale diff --git a/examples/air/viz.py b/examples/air/viz.py index ca40237832..938003515a 100644 --- a/examples/air/viz.py +++ b/examples/air/viz.py @@ -13,8 +13,8 @@ def bounding_box(z_where, x_size): enough to be usable.""" w = x_size / z_where.s h = x_size / z_where.s - xtrans = -z_where.x / z_where.s * x_size / 2. - ytrans = -z_where.y / z_where.s * x_size / 2. + xtrans = -z_where.x / z_where.s * x_size / 2.0 + ytrans = -z_where.y / z_where.s * x_size / 2.0 x = (x_size - w) / 2 + xtrans # origin is top left y = (x_size - h) / 2 + ytrans return (x, y), w, h @@ -22,13 +22,17 @@ def bounding_box(z_where, x_size): def arr2img(arr): # arr is expected to be a 2d array of floats in [0,1] - return Image.frombuffer('L', arr.shape, (arr * 255).astype(np.uint8).tostring(), 'raw', 'L', 0, 1) + return Image.frombuffer( + "L", arr.shape, (arr * 255).astype(np.uint8).tostring(), "raw", "L", 0, 1 + ) def img2arr(img): # assumes color image # returns an array suitable for sending to visdom - return np.array(img.getdata(), np.uint8).reshape(img.size + (3,)).transpose((2, 0, 1)) + return ( + np.array(img.getdata(), np.uint8).reshape(img.size + (3,)).transpose((2, 0, 1)) + ) def colors(k): @@ -40,7 +44,7 @@ def draw_one(imgarr, z_arr): # misleading, as it incorrectly suggests objects occlude one # another. clipped = np.clip(imgarr.detach().cpu().numpy(), 0, 1) - img = arr2img(clipped).convert('RGB') + img = arr2img(clipped).convert("RGB") draw = ImageDraw.Draw(img) for k, z in enumerate(z_arr): # It would be better to use z_pres to change the opacity of @@ -52,8 +56,8 @@ def draw_one(imgarr, z_arr): color = tuple(map(lambda c: int(c * z.pres), colors(k))) draw.rectangle([x, y, x + w, y + h], outline=color) is_relaxed = any(z.pres != math.floor(z.pres) for z in z_arr) - fmtstr = '{:.1f}' if is_relaxed else '{:.0f}' - draw.text((0, 0), fmtstr.format(sum(z.pres for z in z_arr)), fill='white') + fmtstr = "{:.1f}" if is_relaxed else "{:.0f}" + draw.text((0, 0), fmtstr.format(sum(z.pres for z in z_arr)), fill="white") return img2arr(img) @@ -63,7 +67,7 @@ def draw_many(imgarrs, z_arr): return [draw_one(imgarr, z) for (imgarr, z) in zip(imgarrs.cpu(), z_arr)] -z_obj = namedtuple('z', 's,x,y,pres') +z_obj = namedtuple("z", "s,x,y,pres") # Map a tensor of latents (as produced by latents_to_tensor) to a list diff --git a/examples/baseball.py b/examples/baseball.py index b2222fd7b7..52db4ef11a 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -51,7 +51,7 @@ path lengths in Hamiltonian Monte Carlo", (https://arxiv.org/abs/1111.4246) """ -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt" @@ -106,7 +106,9 @@ def partially_pooled(at_bats, hits): """ num_players = at_bats.shape[0] m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))) - kappa = pyro.sample("kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5))) + kappa = pyro.sample( + "kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5)) + ) with pyro.plate("num_players", num_players): phi_prior = Beta(m * kappa, (1 - m) * kappa) phi = pyro.sample("phi", phi_prior) @@ -136,7 +138,14 @@ def partially_pooled_with_logit(at_bats, hits): # =================================== -def get_summary_table(posterior, sites, player_names, transforms={}, diagnostics=False, group_by_chain=False): +def get_summary_table( + posterior, + sites, + player_names, + transforms={}, + diagnostics=False, + group_by_chain=False, +): """ Return summarized statistics for each of the ``sites`` in the traces corresponding to the approximate posterior. @@ -149,7 +158,9 @@ def get_summary_table(posterior, sites, player_names, transforms={}, diagnostics if site_name in transforms: marginal_site = transforms[site_name](marginal_site) - site_summary = summary({site_name: marginal_site}, prob=0.5, group_by_chain=group_by_chain)[site_name] + site_summary = summary( + {site_name: marginal_site}, prob=0.5, group_by_chain=group_by_chain + )[site_name] if site_summary["mean"].shape: site_df = pd.DataFrame(site_summary, index=player_names) else: @@ -167,11 +178,19 @@ def train_test_split(pd_dataframe): Validation data - Full season at-bats and hits for each player. """ device = torch.Tensor().device - train_data = torch.tensor(pd_dataframe[["At-Bats", "Hits"]].values, dtype=torch.float, device=device) - test_data = torch.tensor(pd_dataframe[["SeasonAt-Bats", "SeasonHits"]].values, dtype=torch.float, device=device) + train_data = torch.tensor( + pd_dataframe[["At-Bats", "Hits"]].values, dtype=torch.float, device=device + ) + test_data = torch.tensor( + pd_dataframe[["SeasonAt-Bats", "SeasonHits"]].values, + dtype=torch.float, + device=device, + ) first_name = pd_dataframe["FirstName"].values last_name = pd_dataframe["LastName"].values - player_names = [" ".join([first, last]) for first, last in zip(first_name, last_name)] + player_names = [ + " ".join([first, last]) for first, last in zip(first_name, last_name) + ] return train_data, test_data, player_names @@ -193,19 +212,21 @@ def sample_posterior_predictive(model, posterior_samples, baseball_dataset): logging.info("-----------------------------") # set hits=None to convert it from observation node to sample node train_predict = Predictive(model, posterior_samples)(at_bats, None) - train_summary = get_summary_table(train_predict, - sites=["obs"], - player_names=player_names)["obs"] + train_summary = get_summary_table( + train_predict, sites=["obs"], player_names=player_names + )["obs"] train_summary = train_summary.assign(ActualHits=baseball_dataset[["Hits"]].values) logging.info(train_summary) logging.info("\nHit Rate - Season Predictions") logging.info("-----------------------------") with ignore_experimental_warning(): test_predict = Predictive(model, posterior_samples)(at_bats_season, None) - test_summary = get_summary_table(test_predict, - sites=["obs"], - player_names=player_names)["obs"] - test_summary = test_summary.assign(ActualHits=baseball_dataset[["SeasonHits"]].values) + test_summary = get_summary_table( + test_predict, sites=["obs"], player_names=player_names + )["obs"] + test_summary = test_summary.assign( + ActualHits=baseball_dataset[["SeasonHits"]].values + ) logging.info(test_summary) @@ -216,7 +237,9 @@ def evaluate_pointwise_pred_density(model, posterior_samples, baseball_dataset): """ _, test, player_names = train_test_split(baseball_dataset) at_bats_season, hits_season = test[:, 0], test[:, 1] - trace = Predictive(model, posterior_samples).get_vectorized_trace(at_bats_season, hits_season) + trace = Predictive(model, posterior_samples).get_vectorized_trace( + at_bats_season, hits_season + ) # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $, # where $\theta^{i}$ are parameter samples from the model's posterior. trace.compute_log_prob() @@ -238,46 +261,64 @@ def main(args): # (1) Full Pooling Model # In this model, we illustrate how to use MCMC with general potential_fn. init_params, potential_fn, transforms, _ = initialize_model( - fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains, - jit_compile=args.jit, skip_jit_warnings=True) + fully_pooled, + model_args=(at_bats, hits), + num_chains=args.num_chains, + jit_compile=args.jit, + skip_jit_warnings=True, + ) nuts_kernel = NUTS(potential_fn=potential_fn) - mcmc = MCMC(nuts_kernel, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps, - num_chains=args.num_chains, - initial_params=init_params, - transforms=transforms) + mcmc = MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + initial_params=init_params, + transforms=transforms, + ) mcmc.run(at_bats, hits) samples_fully_pooled = mcmc.get_samples() logging.info("\nModel: Fully Pooled") logging.info("===================") logging.info("\nphi:") - logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), - sites=["phi"], - player_names=player_names, - diagnostics=True, - group_by_chain=True)["phi"]) + logging.info( + get_summary_table( + mcmc.get_samples(group_by_chain=True), + sites=["phi"], + player_names=player_names, + diagnostics=True, + group_by_chain=True, + )["phi"] + ) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset) - evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset) + evaluate_pointwise_pred_density( + fully_pooled, samples_fully_pooled, baseball_dataset + ) # (2) No Pooling Model nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True) - mcmc = MCMC(nuts_kernel, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps, - num_chains=args.num_chains) + mcmc = MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + ) mcmc.run(at_bats, hits) samples_not_pooled = mcmc.get_samples() logging.info("\nModel: Not Pooled") logging.info("=================") logging.info("\nphi:") - logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), - sites=["phi"], - player_names=player_names, - diagnostics=True, - group_by_chain=True)["phi"]) + logging.info( + get_summary_table( + mcmc.get_samples(group_by_chain=True), + sites=["phi"], + player_names=player_names, + diagnostics=True, + group_by_chain=True, + )["phi"] + ) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset) @@ -285,61 +326,83 @@ def main(args): # (3) Partially Pooled Model nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True) - mcmc = MCMC(nuts_kernel, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps, - num_chains=args.num_chains) + mcmc = MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + ) mcmc.run(at_bats, hits) samples_partially_pooled = mcmc.get_samples() logging.info("\nModel: Partially Pooled") logging.info("=======================") logging.info("\nphi:") - logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), - sites=["phi"], - player_names=player_names, - diagnostics=True, - group_by_chain=True)["phi"]) + logging.info( + get_summary_table( + mcmc.get_samples(group_by_chain=True), + sites=["phi"], + player_names=player_names, + diagnostics=True, + group_by_chain=True, + )["phi"] + ) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) - sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset) - evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset) + sample_posterior_predictive( + partially_pooled, samples_partially_pooled, baseball_dataset + ) + evaluate_pointwise_pred_density( + partially_pooled, samples_partially_pooled, baseball_dataset + ) # (4) Partially Pooled with Logit Model - nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True) - mcmc = MCMC(nuts_kernel, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps, - num_chains=args.num_chains) + nuts_kernel = NUTS( + partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True + ) + mcmc = MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + ) mcmc.run(at_bats, hits) samples_partially_pooled_logit = mcmc.get_samples() logging.info("\nModel: Partially Pooled with Logit") logging.info("==================================") logging.info("\nSigmoid(alpha):") - logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), - sites=["alpha"], - player_names=player_names, - transforms={"alpha": torch.sigmoid}, - diagnostics=True, - group_by_chain=True)["alpha"]) + logging.info( + get_summary_table( + mcmc.get_samples(group_by_chain=True), + sites=["alpha"], + player_names=player_names, + transforms={"alpha": torch.sigmoid}, + diagnostics=True, + group_by_chain=True, + )["alpha"] + ) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) - sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit, - baseball_dataset) - evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit, - baseball_dataset) + sample_posterior_predictive( + partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset + ) + evaluate_pointwise_pred_density( + partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset + ) if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Baseball batting average using HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int) - parser.add_argument("--num-chains", nargs='?', default=4, type=int) - parser.add_argument("--warmup-steps", nargs='?', default=100, type=int) - parser.add_argument("--rng_seed", nargs='?', default=0, type=int) - parser.add_argument("--jit", action="store_true", default=False, - help="use PyTorch jit") - parser.add_argument("--cuda", action="store_true", default=False, - help="run this example in GPU") + parser.add_argument("--num-chains", nargs="?", default=4, type=int) + parser.add_argument("--warmup-steps", nargs="?", default=100, type=int) + parser.add_argument("--rng_seed", nargs="?", default=0, type=int) + parser.add_argument( + "--jit", action="store_true", default=False, help="use PyTorch jit" + ) + parser.add_argument( + "--cuda", action="store_true", default=False, help="run this example in GPU" + ) args = parser.parse_args() # work around the error "CUDA error: initialization error" diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index fa868899d5..a0923f2500 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -63,14 +63,20 @@ def model_1(capture_history, sex): first_capture_mask = torch.zeros(N).bool() for t in pyro.markov(range(T)): with poutine.mask(mask=first_capture_mask): - mu_z_t = first_capture_mask.float() * phi * z + (1 - first_capture_mask.float()) + mu_z_t = first_capture_mask.float() * phi * z + ( + 1 - first_capture_mask.float() + ) # we use parallel enumeration to exactly sum out # the discrete states z_t. - z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), - infer={"enumerate": "parallel"}) + z = pyro.sample( + "z_{}".format(t), + dist.Bernoulli(mu_z_t), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z - pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), - obs=capture_history[:, t]) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t] + ) first_capture_mask |= capture_history[:, t].bool() @@ -91,17 +97,24 @@ def model_2(capture_history, sex): for t in pyro.markov(range(T)): # note that phi_t needs to be outside the plate, since # phi_t is shared across all N individuals - phi_t = pyro.sample("phi_{}".format(t), dist.Uniform(0.0, 1.0)) if t > 0 \ - else 1.0 + phi_t = ( + pyro.sample("phi_{}".format(t), dist.Uniform(0.0, 1.0)) if t > 0 else 1.0 + ) with animals_plate, poutine.mask(mask=first_capture_mask): - mu_z_t = first_capture_mask.float() * phi_t * z + (1 - first_capture_mask.float()) + mu_z_t = first_capture_mask.float() * phi_t * z + ( + 1 - first_capture_mask.float() + ) # we use parallel enumeration to exactly sum out # the discrete states z_t. - z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), - infer={"enumerate": "parallel"}) + z = pyro.sample( + "z_{}".format(t), + dist.Bernoulli(mu_z_t), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z - pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), - obs=capture_history[:, t]) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t] + ) first_capture_mask |= capture_history[:, t].bool() @@ -115,8 +128,11 @@ def model_2(capture_history, sex): def model_3(capture_history, sex): def logit(p): return torch.log(p) - torch.log1p(-p) + N, T = capture_history.shape - phi_mean = pyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival probability + phi_mean = pyro.sample( + "phi_mean", dist.Uniform(0.0, 1.0) + ) # mean survival probability phi_logit_mean = logit(phi_mean) # controls temporal variability of survival probability phi_sigma = pyro.sample("phi_sigma", dist.Uniform(0.0, 10.0)) @@ -127,19 +143,29 @@ def logit(p): # we create the plate once, outside of the loop over t animals_plate = pyro.plate("animals", N, dim=-1) for t in pyro.markov(range(T)): - phi_logit_t = pyro.sample("phi_logit_{}".format(t), - dist.Normal(phi_logit_mean, phi_sigma)) if t > 0 \ - else torch.tensor(0.0) + phi_logit_t = ( + pyro.sample( + "phi_logit_{}".format(t), dist.Normal(phi_logit_mean, phi_sigma) + ) + if t > 0 + else torch.tensor(0.0) + ) phi_t = torch.sigmoid(phi_logit_t) with animals_plate, poutine.mask(mask=first_capture_mask): - mu_z_t = first_capture_mask.float() * phi_t * z + (1 - first_capture_mask.float()) + mu_z_t = first_capture_mask.float() * phi_t * z + ( + 1 - first_capture_mask.float() + ) # we use parallel enumeration to exactly sum out # the discrete states z_t. - z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), - infer={"enumerate": "parallel"}) + z = pyro.sample( + "z_{}".format(t), + dist.Bernoulli(mu_z_t), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z - pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), - obs=capture_history[:, t]) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t] + ) first_capture_mask |= capture_history[:, t].bool() @@ -166,14 +192,20 @@ def model_4(capture_history, sex): first_capture_mask = torch.zeros(N).bool() for t in pyro.markov(range(T)): with poutine.mask(mask=first_capture_mask): - mu_z_t = first_capture_mask.float() * phi * z + (1 - first_capture_mask.float()) + mu_z_t = first_capture_mask.float() * phi * z + ( + 1 - first_capture_mask.float() + ) # we use parallel enumeration to exactly sum out # the discrete states z_t. - z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), - infer={"enumerate": "parallel"}) + z = pyro.sample( + "z_{}".format(t), + dist.Bernoulli(mu_z_t), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z - pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), - obs=capture_history[:, t]) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t] + ) first_capture_mask |= capture_history[:, t].bool() @@ -203,24 +235,35 @@ def model_5(capture_history, sex): # we create the plate once, outside of the loop over t animals_plate = pyro.plate("animals", N, dim=-1) for t in pyro.markov(range(T)): - phi_gamma_t = pyro.sample("phi_gamma_{}".format(t), dist.Normal(0.0, 10.0)) if t > 0 \ - else 0.0 + phi_gamma_t = ( + pyro.sample("phi_gamma_{}".format(t), dist.Normal(0.0, 10.0)) + if t > 0 + else 0.0 + ) phi_t = torch.sigmoid(phi_beta + phi_gamma_t) with animals_plate, poutine.mask(mask=first_capture_mask): - mu_z_t = first_capture_mask.float() * phi_t * z + (1 - first_capture_mask.float()) + mu_z_t = first_capture_mask.float() * phi_t * z + ( + 1 - first_capture_mask.float() + ) # we use parallel enumeration to exactly sum out # the discrete states z_t. - z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), - infer={"enumerate": "parallel"}) + z = pyro.sample( + "z_{}".format(t), + dist.Bernoulli(mu_z_t), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z - pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), - obs=capture_history[:, t]) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t] + ) first_capture_mask |= capture_history[:, t].bool() -models = {name[len('model_'):]: model - for name, model in globals().items() - if name.startswith('model_')} +models = { + name[len("model_") :]: model + for name, model in globals().items() + if name.startswith("model_") +} def main(args): @@ -229,24 +272,36 @@ def main(args): # load data if args.dataset == "dipper": - capture_history_file = os.path.dirname(os.path.abspath(__file__)) + '/dipper_capture_history.csv' + capture_history_file = ( + os.path.dirname(os.path.abspath(__file__)) + "/dipper_capture_history.csv" + ) elif args.dataset == "vole": - capture_history_file = os.path.dirname(os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv' + capture_history_file = ( + os.path.dirname(os.path.abspath(__file__)) + + "/meadow_voles_capture_history.csv" + ) else: - raise ValueError("Available datasets are \'dipper\' and \'vole\'.") + raise ValueError("Available datasets are 'dipper' and 'vole'.") - capture_history = torch.tensor(np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:] + capture_history = torch.tensor( + np.genfromtxt(capture_history_file, delimiter=",") + ).float()[:, 1:] N, T = capture_history.shape - print("Loaded {} capture history for {} individuals collected over {} time periods.".format( - args.dataset, N, T)) + print( + "Loaded {} capture history for {} individuals collected over {} time periods.".format( + args.dataset, N, T + ) + ) if args.dataset == "dipper" and args.model in ["4", "5"]: - sex_file = os.path.dirname(os.path.abspath(__file__)) + '/dipper_sex.csv' - sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:, 1] + sex_file = os.path.dirname(os.path.abspath(__file__)) + "/dipper_sex.csv" + sex = torch.tensor(np.genfromtxt(sex_file, delimiter=",")).float()[:, 1] print("Loaded dipper sex data.") elif args.dataset == "vole" and args.model in ["4", "5"]: - raise ValueError("Cannot run model_{} on meadow voles data, since we lack sex " - "information for these animals.".format(args.model)) + raise ValueError( + "Cannot run model_{} on meadow voles data, since we lack sex " + "information for these animals.".format(args.model) + ) else: sex = None @@ -256,7 +311,7 @@ def main(args): # in the models to AutoDiagonalNormal (all of which begin with 'phi' # or 'rho') def expose_fn(msg): - return msg["name"][0:3] in ['phi', 'rho'] + return msg["name"][0:3] in ["phi", "rho"] # we use a mean field diagonal normal variational distributions (i.e. guide) # for the continuous latent variables. @@ -264,20 +319,29 @@ def expose_fn(msg): # since we enumerate the discrete random variables, # we need to use TraceEnum_ELBO or TraceTMC_ELBO. - optim = Adam({'lr': args.learning_rate}) + optim = Adam({"lr": args.learning_rate}) if args.tmc: elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} if msg["infer"].get("enumerate", None) == "parallel" else {}) # noqa: E501 + lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {}, + ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: - elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True) + elbo = TraceEnum_ELBO( + max_plate_nesting=1, num_particles=20, vectorize_particles=True + ) svi = SVI(model, guide, optim, elbo) losses = [] - print("Beginning training of model_{} with Stochastic Variational Inference.".format(args.model)) + print( + "Beginning training of model_{} with Stochastic Variational Inference.".format( + args.model + ) + ) for step in range(args.num_steps): loss = svi.step(capture_history, sex) @@ -286,23 +350,35 @@ def expose_fn(msg): print("[iteration %03d] loss: %.3f" % (step, np.mean(losses[-20:]))) # evaluate final trained model - elbo_eval = TraceEnum_ELBO(max_plate_nesting=1, num_particles=2000, vectorize_particles=True) + elbo_eval = TraceEnum_ELBO( + max_plate_nesting=1, num_particles=2000, vectorize_particles=True + ) svi_eval = SVI(model, guide, optim, elbo_eval) print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex)) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="CJS capture-recapture model for ecological data") - parser.add_argument("-m", "--model", default="1", type=str, - help="one of: {}".format(", ".join(sorted(models.keys())))) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="CJS capture-recapture model for ecological data" + ) + parser.add_argument( + "-m", + "--model", + default="1", + type=str, + help="one of: {}".format(", ".join(sorted(models.keys()))), + ) parser.add_argument("-d", "--dataset", default="dipper", type=str) parser.add_argument("-n", "--num-steps", default=400, type=int) parser.add_argument("-lr", "--learning-rate", default=0.002, type=float) - parser.add_argument("--tmc", action='store_true', - help="Use Tensor Monte Carlo instead of exact enumeration " - "to estimate the marginal likelihood. You probably don't want to do this, " - "except to see that TMC makes Monte Carlo gradient estimation feasible " - "even with very large numbers of non-reparametrized variables.") + parser.add_argument( + "--tmc", + action="store_true", + help="Use Tensor Monte Carlo instead of exact enumeration " + "to estimate the marginal likelihood. You probably don't want to do this, " + "except to see that TMC makes Monte Carlo gradient estimation feasible " + "even with very large numbers of non-reparametrized variables.", + ) parser.add_argument("--tmc-num-samples", default=10, type=int) args = parser.parse_args() main(args) diff --git a/examples/contrib/autoname/mixture.py b/examples/contrib/autoname/mixture.py index da1484c33c..6690713d33 100644 --- a/examples/contrib/autoname/mixture.py +++ b/examples/contrib/autoname/mixture.py @@ -60,23 +60,23 @@ def main(args): data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0]) k = 2 - print('Step\tLoss') + print("Step\tLoss") loss = 0.0 for step in range(args.num_epochs): if step and step % 10 == 0: - print('{}\t{:0.5g}'.format(step, loss)) + print("{}\t{:0.5g}".format(step, loss)) loss = 0.0 loss += inference.step(data, k=k) - print('Parameters:') + print("Parameters:") for name, value in sorted(pyro.get_param_store().items()): - print('{} = {}'.format(name, value.detach().cpu().numpy())) + print("{} = {}".format(name, value.detach().cpu().numpy())) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', default=200, type=int) - parser.add_argument('--jit', action='store_true') + parser.add_argument("-n", "--num-epochs", default=200, type=int) + parser.add_argument("--jit", action="store_true") args = parser.parse_args() main(args) diff --git a/examples/contrib/autoname/scoping_mixture.py b/examples/contrib/autoname/scoping_mixture.py index 4ff5ba34a4..a64b86f8c0 100644 --- a/examples/contrib/autoname/scoping_mixture.py +++ b/examples/contrib/autoname/scoping_mixture.py @@ -15,31 +15,35 @@ def model(K, data): # Global parameters. - weights = pyro.param('weights', torch.ones(K) / K, constraint=constraints.simplex) - locs = pyro.param('locs', 10 * torch.randn(K)) - scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive) + weights = pyro.param("weights", torch.ones(K) / K, constraint=constraints.simplex) + locs = pyro.param("locs", 10 * torch.randn(K)) + scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) - with pyro.plate('data'): + with pyro.plate("data"): return local_model(weights, locs, scale, data) @scope(prefix="local") def local_model(weights, locs, scale, data): - assignment = pyro.sample('assignment', - dist.Categorical(weights).expand_by([len(data)])) - return pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) + assignment = pyro.sample( + "assignment", dist.Categorical(weights).expand_by([len(data)]) + ) + return pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data) def guide(K, data): - assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K, - constraint=constraints.unit_interval) - with pyro.plate('data'): + assignment_probs = pyro.param( + "assignment_probs", + torch.ones(len(data), K) / K, + constraint=constraints.unit_interval, + ) + with pyro.plate("data"): return local_guide(assignment_probs) @scope(prefix="local") def local_guide(probs): - return pyro.sample('assignment', dist.Categorical(probs)) + return pyro.sample("assignment", dist.Categorical(probs)) def main(args): @@ -48,26 +52,27 @@ def main(args): K = 2 data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0]) - optim = pyro.optim.Adam({'lr': 0.1}) - inference = SVI(model, config_enumerate(guide), optim, - loss=TraceEnum_ELBO(max_plate_nesting=1)) + optim = pyro.optim.Adam({"lr": 0.1}) + inference = SVI( + model, config_enumerate(guide), optim, loss=TraceEnum_ELBO(max_plate_nesting=1) + ) - print('Step\tLoss') + print("Step\tLoss") loss = 0.0 for step in range(args.num_epochs): if step and step % 10 == 0: - print('{}\t{:0.5g}'.format(step, loss)) + print("{}\t{:0.5g}".format(step, loss)) loss = 0.0 loss += inference.step(K, data) - print('Parameters:') + print("Parameters:") for name, value in sorted(pyro.get_param_store().items()): - print('{} = {}'.format(name, value.detach().cpu().numpy())) + print("{} = {}".format(name, value.detach().cpu().numpy())) if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', default=200, type=int) + parser.add_argument("-n", "--num-epochs", default=200, type=int) args = parser.parse_args() main(args) diff --git a/examples/contrib/autoname/tree_data.py b/examples/contrib/autoname/tree_data.py index 234c306598..c6dd5044f5 100644 --- a/examples/contrib/autoname/tree_data.py +++ b/examples/contrib/autoname/tree_data.py @@ -90,22 +90,22 @@ def main(args): }, } - print('Step\tLoss') + print("Step\tLoss") loss = 0.0 for step in range(args.num_epochs): loss += inference.step(data) if step and step % 10 == 0: - print('{}\t{:0.5g}'.format(step, loss)) + print("{}\t{:0.5g}".format(step, loss)) loss = 0.0 - print('Parameters:') + print("Parameters:") for name, value in sorted(pyro.get_param_store().items()): - print('{} = {}'.format(name, value.detach().cpu().numpy())) + print("{} = {}".format(name, value.detach().cpu().numpy())) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', default=100, type=int) + parser.add_argument("-n", "--num-epochs", default=100, type=int) args = parser.parse_args() main(args) diff --git a/examples/contrib/cevae/synthetic.py b/examples/contrib/cevae/synthetic.py index bdba5bd7a8..80c2747b2b 100644 --- a/examples/contrib/cevae/synthetic.py +++ b/examples/contrib/cevae/synthetic.py @@ -37,7 +37,7 @@ def generate_data(args): y = dist.Bernoulli(logits=3 * (z + 2 * (2 * t - 2))).sample() # Compute true ite for evaluation (via Monte Carlo approximation). - t0_t1 = torch.tensor([[0.], [1.]]) + t0_t1 = torch.tensor([[0.0], [1.0]]) y_t0, y_t1 = dist.Bernoulli(logits=3 * (z + 2 * (2 * t0_t1 - 2))).mean true_ite = y_t1 - y_t0 return x, t, y, true_ite @@ -45,7 +45,7 @@ def generate_data(args): def main(args): if args.cuda: - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") # Generate synthetic data. pyro.set_rng_seed(args.seed) @@ -54,17 +54,23 @@ def main(args): # Train. pyro.set_rng_seed(args.seed) pyro.clear_param_store() - cevae = CEVAE(feature_dim=args.feature_dim, - latent_dim=args.latent_dim, - hidden_dim=args.hidden_dim, - num_layers=args.num_layers, - num_samples=10) - cevae.fit(x_train, t_train, y_train, - num_epochs=args.num_epochs, - batch_size=args.batch_size, - learning_rate=args.learning_rate, - learning_rate_decay=args.learning_rate_decay, - weight_decay=args.weight_decay) + cevae = CEVAE( + feature_dim=args.feature_dim, + latent_dim=args.latent_dim, + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + num_samples=10, + ) + cevae.fit( + x_train, + t_train, + y_train, + num_epochs=args.num_epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + learning_rate_decay=args.learning_rate_decay, + weight_decay=args.weight_decay, + ) # Evaluate. x_test, t_test, y_test, true_ite = generate_data(args) @@ -80,8 +86,10 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description="Causal Effect Variational Autoencoder") + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="Causal Effect Variational Autoencoder" + ) parser.add_argument("--num-data", default=1000, type=int) parser.add_argument("--feature-dim", default=5, type=int) parser.add_argument("--latent-dim", default=20, type=int) diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index a90658050a..720f8bef88 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -9,7 +9,7 @@ import pyro from pyro.contrib.epidemiology.models import RegionalSIRModel -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) def Model(args, data): @@ -24,22 +24,30 @@ def generate_data(args): model = Model(args, extended_data) logging.info("Simulating from a {}".format(type(model).__name__)) for attempt in range(100): - samples = model.generate({"R0": args.basic_reproduction_number, - "rho_c1": 10 * args.response_rate, - "rho_c0": 10 * (1 - args.response_rate)}) - obs = samples["obs"][:args.duration] + samples = model.generate( + { + "R0": args.basic_reproduction_number, + "rho_c1": 10 * args.response_rate, + "rho_c0": 10 * (1 - args.response_rate), + } + ) + obs = samples["obs"][: args.duration] S2I = samples["S2I"] obs_sum = int(obs.sum()) - S2I_sum = int(S2I[:args.duration].sum()) + S2I_sum = int(S2I[: args.duration].sum()) if obs_sum >= args.min_observations: - logging.info("Observed {:d}/{:d} infections:\n{}".format( - obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0]))) + logging.info( + "Observed {:d}/{:d} infections:\n{}".format( + obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0]) + ) + ) return {"S2I": S2I, "obs": obs} - raise ValueError("Failed to generate {} observations. Try increasing " - "--population or decreasing --min-observations" - .format(args.min_observations)) + raise ValueError( + "Failed to generate {} observations. Try increasing " + "--population or decreasing --min-observations".format(args.min_observations) + ) def infer_mcmc(args, model): @@ -51,20 +59,23 @@ def hook_fn(kernel, *unused): if args.verbose: logging.info("potential = {:0.6g}".format(e)) - mcmc = model.fit_mcmc(heuristic_num_particles=args.smc_particles, - heuristic_ess_threshold=args.ess_threshold, - warmup_steps=args.warmup_steps, - num_samples=args.num_samples, - max_tree_depth=args.max_tree_depth, - num_quant_bins=args.num_bins, - haar=args.haar, - haar_full_mass=args.haar_full_mass, - jit_compile=args.jit, - hook_fn=hook_fn) + mcmc = model.fit_mcmc( + heuristic_num_particles=args.smc_particles, + heuristic_ess_threshold=args.ess_threshold, + warmup_steps=args.warmup_steps, + num_samples=args.num_samples, + max_tree_depth=args.max_tree_depth, + num_quant_bins=args.num_bins, + haar=args.haar, + haar_full_mass=args.haar_full_mass, + jit_compile=args.jit, + hook_fn=hook_fn, + ) mcmc.summary() if args.plot: import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) plt.plot(energies) plt.xlabel("MCMC step") @@ -74,16 +85,19 @@ def hook_fn(kernel, *unused): def infer_svi(args, model): - losses = model.fit_svi(heuristic_num_particles=args.smc_particles, - heuristic_ess_threshold=args.ess_threshold, - num_samples=args.num_samples, - num_steps=args.svi_steps, - num_particles=args.svi_particles, - haar=args.haar, - jit=args.jit) + losses = model.fit_svi( + heuristic_num_particles=args.smc_particles, + heuristic_ess_threshold=args.ess_threshold, + num_samples=args.num_samples, + num_steps=args.svi_steps, + num_particles=args.svi_particles, + haar=args.haar, + jit=args.jit, + ) if args.plot: import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) plt.plot(losses) plt.xlabel("SVI step") @@ -98,27 +112,36 @@ def predict(args, model, truth): median = S2I.median(dim=0).values lines = ["Median prediction of new infections (starting on day 0):"] for r in range(args.num_regions): - lines.append("Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r]))))) + lines.append( + "Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r])))) + ) logging.info("\n".join(lines)) # Optionally plot the latent and forecasted series of new infections. if args.plot: import matplotlib.pyplot as plt - fig, axes = plt.subplots(args.num_regions, sharex=True, - figsize=(6, 1 + args.num_regions)) + + fig, axes = plt.subplots( + args.num_regions, sharex=True, figsize=(6, 1 + args.num_regions) + ) time = torch.arange(args.duration + args.forecast) p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values for r, ax in enumerate(axes): - ax.fill_between(time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI") + ax.fill_between( + time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI" + ) ax.plot(time, median[:, r], "r-", label="median") - ax.plot(time[:args.duration], model.data[:, r], "k.", label="observed") + ax.plot(time[: args.duration], model.data[:, r], "k.", label="observed") ax.plot(time, truth[:, r], "k--", label="truth") ax.axvline(args.duration - 0.5, color="gray", lw=1) ax.set_xlim(0, len(time) - 1) ax.set_ylim(0, None) - axes[0].set_title("New infections among {} regions each of size {}" - .format(args.num_regions, args.population)) + axes[0].set_title( + "New infections among {} regions each of size {}".format( + args.num_regions, args.population + ) + ) axes[args.num_regions // 2].set_ylabel("inf./day") axes[-1].set_xlabel("day after first infection") axes[-1].legend(loc="upper left") @@ -143,9 +166,10 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser( - description="Regional compartmental epidemiology modeling using HMC") + description="Regional compartmental epidemiology modeling using HMC" + ) parser.add_argument("-p", "--population", default=1000, type=int) parser.add_argument("-r", "--num-regions", default=2, type=int) parser.add_argument("-c", "--coupling", default=0.1, type=float) @@ -192,4 +216,5 @@ def main(args): if args.plot: import matplotlib.pyplot as plt + plt.show() diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index bcb3e60d74..817a892d6f 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -23,7 +23,7 @@ SuperspreadingSIRModel, ) -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) def Model(args, data): @@ -35,14 +35,17 @@ def Model(args, data): elif args.incubation_time > 0: assert args.incubation_time > 1 if args.concentration < math.inf: - return SuperspreadingSEIRModel(args.population, args.incubation_time, - args.recovery_time, data) + return SuperspreadingSEIRModel( + args.population, args.incubation_time, args.recovery_time, data + ) elif args.overdispersion > 0: - return OverdispersedSEIRModel(args.population, args.incubation_time, - args.recovery_time, data) + return OverdispersedSEIRModel( + args.population, args.incubation_time, args.recovery_time, data + ) else: - return SimpleSEIRModel(args.population, args.incubation_time, - args.recovery_time, data) + return SimpleSEIRModel( + args.population, args.incubation_time, args.recovery_time, data + ) else: if args.concentration < math.inf: return SuperspreadingSIRModel(args.population, args.recovery_time, data) @@ -57,31 +60,44 @@ def generate_data(args): model = Model(args, extended_data) logging.info("Simulating from a {}".format(type(model).__name__)) for attempt in range(100): - samples = model.generate({"R0": args.basic_reproduction_number, - "rho": args.response_rate, - "k": args.concentration, - "od": args.overdispersion}) - obs = samples["obs"][:args.duration] + samples = model.generate( + { + "R0": args.basic_reproduction_number, + "rho": args.response_rate, + "k": args.concentration, + "od": args.overdispersion, + } + ) + obs = samples["obs"][: args.duration] new_I = samples.get("S2I", samples.get("E2I")) obs_sum = int(obs.sum()) - new_I_sum = int(new_I[:args.duration].sum()) + new_I_sum = int(new_I[: args.duration].sum()) assert 0 <= args.min_obs_portion < args.max_obs_portion <= 1 min_obs = int(math.ceil(args.min_obs_portion * args.population)) max_obs = int(math.floor(args.max_obs_portion * args.population)) if min_obs <= obs_sum <= max_obs: - logging.info("Observed {:d}/{:d} infections:\n{}".format( - obs_sum, new_I_sum, " ".join(str(int(x)) for x in obs))) + logging.info( + "Observed {:d}/{:d} infections:\n{}".format( + obs_sum, new_I_sum, " ".join(str(int(x)) for x in obs) + ) + ) return {"new_I": new_I, "obs": obs} if obs_sum < min_obs: - raise ValueError("Failed to generate >={} observations. " - "Try decreasing --min-obs-portion (currently {})." - .format(min_obs, args.min_obs_portion)) + raise ValueError( + "Failed to generate >={} observations. " + "Try decreasing --min-obs-portion (currently {}).".format( + min_obs, args.min_obs_portion + ) + ) else: - raise ValueError("Failed to generate <={} observations. " - "Try increasing --max-obs-portion (currently {})." - .format(max_obs, args.max_obs_portion)) + raise ValueError( + "Failed to generate <={} observations. " + "Try increasing --max-obs-portion (currently {}).".format( + max_obs, args.max_obs_portion + ) + ) def infer_mcmc(args, model): @@ -94,23 +110,26 @@ def hook_fn(kernel, *unused): if args.verbose: logging.info("potential = {:0.6g}".format(e)) - mcmc = model.fit_mcmc(heuristic_num_particles=args.smc_particles, - heuristic_ess_threshold=args.ess_threshold, - warmup_steps=args.warmup_steps, - num_samples=args.num_samples, - num_chains=args.num_chains, - mp_context="spawn" if parallel else None, - max_tree_depth=args.max_tree_depth, - arrowhead_mass=args.arrowhead_mass, - num_quant_bins=args.num_bins, - haar=args.haar, - haar_full_mass=args.haar_full_mass, - jit_compile=args.jit, - hook_fn=None if parallel else hook_fn) + mcmc = model.fit_mcmc( + heuristic_num_particles=args.smc_particles, + heuristic_ess_threshold=args.ess_threshold, + warmup_steps=args.warmup_steps, + num_samples=args.num_samples, + num_chains=args.num_chains, + mp_context="spawn" if parallel else None, + max_tree_depth=args.max_tree_depth, + arrowhead_mass=args.arrowhead_mass, + num_quant_bins=args.num_bins, + haar=args.haar, + haar_full_mass=args.haar_full_mass, + jit_compile=args.jit, + hook_fn=None if parallel else hook_fn, + ) mcmc.summary() if args.plot and energies: import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) plt.plot(energies) plt.xlabel("MCMC step") @@ -122,16 +141,19 @@ def hook_fn(kernel, *unused): def infer_svi(args, model): - losses = model.fit_svi(heuristic_num_particles=args.smc_particles, - heuristic_ess_threshold=args.ess_threshold, - num_samples=args.num_samples, - num_steps=args.svi_steps, - num_particles=args.svi_particles, - haar=args.haar, - jit=args.jit) + losses = model.fit_svi( + heuristic_num_particles=args.smc_particles, + heuristic_ess_threshold=args.ess_threshold, + num_samples=args.num_samples, + num_steps=args.svi_steps, + num_particles=args.svi_particles, + haar=args.haar, + jit=args.jit, + ) if args.plot: import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) plt.plot(losses) plt.xlabel("SVI step") @@ -154,8 +176,11 @@ def evaluate(args, model, samples): for name, key in names.items(): mean = samples[key].mean().item() std = samples[key].std().item() - logging.info("{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}" - .format(key, getattr(args, name), mean, std)) + logging.info( + "{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}".format( + key, getattr(args, name), mean, std + ) + ) # Optionally plot histograms and pairwise correlations. if args.plot: @@ -191,8 +216,13 @@ def evaluate(args, model, samples): ax = axes[i][j] ax.set_xticks(()) ax.set_yticks(()) - ax.scatter(covariates[j][1], -covariates[i][1], - lw=0, color="darkblue", alpha=0.3) + ax.scatter( + covariates[j][1], + -covariates[i][1], + lw=0, + color="darkblue", + alpha=0.3, + ) plt.tight_layout() plt.subplots_adjust(wspace=0, hspace=0) @@ -204,10 +234,10 @@ def unconstrain(constraint, value): covariates = [("R1", unconstrain(constraints.positive, samples["R0"]))] if not args.heterogeneous: covariates.append( - ("rho", unconstrain(constraints.unit_interval, samples["rho"]))) + ("rho", unconstrain(constraints.unit_interval, samples["rho"])) + ) if "k" in samples: - covariates.append( - ("k", unconstrain(constraints.positive, samples["k"]))) + covariates.append(("k", unconstrain(constraints.positive, samples["k"]))) constraint = constraints.interval(-0.5, model.population + 0.5) for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)): covariates.append((name, unconstrain(constraint, aux))) @@ -235,19 +265,23 @@ def predict(args, model, truth): new_I = samples.get("S2I", samples.get("E2I")) median = new_I.median(dim=0).values - logging.info("Median prediction of new infections (starting on day 0):\n{}" - .format(" ".join(map(str, map(int, median))))) + logging.info( + "Median prediction of new infections (starting on day 0):\n{}".format( + " ".join(map(str, map(int, median))) + ) + ) # Optionally plot the latent and forecasted series of new infections. if args.plot: import matplotlib.pyplot as plt + plt.figure() time = torch.arange(args.duration + args.forecast) p05 = new_I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values p95 = new_I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") plt.plot(time, median, "r-", label="median") - plt.plot(time[:args.duration], obs, "k.", label="observed") + plt.plot(time[: args.duration], obs, "k.", label="observed") if truth is not None: plt.plot(time, truth, "k--", label="truth") plt.axvline(args.duration - 0.5, color="gray", lw=1) @@ -268,7 +302,7 @@ def predict(args, model, truth): p95 = Re.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") plt.plot(time, median, "r-", label="median") - plt.plot(time[:args.duration], obs, "k.", label="observed") + plt.plot(time[: args.duration], obs, "k.", label="observed") plt.axvline(args.duration - 0.5, color="gray", lw=1) plt.xlim(0, len(time) - 1) plt.ylim(0, None) @@ -300,9 +334,10 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser( - description="Compartmental epidemiology modeling using HMC") + description="Compartmental epidemiology modeling using HMC" + ) parser.add_argument("-p", "--population", default=1000, type=float) parser.add_argument("-m", "--min-obs-portion", default=0.01, type=float) parser.add_argument("-M", "--max-obs-portion", default=0.99, type=float) @@ -310,12 +345,22 @@ def main(args): parser.add_argument("-f", "--forecast", default=10, type=int) parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float) parser.add_argument("-tau", "--recovery-time", default=7.0, type=float) - parser.add_argument("-e", "--incubation-time", default=0.0, type=float, - help="If zero, use SIR model; if > 1 use SEIR model.") - parser.add_argument("-k", "--concentration", default=math.inf, type=float, - help="If finite, use a superspreader model.") + parser.add_argument( + "-e", + "--incubation-time", + default=0.0, + type=float, + help="If zero, use SIR model; if > 1 use SEIR model.", + ) + parser.add_argument( + "-k", + "--concentration", + default=math.inf, + type=float, + help="If finite, use a superspreader model.", + ) parser.add_argument("-rho", "--response-rate", default=0.5, type=float) - parser.add_argument("-o", "--overdispersion", default=0., type=float) + parser.add_argument("-o", "--overdispersion", default=0.0, type=float) parser.add_argument("-hg", "--heterogeneous", action="store_true") parser.add_argument("--infer", default="mcmc") parser.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") @@ -357,4 +402,5 @@ def main(args): if args.plot: import matplotlib.pyplot as plt + plt.show() diff --git a/examples/contrib/forecast/bart.py b/examples/contrib/forecast/bart.py index 98ccd55aa2..2e3c452426 100644 --- a/examples/contrib/forecast/bart.py +++ b/examples/contrib/forecast/bart.py @@ -53,50 +53,55 @@ def model(self, zero_data, covariates): assert dim == 2 # Data is bivariate: (arrivals, departures). # Sample global parameters. - noise_scale = pyro.sample("noise_scale", - dist.LogNormal(torch.full((dim,), -3.), 1.).to_event(1)) + noise_scale = pyro.sample( + "noise_scale", dist.LogNormal(torch.full((dim,), -3.0), 1.0).to_event(1) + ) assert noise_scale.shape[-1:] == (dim,) - trans_timescale = pyro.sample("trans_timescale", - dist.LogNormal(torch.zeros(dim), 1).to_event(1)) + trans_timescale = pyro.sample( + "trans_timescale", dist.LogNormal(torch.zeros(dim), 1).to_event(1) + ) assert trans_timescale.shape[-1:] == (dim,) trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period)) trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim,)) assert trans_loc.shape[-1:] == (dim,) - trans_scale = pyro.sample("trans_scale", - dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) - trans_corr = pyro.sample("trans_corr", - dist.LKJCholesky(dim, torch.ones(()))) + trans_scale = pyro.sample( + "trans_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1) + ) + trans_corr = pyro.sample("trans_corr", dist.LKJCholesky(dim, torch.ones(()))) trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr assert trans_scale_tril.shape[-2:] == (dim, dim) - obs_scale = pyro.sample("obs_scale", - dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) - obs_corr = pyro.sample("obs_corr", - dist.LKJCholesky(dim, torch.ones(()))) + obs_scale = pyro.sample( + "obs_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1) + ) + obs_corr = pyro.sample("obs_corr", dist.LKJCholesky(dim, torch.ones(()))) obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr assert obs_scale_tril.shape[-2:] == (dim, dim) # Note the initial seasonality should be sampled in a plate with the # same dim as the time_plate, dim=-1. That way we can repeat the dim # below using periodic_repeat(). - with pyro.plate("season_plate", period, dim=-1): - season_init = pyro.sample("season_init", - dist.Normal(torch.zeros(dim), 1).to_event(1)) + with pyro.plate("season_plate", period, dim=-1): + season_init = pyro.sample( + "season_init", dist.Normal(torch.zeros(dim), 1).to_event(1) + ) assert season_init.shape[-2:] == (period, dim) # Sample independent noise at each time step. with self.time_plate: - season_noise = pyro.sample("season_noise", - dist.Normal(0, noise_scale).to_event(1)) + season_noise = pyro.sample( + "season_noise", dist.Normal(0, noise_scale).to_event(1) + ) assert season_noise.shape[-2:] == (duration, dim) # Construct a prediction. This prediction has an exactly repeated # seasonal part plus slow seasonal drift. We use two deterministic, # linear functions to transform our diagonal Normal noise to nontrivial # samples from a Gaussian process. - prediction = (periodic_repeat(season_init, duration, dim=-2) + - periodic_cumsum(season_noise, period, dim=-2)) + prediction = periodic_repeat(season_init, duration, dim=-2) + periodic_cumsum( + season_noise, period, dim=-2 + ) assert prediction.shape[-2:] == (duration, dim) # Construct a joint noise model. This model is a GaussianHMM, whose @@ -107,8 +112,9 @@ def model(self, zero_data, covariates): trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril) - noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, - duration=duration) + noise_model = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) assert noise_model.event_shape == (duration, dim) # The final statement registers our noise model and prediction. @@ -139,12 +145,16 @@ def transform(pred, truth): "log_every": args.log_every, "dct_gradients": args.dct, } - metrics = backtest(data, covariates, Model, - train_window=args.train_window, - test_window=args.test_window, - stride=args.stride, - num_samples=args.num_samples, - forecaster_options=forecaster_options) + metrics = backtest( + data, + covariates, + Model, + train_window=args.train_window, + test_window=args.test_window, + stride=args.stride, + num_samples=args.num_samples, + forecaster_options=forecaster_options, + ) for name in ["mae", "rmse", "crps"]: values = [m[name] for m in metrics] @@ -155,7 +165,7 @@ def transform(pred, truth): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example") parser.add_argument("--train-window", default=2160, type=int) parser.add_argument("--test-window", default=336, type=int) diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 77463972bd..0a36ac6e82 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -64,7 +64,7 @@ from pyroapi import distributions as dist from pyroapi import handlers, infer, optim, pyro, pyro_backend -logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) +logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) # in a separate stream that can be captured. @@ -98,16 +98,17 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, data_dim]) - .to_event(2)) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), + ) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) @@ -119,11 +120,19 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). - x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}_{}".format(i, t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate: - pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), - obs=sequence[t]) + pyro.sample( + "y_{}_{}".format(i, t), + dist.Bernoulli(probs_y[x.squeeze(-1)]), + obs=sequence[t], + ) + + # To see how enumeration changes the shapes of these sample sites, we can use # the Trace.format_shapes() to print shapes at each site: # $ python examples/hmm.py -m 0 -n 1 -b 1 -t 5 --print-shapes @@ -182,13 +191,14 @@ def model_1(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, data_dim]) - .to_event(2)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different @@ -204,11 +214,19 @@ def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[x.squeeze(-1)]), + obs=sequences[batch, t], + ) + + # Let's see how batching changes the shapes of sample sites: # $ python examples/hmm.py -m 1 -n 1 -t 5 --batch-size=10 --print-shapes # ... @@ -260,27 +278,34 @@ def model_2(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, 2, data_dim]) - .to_event(3)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: - y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), - obs=sequences[batch, t]).long() + y = pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[x, y, tones]), + obs=sequences[batch, t], + ).long() # Next consider a Factorial HMM with two hidden states. @@ -305,29 +330,38 @@ def model_3(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with handlers.mask(mask=include_prior): - probs_w = pyro.sample("probs_w", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([hidden_dim, hidden_dim, data_dim]) - .to_event(3)) + probs_w = pyro.sample( + "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_x = pyro.sample( + "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): - w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), - infer={"enumerate": "parallel"}) - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + w = pyro.sample( + "w_{}".format(t), + dist.Categorical(probs_w[w]), + infer={"enumerate": "parallel"}, + ) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate as tones: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[w, x, tones]), + obs=sequences[batch, t], + ) # By adding a dependency of x on w, we generalize to a @@ -351,17 +385,19 @@ def model_4(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with handlers.mask(mask=include_prior): - probs_w = pyro.sample("probs_w", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .expand_by([hidden_dim]) - .to_event(2)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([hidden_dim, hidden_dim, data_dim]) - .to_event(3)) + probs_w = pyro.sample( + "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) + .expand_by([hidden_dim]) + .to_event(2), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] @@ -371,14 +407,22 @@ def model_4(sequences, lengths, args, batch_size=None, include_prior=True): w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): - w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), - infer={"enumerate": "parallel"}) - x = pyro.sample("x_{}".format(t), - dist.Categorical(Vindex(probs_x)[w, x]), - infer={"enumerate": "parallel"}) + w = pyro.sample( + "w_{}".format(t), + dist.Categorical(probs_w[w]), + infer={"enumerate": "parallel"}, + ) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(Vindex(probs_x)[w, x]), + infer={"enumerate": "parallel"}, + ) with tones_plate as tones: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[w, x, tones]), + obs=sequences[batch, t], + ) # Next let's consider a neural HMM model. @@ -405,8 +449,12 @@ def forward(self, x, y): # a bernoulli variable y. Whereas x will typically be enumerated, y will be observed. # We apply x_to_hidden independently from y_to_hidden, then broadcast the non-enumerated # y part up to the enumerated x part in the + operation. - x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_(-1, x, 1) - y_conv = self.relu(self.conv(y.reshape(-1, 1, self.data_dim))).reshape(y.shape[:-1] + (-1,)) + x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_( + -1, x, 1 + ) + y_conv = self.relu(self.conv(y.reshape(-1, 1, self.data_dim))).reshape( + y.shape[:-1] + (-1,) + ) h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv)) return self.hidden_to_logits(h) @@ -431,23 +479,29 @@ def model_5(sequences, lengths, args, batch_size=None, include_prior=True): pyro.module("tones_generator", tones_generator) with handlers.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): - y = pyro.sample("y_{}".format(t), - dist.Bernoulli(logits=tones_generator(x, y)), - obs=sequences[batch, t]) + y = pyro.sample( + "y_{}".format(t), + dist.Bernoulli(logits=tones_generator(x, y)), + obs=sequences[batch, t], + ) # Next let's consider a second-order HMM model @@ -474,24 +528,38 @@ def model_6(sequences, lengths, args, batch_size=None, include_prior=False): if not args.raftery_parameterization: # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. - probs_x = pyro.param("probs_x", torch.rand(hidden_dim, hidden_dim, hidden_dim), - constraint=constraints.simplex) + probs_x = pyro.param( + "probs_x", + torch.rand(hidden_dim, hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) else: # Use the more parsimonious "Raftery" parameterization of # the tensor of transition probabilities. See reference: # Raftery, A. E. A model for high-order markov chains. # Journal of the Royal Statistical Society. 1985. - probs_x1 = pyro.param("probs_x1", torch.rand(hidden_dim, hidden_dim), - constraint=constraints.simplex) - probs_x2 = pyro.param("probs_x2", torch.rand(hidden_dim, hidden_dim), - constraint=constraints.simplex) - mix_lambda = pyro.param("mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval) + probs_x1 = pyro.param( + "probs_x1", + torch.rand(hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) + probs_x2 = pyro.param( + "probs_x2", + torch.rand(hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) + mix_lambda = pyro.param( + "mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval + ) # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim) probs_x = mix_lambda * probs_x1 + (1.0 - mix_lambda) * probs_x2.unsqueeze(-2) - probs_y = pyro.param("probs_y", torch.rand(hidden_dim, data_dim), - constraint=constraints.unit_interval) + probs_y = pyro.param( + "probs_y", + torch.rand(hidden_dim, data_dim), + constraint=constraints.unit_interval, + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] @@ -501,12 +569,18 @@ def model_6(sequences, lengths, args, batch_size=None, include_prior=False): for t in pyro.markov(range(lengths.max()), history=2): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): probs_x_t = Vindex(probs_x)[x_prev, x_curr] - x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Categorical(probs_x_t), - infer={"enumerate": "parallel"}) + x_prev, x_curr = x_curr, pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x_t), + infer={"enumerate": "parallel"}, + ) with tones_plate: probs_y_t = probs_y[x_curr.squeeze(-1)] - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y_t), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y_t), + obs=sequences[batch, t], + ) # Let's go back to our initial model and make it even faster: we'll support @@ -519,13 +593,14 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, data_dim]) - .to_event(2)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # Note that since we're using dim=-2 for the time dimension, we need # to batch sequences over a different dimension, here dim=-3. @@ -536,13 +611,23 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True): # To vectorize time dimension we use pyro.vectorized_markov(name=...). # With the help of Vindex and additional unsqueezes we can ensure that # dimensions line up properly. - for t in pyro.vectorized_markov(name="time", size=int(max_length if args.jit else lengths.max()), dim=-2): + for t in pyro.vectorized_markov( + name="time", size=int(max_length if args.jit else lengths.max()), dim=-2 + ): with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)): - x_curr = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x_prev]), - infer={"enumerate": "parallel"}) + x_curr = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x_prev]), + infer={"enumerate": "parallel"}, + ) with tones_plate: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x_curr.squeeze(-1)]), - obs=Vindex(sequences)[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[x_curr.squeeze(-1)]), + obs=Vindex(sequences)[batch, t], + ) + + # Let's see how vectorizing time dimension changes the shapes of sample sites: # $ python examples/hmm.py -m 7 --funsor -n 1 --batch-size=10 --print-shapes # ... @@ -574,33 +659,38 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True): # finally vectorized t_curr (torch.arange(1, 72)). -models = {name[len('model_'):]: model - for name, model in globals().items() - if name.startswith('model_')} +models = { + name[len("model_") :]: model + for name, model in globals().items() + if name.startswith("model_") +} def main(args): if args.cuda: - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") - logging.info('Loading data') + logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) - logging.info('-' * 40) + logging.info("-" * 40) model = models[args.model] - logging.info('Training {} on {} sequences'.format( - model.__name__, len(data['train']['sequences']))) - sequences = data['train']['sequences'] - lengths = data['train']['sequence_lengths'] + logging.info( + "Training {} on {} sequences".format( + model.__name__, len(data["train"]["sequences"]) + ) + ) + sequences = data["train"]["sequences"] + lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set - present_notes = ((sequences == 1).sum(0).sum(0) > 0) + present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) - sequences = sequences[:, :args.truncate] + sequences = sequences[:, : args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() @@ -609,7 +699,9 @@ def main(args): # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. - guide = AutoDelta(handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) + guide = AutoDelta( + handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")) + ) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is @@ -622,10 +714,11 @@ def main(args): else: first_available_dim = -3 guide_trace = handlers.trace(guide).get_trace( - sequences, lengths, args=args, batch_size=args.batch_size) + sequences, lengths, args=args, batch_size=args.batch_size + ) model_trace = handlers.trace( - handlers.replay(handlers.enum(model, first_available_dim), guide_trace)).get_trace( - sequences, lengths, args=args, batch_size=args.batch_size) + handlers.replay(handlers.enum(model, first_available_dim), guide_trace) + ).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Bind non-PyTorch parameters to make these functions jittable. @@ -634,7 +727,7 @@ def main(args): # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". - optimizer = optim.Adam({'lr': args.learning_rate}) + optimizer = optim.Adam({"lr": args.learning_rate}) if args.tmc: if args.jit and not args.funsor: raise NotImplementedError("jit support not yet added for TraceTMC_ELBO") @@ -642,12 +735,19 @@ def main(args): elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} if msg["infer"].get("enumerate", None) == "parallel" else {}) # noqa: E501 + lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {}, + ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: if args.model == "7": assert args.funsor - Elbo = infer.JitTraceMarkovEnum_ELBO if args.jit else infer.TraceMarkovEnum_ELBO + Elbo = ( + infer.JitTraceMarkovEnum_ELBO + if args.jit + else infer.TraceMarkovEnum_ELBO + ) else: Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO if args.model == "0": @@ -656,30 +756,43 @@ def main(args): max_plate_nesting = 3 else: max_plate_nesting = 2 - elbo = Elbo(max_plate_nesting=max_plate_nesting, - strict_enumeration_warning=True, - jit_options={"time_compilation": args.time_compilation}) + elbo = Elbo( + max_plate_nesting=max_plate_nesting, + strict_enumeration_warning=True, + jit_options={"time_compilation": args.time_compilation}, + ) svi = infer.SVI(model, guide, optimizer, elbo) # We'll train on small minibatches. - logging.info('Step\tLoss') + logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, batch_size=args.batch_size) - logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) + logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: - logging.debug('time to compile: {} s.'.format(elbo._differentiable_loss.compile_time)) + logging.debug( + "time to compile: {} s.".format(elbo._differentiable_loss.compile_time) + ) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. - train_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) - logging.info('training loss = {}'.format(train_loss / num_observations)) + train_loss = elbo.loss( + model, + guide, + sequences, + lengths, + batch_size=sequences.shape[0], + include_prior=False, + ) + logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. - logging.info('-' * 40) - logging.info('Evaluating on {} test sequences'.format(len(data['test']['sequences']))) - sequences = data['test']['sequences'][..., present_notes] - lengths = data['test']['sequence_lengths'] + logging.info("-" * 40) + logging.info( + "Evaluating on {} test sequences".format(len(data["test"]["sequences"])) + ) + sequences = data["test"]["sequences"][..., present_notes] + lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) @@ -687,21 +800,36 @@ def main(args): # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. - test_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) - logging.info('test loss = {}'.format(test_loss / num_observations)) + test_loss = elbo.loss( + model, + guide, + sequences, + lengths, + batch_size=sequences.shape[0], + include_prior=False, + ) + logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. - capacity = sum(value.reshape(-1).size(0) - for value in pyro.get_param_store().values()) - logging.info('model_{} capacity = {} parameters'.format(args.model, capacity)) - - -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales") - parser.add_argument("-m", "--model", default="1", type=str, - help="one of: {}".format(", ".join(sorted(models.keys())))) + capacity = sum( + value.reshape(-1).size(0) for value in pyro.get_param_store().values() + ) + logging.info("model_{} capacity = {} parameters".format(args.model, capacity)) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="MAP Baum-Welch learning Bach Chorales" + ) + parser.add_argument( + "-m", + "--model", + default="1", + type=str, + help="one of: {}".format(", ".join(sorted(models.keys()))), + ) parser.add_argument("-n", "--num-steps", default=50, type=int) parser.add_argument("-b", "--batch-size", default=8, type=int) parser.add_argument("-d", "--hidden-dim", default=16, type=int) @@ -711,21 +839,25 @@ def main(args): parser.add_argument("-t", "--truncate", type=int) parser.add_argument("-p", "--print-shapes", action="store_true") parser.add_argument("--seed", default=0, type=int) - parser.add_argument('--cuda', action='store_true') - parser.add_argument('--jit', action='store_true') - parser.add_argument('--time-compilation', action='store_true') - parser.add_argument('-rp', '--raftery-parameterization', action='store_true') - parser.add_argument('--tmc', action='store_true', - help="Use Tensor Monte Carlo instead of exact enumeration " - "to estimate the marginal likelihood. You probably don't want to do this, " - "except to see that TMC makes Monte Carlo gradient estimation feasible " - "even with very large numbers of non-reparametrized variables.") - parser.add_argument('--tmc-num-samples', default=10, type=int) - parser.add_argument('--funsor', action='store_true') + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true") + parser.add_argument("--time-compilation", action="store_true") + parser.add_argument("-rp", "--raftery-parameterization", action="store_true") + parser.add_argument( + "--tmc", + action="store_true", + help="Use Tensor Monte Carlo instead of exact enumeration " + "to estimate the marginal likelihood. You probably don't want to do this, " + "except to see that TMC makes Monte Carlo gradient estimation feasible " + "even with very large numbers of non-reparametrized variables.", + ) + parser.add_argument("--tmc-num-samples", default=10, type=int) + parser.add_argument("--funsor", action="store_true") args = parser.parse_args() if args.funsor: import funsor + funsor.set_backend("torch") PYRO_BACKEND = "contrib.funsor" else: diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index a2e8672ca9..c27f7dbbde 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -73,9 +73,15 @@ def train(args, train_loader, gpmodule, optimizer, loss_fn, epoch): optimizer.step() batch_idx = batch_idx + 1 if batch_idx % args.log_interval == 0: - print("Train Epoch: {:2d} [{:5d}/{} ({:2.0f}%)]\tLoss: {:.6f}" - .format(epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss)) + print( + "Train Epoch: {:2d} [{:5d}/{} ({:2.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss, + ) + ) def test(args, test_loader, gpmodule): @@ -93,24 +99,35 @@ def test(args, test_loader, gpmodule): # compare prediction and target to count accuracy correct += pred.eq(target).long().cpu().sum().item() - print("\nTest set: Accuracy: {}/{} ({:.2f}%)\n" - .format(correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Accuracy: {}/{} ({:.2f}%)\n".format( + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) def main(args): - data_dir = args.data_dir if args.data_dir is not None else get_data_directory(__file__) - train_loader = get_data_loader(dataset_name='MNIST', - data_dir=data_dir, - batch_size=args.batch_size, - dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], - is_training_set=True, - shuffle=True) - test_loader = get_data_loader(dataset_name='MNIST', - data_dir=data_dir, - batch_size=args.test_batch_size, - dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], - is_training_set=False, - shuffle=False) + data_dir = ( + args.data_dir if args.data_dir is not None else get_data_directory(__file__) + ) + train_loader = get_data_loader( + dataset_name="MNIST", + data_dir=data_dir, + batch_size=args.batch_size, + dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], + is_training_set=True, + shuffle=True, + ) + test_loader = get_data_loader( + dataset_name="MNIST", + data_dir=data_dir, + batch_size=args.test_batch_size, + dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], + is_training_set=False, + shuffle=False, + ) if args.cuda: train_loader.num_workers = 1 test_loader.num_workers = 1 @@ -130,7 +147,7 @@ def main(args): batches.append(data) if i >= ((args.num_inducing - 1) // args.batch_size): break - Xu = torch.cat(batches)[:args.num_inducing].clone() + Xu = torch.cat(batches)[: args.num_inducing].clone() if args.binary: likelihood = gp.likelihoods.Binary() @@ -144,9 +161,17 @@ def main(args): latent_shape = torch.Size([10]) # Turns on "whiten" flag will help optimization for variational models. - gpmodule = gp.models.VariationalSparseGP(X=Xu, y=None, kernel=deep_kernel, Xu=Xu, - likelihood=likelihood, latent_shape=latent_shape, - num_data=60000, whiten=True, jitter=2e-6) + gpmodule = gp.models.VariationalSparseGP( + X=Xu, + y=None, + kernel=deep_kernel, + Xu=Xu, + likelihood=likelihood, + latent_shape=latent_shape, + num_data=60000, + whiten=True, + jitter=2e-6, + ) if args.cuda: gpmodule.cuda() @@ -160,35 +185,77 @@ def main(args): train(args, train_loader, gpmodule, optimizer, loss_fn, epoch) with torch.no_grad(): test(args, test_loader, gpmodule) - print("Amount of time spent for epoch {}: {}s\n" - .format(epoch, int(time.time() - start_time))) - - -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') - parser.add_argument('--data-dir', type=str, default=None, metavar='PATH', - help='default directory to cache MNIST data') - parser.add_argument('--num-inducing', type=int, default=70, metavar='N', - help='number of inducing input (default: 70)') - parser.add_argument('--binary', action='store_true', default=False, - help='do binary classification') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--cuda', action='store_true', default=False, - help='enables CUDA training') - parser.add_argument('--jit', action='store_true', default=False, - help='enables PyTorch jit') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') + print( + "Amount of time spent for epoch {}: {}s\n".format( + epoch, int(time.time() - start_time) + ) + ) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser(description="Pyro GP MNIST Example") + parser.add_argument( + "--data-dir", + type=str, + default=None, + metavar="PATH", + help="default directory to cache MNIST data", + ) + parser.add_argument( + "--num-inducing", + type=int, + default=70, + metavar="N", + help="number of inducing input (default: 70)", + ) + parser.add_argument( + "--binary", action="store_true", default=False, help="do binary classification" + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--cuda", action="store_true", default=False, help="enables CUDA training" + ) + parser.add_argument( + "--jit", action="store_true", default=False, help="enables PyTorch jit" + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) args = parser.parse_args() pyro.set_rng_seed(args.seed) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index a122be7266..4bf8f24440 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -51,9 +51,10 @@ def generate_data(small_test, include_stop, device): else: mult_dat = 10 - seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, - device=device) + seqs = ["BABBA"] * mult_dat + ["BAAB"] * mult_dat + ["BABBB"] * mult_dat + dataset = BiosequenceDataset( + seqs, "list", "AB", include_stop=include_stop, device=device + ) return dataset @@ -62,28 +63,34 @@ def main(args): # Load dataset. if args.cpu_data and args.cuda: - device = torch.device('cpu') + device = torch.device("cpu") else: device = None if args.test: dataset = generate_data(args.small, args.include_stop, device) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, - include_stop=args.include_stop, - device=device) + dataset = BiosequenceDataset( + args.file, + "fasta", + args.alphabet, + include_stop=args.include_stop, + device=device, + ) args.batch_size = min([dataset.data_size, args.batch_size]) - if args.split > 0.: + if args.split > 0.0: # Train test split. - heldout_num = int(np.ceil(args.split*len(dataset))) + heldout_num = int(np.ceil(args.split * len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths)).tolist() dataset_train, dataset_test = [ - torch.utils.data.Subset(dataset, indices[(offset - length):offset]) - for offset, length in zip(torch._utils._accumulate(data_lengths), - data_lengths)] + torch.utils.data.Subset(dataset, indices[(offset - length) : offset]) + for offset, length in zip( + torch._utils._accumulate(data_lengths), data_lengths + ) + ] else: dataset_train = dataset dataset_test = None @@ -92,38 +99,47 @@ def main(args): pyro.set_rng_seed(args.rng_seed) # Construct model. - model = FactorMuE(dataset.max_length, dataset.alphabet_length, - args.z_dim, - batch_size=args.batch_size, - latent_seq_length=args.latent_seq_length, - indel_factor_dependence=args.indel_factor, - indel_prior_scale=args.indel_prior_scale, - indel_prior_bias=args.indel_prior_bias, - inverse_temp_prior=args.inverse_temp_prior, - weights_prior_scale=args.weights_prior_scale, - offset_prior_scale=args.offset_prior_scale, - z_prior_distribution=args.z_prior, - ARD_prior=args.ARD_prior, - substitution_matrix=(not args.no_substitution_matrix), - substitution_prior_scale=args.substitution_prior_scale, - latent_alphabet_length=args.latent_alphabet, - cuda=args.cuda, - pin_memory=args.pin_mem) + model = FactorMuE( + dataset.max_length, + dataset.alphabet_length, + args.z_dim, + batch_size=args.batch_size, + latent_seq_length=args.latent_seq_length, + indel_factor_dependence=args.indel_factor, + indel_prior_scale=args.indel_prior_scale, + indel_prior_bias=args.indel_prior_bias, + inverse_temp_prior=args.inverse_temp_prior, + weights_prior_scale=args.weights_prior_scale, + offset_prior_scale=args.offset_prior_scale, + z_prior_distribution=args.z_prior, + ARD_prior=args.ARD_prior, + substitution_matrix=(not args.no_substitution_matrix), + substitution_prior_scale=args.substitution_prior_scale, + latent_alphabet_length=args.latent_alphabet, + cuda=args.cuda, + pin_memory=args.pin_mem, + ) # Infer with SVI. - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': args.learning_rate}, - 'milestones': json.loads(args.milestones), - 'gamma': args.learning_gamma}) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": args.learning_rate}, + "milestones": json.loads(args.milestones), + "gamma": args.learning_gamma, + } + ) n_epochs = args.n_epochs - losses = model.fit_svi(dataset_train, n_epochs, args.anneal, - args.batch_size, scheduler, args.jit) + losses = model.fit_svi( + dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit + ) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( - dataset_train, dataset_test, args.jit) - print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) - print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) + dataset_train, dataset_test, args.jit + ) + print("train logp: {} perplex: {}".format(train_lp, train_perplex)) + print("test logp: {} perplex: {}".format(test_lp, test_perplex)) # Get latent space embedding. z_locs, z_scales = model.embed(dataset) @@ -133,21 +149,25 @@ def main(args): if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) - plt.xlabel('step') - plt.ylabel('loss') + plt.xlabel("step") + plt.ylabel("loss") if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, "FactorMuE_plot.loss_{}.pdf".format(time_stamp) + ) + ) plt.figure(figsize=(6, 6)) plt.scatter(z_locs[:, 0], z_locs[:, 1]) - plt.xlabel(r'$z_1$') - plt.ylabel(r'$z_2$') + plt.xlabel(r"$z_1$") + plt.ylabel(r"$z_2$") if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, "FactorMuE_plot.latent_{}.pdf".format(time_stamp) + ) + ) if not args.indel_factor: # Plot indel parameters. See statearrangers.py for details on the @@ -156,128 +176,246 @@ def main(args): insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) plt.plot(insert_expect[:, :, 1].cpu().numpy()) - plt.xlabel('position') - plt.ylabel('probability of insert') - plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) + plt.xlabel("position") + plt.ylabel("probability of insert") + plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"]) if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, + "FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp), + ) + ) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].cpu().numpy()) - plt.xlabel('position') - plt.ylabel('probability of delete') - plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) + plt.xlabel("position") + plt.ylabel("probability of delete") + plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"]) if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, + "FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp), + ) + ) if not args.no_save: - pyro.get_param_store().save(os.path.join( - args.out_folder, - 'FactorMuE_results.params_{}.out'.format(time_stamp))) - with open(os.path.join( + pyro.get_param_store().save( + os.path.join( + args.out_folder, "FactorMuE_results.params_{}.out".format(time_stamp) + ) + ) + with open( + os.path.join( args.out_folder, - 'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), - 'w') as ow: - ow.write('train_lp,test_lp,train_perplex,test_perplex\n') - ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, - test_perplex)) - np.savetxt(os.path.join( - args.out_folder, - 'FactorMuE_results.embed_loc_{}.txt'.format( - time_stamp)), - z_locs.cpu().numpy()) - np.savetxt(os.path.join( - args.out_folder, - 'FactorMuE_results.embed_scale_{}.txt'.format( - time_stamp)), - z_scales.cpu().numpy()) - with open(os.path.join( + "FactorMuE_results.evaluation_{}.txt".format(time_stamp), + ), + "w", + ) as ow: + ow.write("train_lp,test_lp,train_perplex,test_perplex\n") + ow.write( + "{},{},{},{}\n".format(train_lp, test_lp, train_perplex, test_perplex) + ) + np.savetxt( + os.path.join( + args.out_folder, "FactorMuE_results.embed_loc_{}.txt".format(time_stamp) + ), + z_locs.cpu().numpy(), + ) + np.savetxt( + os.path.join( args.out_folder, - 'FactorMuE_results.input_{}.txt'.format(time_stamp)), - 'w') as ow: - ow.write('[args]\n') + "FactorMuE_results.embed_scale_{}.txt".format(time_stamp), + ), + z_scales.cpu().numpy(), + ) + with open( + os.path.join( + args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp) + ), + "w", + ) as ow: + ow.write("[args]\n") for elem in list(args.__dict__.keys()): - ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) + ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) -if __name__ == '__main__': +if __name__ == "__main__": # Parse command line arguments. parser = argparse.ArgumentParser(description="Factor MuE model.") - parser.add_argument("--test", action='store_true', default=False, - help='Run with generated example dataset.') - parser.add_argument("--small", action='store_true', default=False, - help='Run with small example dataset.') + parser.add_argument( + "--test", + action="store_true", + default=False, + help="Run with generated example dataset.", + ) + parser.add_argument( + "--small", + action="store_true", + default=False, + help="Run with small example dataset.", + ) parser.add_argument("-r", "--rng-seed", default=0, type=int) parser.add_argument("--rng-data-seed", default=0, type=int) - parser.add_argument("-f", "--file", default=None, type=str, - help='Input file (fasta format).') - parser.add_argument("-a", "--alphabet", default='amino-acid', - help='Alphabet (amino-acid OR dna OR ATGC ...).') - parser.add_argument("-zdim", "--z-dim", default=2, type=int, - help='z space dimension.') - parser.add_argument("-b", "--batch-size", default=10, type=int, - help='Batch size.') - parser.add_argument("-M", "--latent-seq-length", default=None, type=int, - help='Latent sequence length.') - parser.add_argument("-idfac", "--indel-factor", default=False, - action='store_true', - help='Indel parameters depend on latent variable.') - parser.add_argument("-zdist", "--z-prior", default='Normal', - help='Latent prior distribution (normal or Laplace).') - parser.add_argument("-ard", "--ARD-prior", default=False, - action='store_true', - help='Use automatic relevance detection prior.') - parser.add_argument("--no-substitution-matrix", default=False, - action='store_true', - help='Do not use substitution matrix.') - parser.add_argument("-D", "--latent-alphabet", default=None, type=int, - help='Latent alphabet length.') - parser.add_argument("--include-stop", default=False, action='store_true', - help='Include stop symbol at the end of each sequence.') - parser.add_argument("--indel-prior-scale", default=1., type=float, - help=('Indel prior scale parameter ' + - '(when indel-factor=False).')) - parser.add_argument("--indel-prior-bias", default=10., type=float, - help='Indel prior bias parameter.') - parser.add_argument("--inverse-temp-prior", default=100., type=float, - help='Inverse temperature prior mean.') - parser.add_argument("--weights-prior-scale", default=1., type=float, - help='Factor parameter prior scale.') - parser.add_argument("--offset-prior-scale", default=1., type=float, - help='Offset parameter prior scale.') - parser.add_argument("--substitution-prior-scale", default=10., type=float, - help='Substitution matrix prior scale.') - parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, - help='Learning rate for Adam optimizer.') - parser.add_argument("--milestones", default='[]', type=str, - help='Milestones for multistage learning rate.') - parser.add_argument("--learning-gamma", default=0.5, type=float, - help='Gamma parameter for multistage learning rate.') - parser.add_argument("-e", "--n-epochs", default=10, type=int, - help='Number of epochs of training.') - parser.add_argument("--anneal", default=0., type=float, - help='Number of epochs to anneal beta over.') - parser.add_argument("--no-plots", default=False, action='store_true', - help='Make plots.') - parser.add_argument("--no-save", default=False, action='store_true', - help='Do not save plots and results.') - parser.add_argument("-outf", "--out-folder", default='.', - help='Folder to save plots.') - parser.add_argument("--split", default=0.2, type=float, - help=('Fraction of dataset to holdout for testing')) - parser.add_argument("--jit", default=False, action='store_true', - help='JIT compile the ELBO.') - parser.add_argument("--cuda", default=False, action='store_true', - help='Use GPU.') - parser.add_argument("--cpu-data", default=False, action='store_true', - help='Keep data on CPU (for large datasets).') - parser.add_argument("--pin-mem", default=False, action='store_true', - help='Use pin_memory for faster CPU to GPU transfer.') + parser.add_argument( + "-f", "--file", default=None, type=str, help="Input file (fasta format)." + ) + parser.add_argument( + "-a", + "--alphabet", + default="amino-acid", + help="Alphabet (amino-acid OR dna OR ATGC ...).", + ) + parser.add_argument( + "-zdim", "--z-dim", default=2, type=int, help="z space dimension." + ) + parser.add_argument("-b", "--batch-size", default=10, type=int, help="Batch size.") + parser.add_argument( + "-M", + "--latent-seq-length", + default=None, + type=int, + help="Latent sequence length.", + ) + parser.add_argument( + "-idfac", + "--indel-factor", + default=False, + action="store_true", + help="Indel parameters depend on latent variable.", + ) + parser.add_argument( + "-zdist", + "--z-prior", + default="Normal", + help="Latent prior distribution (normal or Laplace).", + ) + parser.add_argument( + "-ard", + "--ARD-prior", + default=False, + action="store_true", + help="Use automatic relevance detection prior.", + ) + parser.add_argument( + "--no-substitution-matrix", + default=False, + action="store_true", + help="Do not use substitution matrix.", + ) + parser.add_argument( + "-D", + "--latent-alphabet", + default=None, + type=int, + help="Latent alphabet length.", + ) + parser.add_argument( + "--include-stop", + default=False, + action="store_true", + help="Include stop symbol at the end of each sequence.", + ) + parser.add_argument( + "--indel-prior-scale", + default=1.0, + type=float, + help=("Indel prior scale parameter " + "(when indel-factor=False)."), + ) + parser.add_argument( + "--indel-prior-bias", + default=10.0, + type=float, + help="Indel prior bias parameter.", + ) + parser.add_argument( + "--inverse-temp-prior", + default=100.0, + type=float, + help="Inverse temperature prior mean.", + ) + parser.add_argument( + "--weights-prior-scale", + default=1.0, + type=float, + help="Factor parameter prior scale.", + ) + parser.add_argument( + "--offset-prior-scale", + default=1.0, + type=float, + help="Offset parameter prior scale.", + ) + parser.add_argument( + "--substitution-prior-scale", + default=10.0, + type=float, + help="Substitution matrix prior scale.", + ) + parser.add_argument( + "-lr", + "--learning-rate", + default=0.001, + type=float, + help="Learning rate for Adam optimizer.", + ) + parser.add_argument( + "--milestones", + default="[]", + type=str, + help="Milestones for multistage learning rate.", + ) + parser.add_argument( + "--learning-gamma", + default=0.5, + type=float, + help="Gamma parameter for multistage learning rate.", + ) + parser.add_argument( + "-e", "--n-epochs", default=10, type=int, help="Number of epochs of training." + ) + parser.add_argument( + "--anneal", + default=0.0, + type=float, + help="Number of epochs to anneal beta over.", + ) + parser.add_argument( + "--no-plots", default=False, action="store_true", help="Make plots." + ) + parser.add_argument( + "--no-save", + default=False, + action="store_true", + help="Do not save plots and results.", + ) + parser.add_argument( + "-outf", "--out-folder", default=".", help="Folder to save plots." + ) + parser.add_argument( + "--split", + default=0.2, + type=float, + help=("Fraction of dataset to holdout for testing"), + ) + parser.add_argument( + "--jit", default=False, action="store_true", help="JIT compile the ELBO." + ) + parser.add_argument("--cuda", default=False, action="store_true", help="Use GPU.") + parser.add_argument( + "--cpu-data", + default=False, + action="store_true", + help="Keep data on CPU (for large datasets).", + ) + parser.add_argument( + "--pin-mem", + default=False, + action="store_true", + help="Use pin_memory for faster CPU to GPU transfer.", + ) args = parser.parse_args() if args.cuda: diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index 61df67f039..ef1a4ef3ad 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -55,9 +55,10 @@ def generate_data(small_test, include_stop, device): else: mult_dat = 10 - seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat - dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, - device=device) + seqs = ["BABBA"] * mult_dat + ["BAAB"] * mult_dat + ["BABBB"] * mult_dat + dataset = BiosequenceDataset( + seqs, "list", "AB", include_stop=include_stop, device=device + ) return dataset @@ -68,28 +69,34 @@ def main(args): # Load dataset. if args.cpu_data and args.cuda: - device = torch.device('cpu') + device = torch.device("cpu") else: device = None if args.test: dataset = generate_data(args.small, args.include_stop, device) else: - dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, - include_stop=args.include_stop, - device=device) + dataset = BiosequenceDataset( + args.file, + "fasta", + args.alphabet, + include_stop=args.include_stop, + device=device, + ) args.batch_size = min([dataset.data_size, args.batch_size]) - if args.split > 0.: + if args.split > 0.0: # Train test split. - heldout_num = int(np.ceil(args.split*len(dataset))) + heldout_num = int(np.ceil(args.split * len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths)).tolist() dataset_train, dataset_test = [ - torch.utils.data.Subset(dataset, indices[(offset - length):offset]) - for offset, length in zip(torch._utils._accumulate(data_lengths), - data_lengths)] + torch.utils.data.Subset(dataset, indices[(offset - length) : offset]) + for offset, length in zip( + torch._utils._accumulate(data_lengths), data_lengths + ) + ] else: dataset_train = dataset dataset_test = None @@ -98,129 +105,213 @@ def main(args): latent_seq_length = args.latent_seq_length if latent_seq_length is None: latent_seq_length = int(dataset.max_length * 1.1) - model = ProfileHMM(latent_seq_length, dataset.alphabet_length, - prior_scale=args.prior_scale, - indel_prior_bias=args.indel_prior_bias, - cuda=args.cuda, - pin_memory=args.pin_mem) + model = ProfileHMM( + latent_seq_length, + dataset.alphabet_length, + prior_scale=args.prior_scale, + indel_prior_bias=args.indel_prior_bias, + cuda=args.cuda, + pin_memory=args.pin_mem, + ) # Infer with SVI. - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': args.learning_rate}, - 'milestones': json.loads(args.milestones), - 'gamma': args.learning_gamma}) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": args.learning_rate}, + "milestones": json.loads(args.milestones), + "gamma": args.learning_gamma, + } + ) n_epochs = args.n_epochs - losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler, - args.jit) + losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler, args.jit) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( - dataset_train, dataset_test, args.jit) - print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) - print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) + dataset_train, dataset_test, args.jit + ) + print("train logp: {} perplex: {}".format(train_lp, train_perplex)) + print("test logp: {} perplex: {}".format(test_lp, test_perplex)) # Plots. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) - plt.xlabel('step') - plt.ylabel('loss') + plt.xlabel("step") + plt.ylabel("loss") if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'ProfileHMM_plot.loss_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, "ProfileHMM_plot.loss_{}.pdf".format(time_stamp) + ) + ) plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) plt.plot(insert_expect[:, :, 1].cpu().numpy()) - plt.xlabel('position') - plt.ylabel('probability of insert') - plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) + plt.xlabel("position") + plt.ylabel("probability of insert") + plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"]) if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'ProfileHMM_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, + "ProfileHMM_plot.insert_prob_{}.pdf".format(time_stamp), + ) + ) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].cpu().numpy()) - plt.xlabel('position') - plt.ylabel('probability of delete') - plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) + plt.xlabel("position") + plt.ylabel("probability of delete") + plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"]) if not args.no_save: - plt.savefig(os.path.join( - args.out_folder, - 'ProfileHMM_plot.delete_prob_{}.pdf'.format(time_stamp))) + plt.savefig( + os.path.join( + args.out_folder, + "ProfileHMM_plot.delete_prob_{}.pdf".format(time_stamp), + ) + ) if not args.no_save: - pyro.get_param_store().save(os.path.join( + pyro.get_param_store().save( + os.path.join( + args.out_folder, "ProfileHMM_results.params_{}.out".format(time_stamp) + ) + ) + with open( + os.path.join( args.out_folder, - 'ProfileHMM_results.params_{}.out'.format(time_stamp))) - with open(os.path.join( - args.out_folder, - 'ProfileHMM_results.evaluation_{}.txt'.format(time_stamp)), - 'w') as ow: - ow.write('train_lp,test_lp,train_perplex,test_perplex\n') - ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, - test_perplex)) - with open(os.path.join( - args.out_folder, - 'ProfileHMM_results.input_{}.txt'.format(time_stamp)), - 'w') as ow: - ow.write('[args]\n') + "ProfileHMM_results.evaluation_{}.txt".format(time_stamp), + ), + "w", + ) as ow: + ow.write("train_lp,test_lp,train_perplex,test_perplex\n") + ow.write( + "{},{},{},{}\n".format(train_lp, test_lp, train_perplex, test_perplex) + ) + with open( + os.path.join( + args.out_folder, "ProfileHMM_results.input_{}.txt".format(time_stamp) + ), + "w", + ) as ow: + ow.write("[args]\n") for elem in list(args.__dict__.keys()): - ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) + ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) -if __name__ == '__main__': +if __name__ == "__main__": # Parse command line arguments. parser = argparse.ArgumentParser(description="Profile HMM model.") - parser.add_argument("--test", action='store_true', default=False, - help='Run with generated example dataset.') - parser.add_argument("--small", action='store_true', default=False, - help='Run with small example dataset.') + parser.add_argument( + "--test", + action="store_true", + default=False, + help="Run with generated example dataset.", + ) + parser.add_argument( + "--small", + action="store_true", + default=False, + help="Run with small example dataset.", + ) parser.add_argument("-r", "--rng-seed", default=0, type=int) parser.add_argument("--rng-data-seed", default=0, type=int) - parser.add_argument("-f", "--file", default=None, type=str, - help='Input file (fasta format).') - parser.add_argument("-a", "--alphabet", default='amino-acid', - help='Alphabet (amino-acid OR dna OR ATGC ...).') - parser.add_argument("-b", "--batch-size", default=10, type=int, - help='Batch size.') - parser.add_argument("-M", "--latent-seq-length", default=None, type=int, - help='Latent sequence length.') - parser.add_argument("--include-stop", default=False, action='store_true', - help='Include stop symbol at the end of each sequence.') - parser.add_argument("--prior-scale", default=1., type=float, - help='Prior scale parameter (all parameters).') - parser.add_argument("--indel-prior-bias", default=10., type=float, - help='Indel prior bias parameter.') - parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, - help='Learning rate for Adam optimizer.') - parser.add_argument("--milestones", default='[]', type=str, - help='Milestones for multistage learning rate.') - parser.add_argument("--learning-gamma", default=0.5, type=float, - help='Gamma parameter for multistage learning rate.') - parser.add_argument("-e", "--n-epochs", default=10, type=int, - help='Number of epochs of training.') - parser.add_argument("--no-plots", default=False, action='store_true', - help='Make plots.') - parser.add_argument("--no-save", default=False, action='store_true', - help='Do not save plots and results.') - parser.add_argument("-outf", "--out-folder", default='.', - help='Folder to save plots.') - parser.add_argument("--split", default=0.2, type=float, - help=('Fraction of dataset to holdout for testing')) - parser.add_argument("--jit", default=False, action='store_true', - help='JIT compile the ELBO.') - parser.add_argument("--cuda", default=False, action='store_true', - help='Use GPU.') - parser.add_argument("--cpu-data", default=False, action='store_true', - help='Keep data on CPU (for large datasets).') - parser.add_argument("--pin-mem", default=False, action='store_true', - help='Use pin_memory for faster GPU transfer.') + parser.add_argument( + "-f", "--file", default=None, type=str, help="Input file (fasta format)." + ) + parser.add_argument( + "-a", + "--alphabet", + default="amino-acid", + help="Alphabet (amino-acid OR dna OR ATGC ...).", + ) + parser.add_argument("-b", "--batch-size", default=10, type=int, help="Batch size.") + parser.add_argument( + "-M", + "--latent-seq-length", + default=None, + type=int, + help="Latent sequence length.", + ) + parser.add_argument( + "--include-stop", + default=False, + action="store_true", + help="Include stop symbol at the end of each sequence.", + ) + parser.add_argument( + "--prior-scale", + default=1.0, + type=float, + help="Prior scale parameter (all parameters).", + ) + parser.add_argument( + "--indel-prior-bias", + default=10.0, + type=float, + help="Indel prior bias parameter.", + ) + parser.add_argument( + "-lr", + "--learning-rate", + default=0.001, + type=float, + help="Learning rate for Adam optimizer.", + ) + parser.add_argument( + "--milestones", + default="[]", + type=str, + help="Milestones for multistage learning rate.", + ) + parser.add_argument( + "--learning-gamma", + default=0.5, + type=float, + help="Gamma parameter for multistage learning rate.", + ) + parser.add_argument( + "-e", "--n-epochs", default=10, type=int, help="Number of epochs of training." + ) + parser.add_argument( + "--no-plots", default=False, action="store_true", help="Make plots." + ) + parser.add_argument( + "--no-save", + default=False, + action="store_true", + help="Do not save plots and results.", + ) + parser.add_argument( + "-outf", "--out-folder", default=".", help="Folder to save plots." + ) + parser.add_argument( + "--split", + default=0.2, + type=float, + help=("Fraction of dataset to holdout for testing"), + ) + parser.add_argument( + "--jit", default=False, action="store_true", help="JIT compile the ELBO." + ) + parser.add_argument("--cuda", default=False, action="store_true", help="Use GPU.") + parser.add_argument( + "--cpu-data", + default=False, + action="store_true", + help="Keep data on CPU (for large datasets).", + ) + parser.add_argument( + "--pin-mem", + default=False, + action="store_true", + help="Use pin_memory for faster GPU transfer.", + ) args = parser.parse_args() if args.cuda: diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 0e0ccbf561..835ac49012 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -50,15 +50,15 @@ # Set up regression model dimensions N = 100 # number of participants -p = 2 # number of features -prior_sds = torch.tensor([10., 2.5]) +p = 2 # number of features +prior_sds = torch.tensor([10.0, 2.5]) # Model and guide using known obs_sd model, guide = zero_mean_unit_obs_sd_lm(prior_sds) def estimated_ape(ns, num_vi_steps): - designs = [group_assignment_matrix(torch.tensor([n1, N-n1])) for n1 in ns] + designs = [group_assignment_matrix(torch.tensor([n1, N - n1])) for n1 in ns] X = torch.stack(designs) est_ape = vi_eig( model, @@ -68,10 +68,13 @@ def estimated_ape(ns, num_vi_steps): vi_parameters={ "guide": guide, "optim": optim.Adam({"lr": 0.05}), - "loss": TraceEnum_ELBO(strict_enumeration_warning=False).differentiable_loss, - "num_steps": num_vi_steps}, + "loss": TraceEnum_ELBO( + strict_enumeration_warning=False + ).differentiable_loss, + "num_steps": num_vi_steps, + }, is_parameters={"num_samples": 1}, - eig=False + eig=False, ) return est_ape @@ -79,12 +82,12 @@ def estimated_ape(ns, num_vi_steps): def true_ape(ns): """Analytic APE""" true_ape = [] - prior_cov = torch.diag(prior_sds**2) - designs = [group_assignment_matrix(torch.tensor([n1, N-n1])) for n1 in ns] + prior_cov = torch.diag(prior_sds ** 2) + designs = [group_assignment_matrix(torch.tensor([n1, N - n1])) for n1 in ns] for i in range(len(ns)): x = designs[i] - posterior_cov = analytic_posterior_cov(prior_cov, x, torch.tensor(1.)) - true_ape.append(0.5*torch.logdet(2*np.pi*np.e*posterior_cov)) + posterior_cov = analytic_posterior_cov(prior_cov, x, torch.tensor(1.0)) + true_ape.append(0.5 * torch.logdet(2 * np.pi * np.e * posterior_cov)) return torch.tensor(true_ape) @@ -101,13 +104,18 @@ def main(num_vi_steps, num_bo_steps, seed): num_acqs = [2, 10] for f, noise, num_acquisitions in zip(estimators, noises, num_acqs): - X = torch.tensor([25., 75.]) + X = torch.tensor([25.0, 75.0]) y = f(X) gpmodel = gp.models.GPRegression( - X, y, gp.kernels.Matern52(input_dim=1, lengthscale=torch.tensor(10.)), - noise=torch.tensor(noise), jitter=1e-6) - gpbo = GPBayesOptimizer(constraints.interval(0, 100), gpmodel, - num_acquisitions=num_acquisitions) + X, + y, + gp.kernels.Matern52(input_dim=1, lengthscale=torch.tensor(10.0)), + noise=torch.tensor(noise), + jitter=1e-6, + ) + gpbo = GPBayesOptimizer( + constraints.interval(0, 100), gpmodel, num_acquisitions=num_acquisitions + ) pyro.clear_param_store() for i in range(num_bo_steps): result = gpbo.get_step(f, None, verbose=True) @@ -117,11 +125,12 @@ def main(num_vi_steps, num_bo_steps, seed): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="A/B test experiment design using VI") parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int) - parser.add_argument('--num-bo-steps', nargs="?", default=5, type=int) - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') + parser.add_argument("--num-bo-steps", nargs="?", default=5, type=int) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) args = parser.parse_args() main(args.num_vi_steps, args.num_bo_steps, args.seed) diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 6132dee48a..6d797d3ecf 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -38,9 +38,14 @@ def update_posterior(self, X, y): y = torch.cat([self.gpmodel.y, y]) self.gpmodel.set_data(X, y) optimizer = torch.optim.Adam(self.gpmodel.parameters(), lr=0.001) - gp.util.train(self.gpmodel, optimizer, - loss_fn=TraceEnum_ELBO(strict_enumeration_warning=False).differentiable_loss, - retain_graph=True) + gp.util.train( + self.gpmodel, + optimizer, + loss_fn=TraceEnum_ELBO( + strict_enumeration_warning=False + ).differentiable_loss, + retain_graph=True, + ) def find_a_candidate(self, differentiable, x_init): """Given a starting point, `x_init`, takes one LBFGS step @@ -59,12 +64,13 @@ def find_a_candidate(self, differentiable, x_init): def closure(): minimizer.zero_grad() - if (torch.log(torch.abs(unconstrained_x)) > 25.).any(): - return torch.tensor(float('inf')) + if (torch.log(torch.abs(unconstrained_x)) > 25.0).any(): + return torch.tensor(float("inf")) x = transform_to(self.constraints)(unconstrained_x) y = differentiable(x) - autograd.backward(unconstrained_x, - autograd.grad(y, unconstrained_x, retain_graph=True)) + autograd.backward( + unconstrained_x, autograd.grad(y, unconstrained_x, retain_graph=True) + ) return y minimizer.step(closure) @@ -89,8 +95,9 @@ def opt_differentiable(self, differentiable, num_candidates=5): candidates = [] values = [] for j in range(num_candidates): - x_init = (torch.empty(1, dtype=self.gpmodel.X.dtype, device=self.gpmodel.X.device) - .uniform_(self.constraints.lower_bound, self.constraints.upper_bound)) + x_init = torch.empty( + 1, dtype=self.gpmodel.X.dtype, device=self.gpmodel.X.device + ).uniform_(self.constraints.lower_bound, self.constraints.upper_bound) x, y = self.find_a_candidate(differentiable, x_init) if torch.isnan(y): continue diff --git a/examples/contrib/timeseries/gp_models.py b/examples/contrib/timeseries/gp_models.py index 5dd90cd6a1..0d7e3236f4 100644 --- a/examples/contrib/timeseries/gp_models.py +++ b/examples/contrib/timeseries/gp_models.py @@ -25,7 +25,7 @@ def main(args): if not args.test: download_data() T_forecast = 349 - data = np.loadtxt('eeg.dat', delimiter=',', skiprows=19) + data = np.loadtxt("eeg.dat", delimiter=",", skiprows=19) print("[raw data shape] {}".format(data.shape)) data = torch.tensor(data[::20, :-1]).double() print("[data shape after thinning] {}".format(data.shape)) @@ -47,18 +47,29 @@ def main(args): # set up model if args.model == "imgp": - gp = IndependentMaternGP(nu=1.5, obs_dim=obs_dim, - length_scale_init=1.5 * torch.ones(obs_dim)).double() + gp = IndependentMaternGP( + nu=1.5, obs_dim=obs_dim, length_scale_init=1.5 * torch.ones(obs_dim) + ).double() elif args.model == "lcmgp": num_gps = 9 - gp = LinearlyCoupledMaternGP(nu=1.5, obs_dim=obs_dim, num_gps=num_gps, - length_scale_init=1.5 * torch.ones(num_gps)).double() + gp = LinearlyCoupledMaternGP( + nu=1.5, + obs_dim=obs_dim, + num_gps=num_gps, + length_scale_init=1.5 * torch.ones(num_gps), + ).double() # set up optimizer - adam = torch.optim.Adam(gp.parameters(), lr=args.init_learning_rate, - betas=(args.beta1, 0.999), amsgrad=True) + adam = torch.optim.Adam( + gp.parameters(), + lr=args.init_learning_rate, + betas=(args.beta1, 0.999), + amsgrad=True, + ) # we decay the learning rate over the course of training - gamma = (args.final_learning_rate / args.init_learning_rate) ** (1.0 / args.num_steps) + gamma = (args.final_learning_rate / args.init_learning_rate) ** ( + 1.0 / args.num_steps + ) scheduler = torch.optim.lr_scheduler.ExponentialLR(adam, gamma=gamma) report_frequency = 10 @@ -82,30 +93,37 @@ def main(args): # do rolling prediction print("doing one-step-ahead forecasting...") - onestep_means, onestep_stds = np.zeros((T_onestep, obs_dim)), np.zeros((T_onestep, obs_dim)) + onestep_means, onestep_stds = np.zeros((T_onestep, obs_dim)), np.zeros( + (T_onestep, obs_dim) + ) for t in range(T_onestep): # predict one step into the future, conditioning on all previous data. # note that each call to forecast() conditions on more data than the previous call dts = torch.tensor([1.0]).double() - pred_dist = gp.forecast(data[0:T_train + t, :], dts) + pred_dist = gp.forecast(data[0 : T_train + t, :], dts) onestep_means[t, :] = pred_dist.loc.data.numpy() if args.model == "imgp": onestep_stds[t, :] = pred_dist.scale.data.numpy() elif args.model == "lcmgp": - onestep_stds[t, :] = pred_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2).data.numpy() + onestep_stds[t, :] = pred_dist.covariance_matrix.diagonal( + dim1=-1, dim2=-2 + ).data.numpy() # do (non-rolling) multi-step forecasting print("doing multi-step forecasting...") dts = (1 + torch.arange(T_multistep)).double() - pred_dist = gp.forecast(data[0:T_train + T_onestep, :], dts) + pred_dist = gp.forecast(data[0 : T_train + T_onestep, :], dts) multistep_means = pred_dist.loc.data.numpy() if args.model == "imgp": multistep_stds = pred_dist.scale.data.numpy() elif args.model == "lcmgp": - multistep_stds = pred_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2).data.numpy() + multistep_stds = pred_dist.covariance_matrix.diagonal( + dim1=-1, dim2=-2 + ).data.numpy() import matplotlib - matplotlib.use('Agg') # noqa: E402 + + matplotlib.use("Agg") # noqa: E402 import matplotlib.pyplot as plt f, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) @@ -116,46 +134,70 @@ def main(args): which = [0, 4, 10][k] # plot raw data - ax.plot(to_seconds * np.arange(T), data[:, which], 'ko', markersize=2, label='Data') + ax.plot( + to_seconds * np.arange(T), + data[:, which], + "ko", + markersize=2, + label="Data", + ) # plot mean predictions for one-step-ahead forecasts - ax.plot(to_seconds * (T_train + np.arange(T_onestep)), - onestep_means[:, which], ls='solid', color='b', label='One-step') + ax.plot( + to_seconds * (T_train + np.arange(T_onestep)), + onestep_means[:, which], + ls="solid", + color="b", + label="One-step", + ) # plot 90% confidence intervals for one-step-ahead forecasts - ax.fill_between(to_seconds * (T_train + np.arange(T_onestep)), - onestep_means[:, which] - 1.645 * onestep_stds[:, which], - onestep_means[:, which] + 1.645 * onestep_stds[:, which], - color='b', alpha=0.20) + ax.fill_between( + to_seconds * (T_train + np.arange(T_onestep)), + onestep_means[:, which] - 1.645 * onestep_stds[:, which], + onestep_means[:, which] + 1.645 * onestep_stds[:, which], + color="b", + alpha=0.20, + ) # plot mean predictions for multi-step-ahead forecasts - ax.plot(to_seconds * (T_train + T_onestep + np.arange(T_multistep)), - multistep_means[:, which], ls='solid', color='r', label='Multi-step') + ax.plot( + to_seconds * (T_train + T_onestep + np.arange(T_multistep)), + multistep_means[:, which], + ls="solid", + color="r", + label="Multi-step", + ) # plot 90% confidence intervals for multi-step-ahead forecasts - ax.fill_between(to_seconds * (T_train + T_onestep + np.arange(T_multistep)), - multistep_means[:, which] - 1.645 * multistep_stds[:, which], - multistep_means[:, which] + 1.645 * multistep_stds[:, which], - color='r', alpha=0.20) + ax.fill_between( + to_seconds * (T_train + T_onestep + np.arange(T_multistep)), + multistep_means[:, which] - 1.645 * multistep_stds[:, which], + multistep_means[:, which] + 1.645 * multistep_stds[:, which], + color="r", + alpha=0.20, + ) ax.set_ylabel("$y_{%d}$" % (which + 1), fontsize=20) - ax.tick_params(axis='both', which='major', labelsize=14) + ax.tick_params(axis="both", which="major", labelsize=14) if k == 1: - ax.legend(loc='upper left', fontsize=16) + ax.legend(loc="upper left", fontsize=16) plt.tight_layout(pad=0.7) - plt.savefig('eeg.{}.pdf'.format(args.model)) + plt.savefig("eeg.{}.pdf".format(args.model)) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="contrib.timeseries example usage") parser.add_argument("-n", "--num-steps", default=300, type=int) parser.add_argument("-s", "--seed", default=0, type=int) - parser.add_argument("-m", "--model", default="imgp", type=str, choices=["imgp", "lcmgp"]) + parser.add_argument( + "-m", "--model", default="imgp", type=str, choices=["imgp", "lcmgp"] + ) parser.add_argument("-ilr", "--init-learning-rate", default=0.01, type=float) parser.add_argument("-flr", "--final-learning-rate", default=0.0003, type=float) parser.add_argument("-b1", "--beta1", default=0.50, type=float) - parser.add_argument("--test", action='store_true') - parser.add_argument("--plot", action='store_true') + parser.add_argument("--test", action="store_true") + parser.add_argument("--plot", action="store_true") args = parser.parse_args() main(args) diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index 23e1591016..6c4a0a18e0 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -34,13 +34,20 @@ def __init__(self, masked_with=-1): def forward(self, input, target): target = target.view(input.shape) - loss = F.binary_cross_entropy(input, target, reduction='none') + loss = F.binary_cross_entropy(input, target, reduction="none") loss[target == self.masked_with] = 0 return loss.sum() -def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, - early_stop_patience, model_path): +def train( + device, + dataloaders, + dataset_sizes, + learning_rate, + num_epochs, + early_stop_patience, + model_path, +): # Train baseline baseline_net = BaselineNet(500, 500) @@ -51,8 +58,8 @@ def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, early_stop_count = 0 for epoch in range(num_epochs): - for phase in ['train', 'val']: - if phase == 'train': + for phase in ["train", "val"]: + if phase == "train": baseline_net.train() else: baseline_net.eval() @@ -60,30 +67,33 @@ def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, running_loss = 0.0 num_preds = 0 - bar = tqdm(dataloaders[phase], - desc='NN Epoch {} {}'.format(epoch, phase).ljust(20)) + bar = tqdm( + dataloaders[phase], desc="NN Epoch {} {}".format(epoch, phase).ljust(20) + ) for i, batch in enumerate(bar): - inputs = batch['input'].to(device) - outputs = batch['output'].to(device) + inputs = batch["input"].to(device) + outputs = batch["output"].to(device) optimizer.zero_grad() - with torch.set_grad_enabled(phase == 'train'): + with torch.set_grad_enabled(phase == "train"): preds = baseline_net(inputs) loss = criterion(preds, outputs) / inputs.size(0) - if phase == 'train': + if phase == "train": loss.backward() optimizer.step() running_loss += loss.item() num_preds += 1 if i % 10 == 0: - bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds), - early_stop_count=early_stop_count) + bar.set_postfix( + loss="{:.2f}".format(running_loss / num_preds), + early_stop_count=early_stop_count, + ) epoch_loss = running_loss / dataset_sizes[phase] # deep copy the model - if phase == 'val': + if phase == "val": if epoch_loss < best_loss: best_loss = epoch_loss best_model_wts = copy.deepcopy(baseline_net.state_dict()) diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index f499aaf452..e6a6fd8eda 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -81,7 +81,7 @@ def model(self, xs, ys=None): # sample the handwriting style from the prior distribution, which is # modulated by the input xs. prior_loc, prior_scale = self.prior_net(xs, y_hat) - zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1)) + zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1)) # the output y is generated from the distribution pθ(y|x, z) loc = self.generation_net(zs) @@ -90,16 +90,17 @@ def model(self, xs, ys=None): # In training, we will only sample in the masked image mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1) mask_ys = ys[xs == -1].view(batch_size, -1) - pyro.sample('y', - dist.Bernoulli(mask_loc, validate_args=False) - .to_event(1), - obs=mask_ys) + pyro.sample( + "y", + dist.Bernoulli(mask_loc, validate_args=False).to_event(1), + obs=mask_ys, + ) else: # In testing, no need to sample: the output is already a # probability in [0, 1] range, which better represent pixel # values considering grayscale. If we sample, we will force # each pixel to be either 0 or 1, killing the grayscale - pyro.deterministic('y', loc.detach()) + pyro.deterministic("y", loc.detach()) # return the loc so we can visualize it later return loc @@ -119,8 +120,16 @@ def guide(self, xs, ys=None): pyro.sample("z", dist.Normal(loc, scale).to_event(1)) -def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, - early_stop_patience, model_path, pre_trained_baseline_net): +def train( + device, + dataloaders, + dataset_sizes, + learning_rate, + num_epochs, + early_stop_patience, + model_path, + pre_trained_baseline_net, +): # clear param store pyro.clear_param_store() @@ -136,18 +145,20 @@ def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, for epoch in range(num_epochs): # Each epoch has a training and validation phase - for phase in ['train', 'val']: + for phase in ["train", "val"]: running_loss = 0.0 num_preds = 0 # Iterate over data. - bar = tqdm(dataloaders[phase], - desc='CVAE Epoch {} {}'.format(epoch, phase).ljust(20)) + bar = tqdm( + dataloaders[phase], + desc="CVAE Epoch {} {}".format(epoch, phase).ljust(20), + ) for i, batch in enumerate(bar): - inputs = batch['input'].to(device) - outputs = batch['output'].to(device) + inputs = batch["input"].to(device) + outputs = batch["output"].to(device) - if phase == 'train': + if phase == "train": loss = svi.step(inputs, outputs) else: loss = svi.evaluate_loss(inputs, outputs) @@ -156,12 +167,14 @@ def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, running_loss += loss / inputs.size(0) num_preds += 1 if i % 10 == 0: - bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds), - early_stop_count=early_stop_count) + bar.set_postfix( + loss="{:.2f}".format(running_loss / num_preds), + early_stop_count=early_stop_count, + ) epoch_loss = running_loss / dataset_sizes[phase] # deep copy the model - if phase == 'val': + if phase == "val": if epoch_loss < best_loss: best_loss = epoch_loss torch.save(cvae_net.state_dict(), model_path) diff --git a/examples/cvae/main.py b/examples/cvae/main.py index 0d1fd39596..0c65ab95ef 100644 --- a/examples/cvae/main.py +++ b/examples/cvae/main.py @@ -13,22 +13,25 @@ def main(args): - device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda - else "cpu") + device = torch.device( + "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu" + ) results = [] columns = [] for num_quadrant_inputs in args.num_quadrant_inputs: # adds an s in case of plural quadrants - maybes = 's' if num_quadrant_inputs > 1 else '' + maybes = "s" if num_quadrant_inputs > 1 else "" - print('Training with {} quadrant{} as input...' - .format(num_quadrant_inputs, maybes)) + print( + "Training with {} quadrant{} as input...".format( + num_quadrant_inputs, maybes + ) + ) # Dataset datasets, dataloaders, dataset_sizes = get_data( - num_quadrant_inputs=num_quadrant_inputs, - batch_size=128 + num_quadrant_inputs=num_quadrant_inputs, batch_size=128 ) # Train baseline @@ -39,7 +42,7 @@ def main(args): learning_rate=args.learning_rate, num_epochs=args.num_epochs, early_stop_patience=args.early_stop_patience, - model_path='baseline_net_q{}.pth'.format(num_quadrant_inputs) + model_path="baseline_net_q{}.pth".format(num_quadrant_inputs), ) # Train CVAE @@ -50,8 +53,8 @@ def main(args): learning_rate=args.learning_rate, num_epochs=args.num_epochs, early_stop_patience=args.early_stop_patience, - model_path='cvae_net_q{}.pth'.format(num_quadrant_inputs), - pre_trained_baseline_net=baseline_net + model_path="cvae_net_q{}.pth".format(num_quadrant_inputs), + pre_trained_baseline_net=baseline_net, ) # Visualize conditional predictions @@ -62,7 +65,7 @@ def main(args): pre_trained_cvae=cvae_net, num_images=args.num_images, num_samples=args.num_samples, - image_path='cvae_plot_q{}.png'.format(num_quadrant_inputs) + image_path="cvae_plot_q{}.png".format(num_quadrant_inputs), ) # Retrieve conditional log likelihood @@ -72,38 +75,63 @@ def main(args): pre_trained_baseline=baseline_net, pre_trained_cvae=cvae_net, num_particles=args.num_particles, - col_name='{} quadrant{}'.format(num_quadrant_inputs, maybes) + col_name="{} quadrant{}".format(num_quadrant_inputs, maybes), ) results.append(df) - columns.append('{} quadrant{}'.format(num_quadrant_inputs, maybes)) + columns.append("{} quadrant{}".format(num_quadrant_inputs, maybes)) results = pd.concat(results, axis=1, ignore_index=True) results.columns = columns - results.loc['Performance gap', :] = results.iloc[0, :] - results.iloc[1, :] - results.to_csv('results.csv') + results.loc["Performance gap", :] = results.iloc[0, :] - results.iloc[1, :] + results.to_csv("results.csv") -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-nq', '--num-quadrant-inputs', metavar='N', type=int, - nargs='+', default=[1, 2, 3], - help='num of quadrants to use as inputs') - parser.add_argument('-n', '--num-epochs', default=101, type=int, - help='number of training epochs') - parser.add_argument('-esp', '--early-stop-patience', default=3, type=int, - help='early stop patience') - parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, - help='learning rate') - parser.add_argument('--cuda', action='store_true', default=False, - help='whether to use cuda') - parser.add_argument('-vi', '--num-images', default=10, type=int, - help='number of images to visualize') - parser.add_argument('-vs', '--num-samples', default=10, type=int, - help='number of samples to visualize per image') - parser.add_argument('-p', '--num-particles', default=10, type=int, - help='n of particles to estimate logpθ(y|x,z) in ELBO') + parser.add_argument( + "-nq", + "--num-quadrant-inputs", + metavar="N", + type=int, + nargs="+", + default=[1, 2, 3], + help="num of quadrants to use as inputs", + ) + parser.add_argument( + "-n", "--num-epochs", default=101, type=int, help="number of training epochs" + ) + parser.add_argument( + "-esp", "--early-stop-patience", default=3, type=int, help="early stop patience" + ) + parser.add_argument( + "-lr", "--learning-rate", default=1.0e-3, type=float, help="learning rate" + ) + parser.add_argument( + "--cuda", action="store_true", default=False, help="whether to use cuda" + ) + parser.add_argument( + "-vi", + "--num-images", + default=10, + type=int, + help="number of images to visualize", + ) + parser.add_argument( + "-vs", + "--num-samples", + default=10, + type=int, + help="number of samples to visualize per image", + ) + parser.add_argument( + "-p", + "--num-particles", + default=10, + type=int, + help="n of particles to estimate logpθ(y|x,z) in ELBO", + ) args = parser.parse_args() main(args) diff --git a/examples/cvae/mnist.py b/examples/cvae/mnist.py index e9686b64f4..a645292788 100644 --- a/examples/cvae/mnist.py +++ b/examples/cvae/mnist.py @@ -19,7 +19,7 @@ def __len__(self): def __getitem__(self, item): image, digit = self.original[item] - sample = {'original': image, 'digit': digit} + sample = {"original": image, "digit": digit} if self.transform: sample = self.transform(sample) @@ -28,9 +28,10 @@ def __getitem__(self, item): class ToTensor: def __call__(self, sample): - sample['original'] = functional.to_tensor(sample['original']) - sample['digit'] = torch.as_tensor(np.asarray(sample['digit']), - dtype=torch.int64) + sample["original"] = functional.to_tensor(sample["original"]) + sample["digit"] = torch.as_tensor( + np.asarray(sample["digit"]), dtype=torch.int64 + ) return sample @@ -41,55 +42,52 @@ class MaskImages: quadrant(s) setting their pixels with -1. Additionally, the transformation adds the target output in the sample dict as the complementary of the input """ + def __init__(self, num_quadrant_inputs, mask_with=-1): if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4: - raise ValueError('Number of quadrants as inputs must be 1, 2 or 3') + raise ValueError("Number of quadrants as inputs must be 1, 2 or 3") self.num = num_quadrant_inputs self.mask_with = mask_with def __call__(self, sample): - tensor = sample['original'].squeeze() + tensor = sample["original"].squeeze() out = tensor.detach().clone() h, w = tensor.shape # removes the bottom left quadrant from the target output - out[h // 2:, :w // 2] = self.mask_with + out[h // 2 :, : w // 2] = self.mask_with # if num of quadrants to be used as input is 2, # also removes the top left quadrant from the target output if self.num == 2: - out[:, :w // 2] = self.mask_with + out[:, : w // 2] = self.mask_with # if num of quadrants to be used as input is 3, # also removes the top right quadrant from the target output if self.num == 3: - out[:h // 2, :] = self.mask_with + out[: h // 2, :] = self.mask_with # now, sets the input as complementary inp = tensor.clone() inp[out != -1] = self.mask_with - sample['input'] = inp - sample['output'] = out + sample["input"] = inp + sample["output"] = out return sample def get_data(num_quadrant_inputs, batch_size): - transforms = Compose([ - ToTensor(), - MaskImages(num_quadrant_inputs=num_quadrant_inputs) - ]) + transforms = Compose( + [ToTensor(), MaskImages(num_quadrant_inputs=num_quadrant_inputs)] + ) datasets, dataloaders, dataset_sizes = {}, {}, {} - for mode in ['train', 'val']: + for mode in ["train", "val"]: datasets[mode] = CVAEMNIST( - '../data', - download=True, - transform=transforms, - train=mode == 'train' + "../data", download=True, transform=transforms, train=mode == "train" ) dataloaders[mode] = DataLoader( datasets[mode], batch_size=batch_size, - shuffle=mode == 'train', - num_workers=0 + shuffle=mode == "train", + num_workers=0, ) dataset_sizes[mode] = len(datasets[mode]) diff --git a/examples/cvae/util.py b/examples/cvae/util.py index 87650298ef..57caa55e1d 100644 --- a/examples/cvae/util.py +++ b/examples/cvae/util.py @@ -23,44 +23,50 @@ def imshow(inp, image_path=None): inp = np.concatenate([space, inp], axis=1) ax = plt.axes(frameon=False, xticks=[], yticks=[]) - ax.text(0, 23, 'Inputs:') - ax.text(0, 23 + 28 + 3, 'Truth:') - ax.text(0, 23 + (28 + 3) * 2, 'NN:') - ax.text(0, 23 + (28 + 3) * 3, 'CVAE:') + ax.text(0, 23, "Inputs:") + ax.text(0, 23 + 28 + 3, "Truth:") + ax.text(0, 23 + (28 + 3) * 2, "NN:") + ax.text(0, 23 + (28 + 3) * 3, "CVAE:") ax.imshow(inp) if image_path is not None: Path(image_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(image_path, bbox_inches='tight', pad_inches=0.1) + plt.savefig(image_path, bbox_inches="tight", pad_inches=0.1) else: plt.show() plt.clf() -def visualize(device, num_quadrant_inputs, pre_trained_baseline, - pre_trained_cvae, num_images, num_samples, image_path=None): +def visualize( + device, + num_quadrant_inputs, + pre_trained_baseline, + pre_trained_cvae, + num_images, + num_samples, + image_path=None, +): # Load sample random data datasets, _, dataset_sizes = get_data( - num_quadrant_inputs=num_quadrant_inputs, - batch_size=num_images + num_quadrant_inputs=num_quadrant_inputs, batch_size=num_images ) - dataloader = DataLoader(datasets['val'], batch_size=num_images, shuffle=True) + dataloader = DataLoader(datasets["val"], batch_size=num_images, shuffle=True) batch = next(iter(dataloader)) - inputs = batch['input'].to(device) - outputs = batch['output'].to(device) - originals = batch['original'].to(device) + inputs = batch["input"].to(device) + outputs = batch["output"].to(device) + originals = batch["original"].to(device) # Make predictions with torch.no_grad(): baseline_preds = pre_trained_baseline(inputs).view(outputs.shape) - predictive = Predictive(pre_trained_cvae.model, - guide=pre_trained_cvae.guide, - num_samples=num_samples) - cvae_preds = predictive(inputs)['y'].view(num_samples, num_images, 28, 28) + predictive = Predictive( + pre_trained_cvae.model, guide=pre_trained_cvae.guide, num_samples=num_samples + ) + cvae_preds = predictive(inputs)["y"].view(num_samples, num_images, 28, 28) # Predictions are only made in the pixels not masked. This completes # the input quadrant with the prediction for the missing quadrants, for @@ -91,20 +97,34 @@ def visualize(device, num_quadrant_inputs, pre_trained_baseline, cvae_tensor[:, (i + 1) * 28, :] = 0.3 # concatenate all tensors - grid_tensor = torch.cat([inputs_tensor, separator_tensor, originals_tensor, - separator_tensor, baseline_tensor, - separator_tensor, cvae_tensor], dim=1) + grid_tensor = torch.cat( + [ + inputs_tensor, + separator_tensor, + originals_tensor, + separator_tensor, + baseline_tensor, + separator_tensor, + cvae_tensor, + ], + dim=1, + ) # plot tensors imshow(grid_tensor, image_path=image_path) -def generate_table(device, num_quadrant_inputs, pre_trained_baseline, - pre_trained_cvae, num_particles, col_name): +def generate_table( + device, + num_quadrant_inputs, + pre_trained_baseline, + pre_trained_cvae, + num_particles, + col_name, +): # Load sample random data datasets, dataloaders, dataset_sizes = get_data( - num_quadrant_inputs=num_quadrant_inputs, - batch_size=32 + num_quadrant_inputs=num_quadrant_inputs, batch_size=32 ) # Load sample data @@ -115,14 +135,13 @@ def generate_table(device, num_quadrant_inputs, pre_trained_baseline, cvae_mc_cll = 0.0 num_preds = 0 - df = pd.DataFrame(index=['NN (baseline)', 'CVAE (Monte Carlo)'], - columns=[col_name]) + df = pd.DataFrame(index=["NN (baseline)", "CVAE (Monte Carlo)"], columns=[col_name]) # Iterate over data. - bar = tqdm(dataloaders['val'], desc='Generating predictions'.ljust(20)) + bar = tqdm(dataloaders["val"], desc="Generating predictions".ljust(20)) for batch in bar: - inputs = batch['input'].to(device) - outputs = batch['output'].to(device) + inputs = batch["input"].to(device) + outputs = batch["output"].to(device) num_preds += 1 # Compute negative log likelihood for the baseline NN @@ -131,9 +150,9 @@ def generate_table(device, num_quadrant_inputs, pre_trained_baseline, baseline_cll += criterion(preds, outputs).item() / inputs.size(0) # Compute the negative conditional log likelihood for the CVAE - cvae_mc_cll += loss_fn(pre_trained_cvae.model, - pre_trained_cvae.guide, - inputs, outputs).detach().item() / inputs.size(0) + cvae_mc_cll += loss_fn( + pre_trained_cvae.model, pre_trained_cvae.guide, inputs, outputs + ).detach().item() / inputs.size(0) df.iloc[0, 0] = baseline_cll / num_preds df.iloc[1, 0] = cvae_mc_cll / num_preds diff --git a/examples/dmm.py b/examples/dmm.py index 5bff3315e8..c58cbe63b0 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -150,22 +150,40 @@ class DMM(nn.Module): variational distribution (the guide) for the Deep Markov Model """ - def __init__(self, input_dim=88, z_dim=100, emission_dim=100, - transition_dim=200, rnn_dim=600, num_layers=1, rnn_dropout_rate=0.0, - num_iafs=0, iaf_dim=50, use_cuda=False): + def __init__( + self, + input_dim=88, + z_dim=100, + emission_dim=100, + transition_dim=200, + rnn_dim=600, + num_layers=1, + rnn_dropout_rate=0.0, + num_iafs=0, + iaf_dim=50, + use_cuda=False, + ): super().__init__() # instantiate PyTorch modules used in the model and guide below self.emitter = Emitter(input_dim, z_dim, emission_dim) self.trans = GatedTransition(z_dim, transition_dim) self.combiner = Combiner(z_dim, rnn_dim) # dropout just takes effect on inner layers of rnn - rnn_dropout_rate = 0. if num_layers == 1 else rnn_dropout_rate - self.rnn = nn.RNN(input_size=input_dim, hidden_size=rnn_dim, nonlinearity='relu', - batch_first=True, bidirectional=False, num_layers=num_layers, - dropout=rnn_dropout_rate) + rnn_dropout_rate = 0.0 if num_layers == 1 else rnn_dropout_rate + self.rnn = nn.RNN( + input_size=input_dim, + hidden_size=rnn_dim, + nonlinearity="relu", + batch_first=True, + bidirectional=False, + num_layers=num_layers, + dropout=rnn_dropout_rate, + ) # if we're using normalizing flows, instantiate those too - self.iafs = [affine_autoregressive(z_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs)] + self.iafs = [ + affine_autoregressive(z_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs) + ] self.iafs_modules = nn.ModuleList(self.iafs) # define a (trainable) parameters z_0 and z_q_0 that help define the probability @@ -182,8 +200,14 @@ def __init__(self, input_dim=88, z_dim=100, emission_dim=100, self.cuda() # the model p(x_{1:T} | z_{1:T}) p(z_{1:T}) - def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, - mini_batch_seq_lengths, annealing_factor=1.0): + def model( + self, + mini_batch, + mini_batch_reversed, + mini_batch_mask, + mini_batch_seq_lengths, + annealing_factor=1.0, + ): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) @@ -214,27 +238,37 @@ def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. with poutine.scale(scale=annealing_factor): - z_t = pyro.sample("z_%d" % t, - dist.Normal(z_loc, z_scale) - .mask(mini_batch_mask[:, t - 1:t]) - .to_event(1)) + z_t = pyro.sample( + "z_%d" % t, + dist.Normal(z_loc, z_scale) + .mask(mini_batch_mask[:, t - 1 : t]) + .to_event(1), + ) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = self.emitter(z_t) # the next statement instructs pyro to observe x_t according to the # bernoulli distribution p(x_t|z_t) - pyro.sample("obs_x_%d" % t, - dist.Bernoulli(emission_probs_t) - .mask(mini_batch_mask[:, t - 1:t]) - .to_event(1), - obs=mini_batch[:, t - 1, :]) + pyro.sample( + "obs_x_%d" % t, + dist.Bernoulli(emission_probs_t) + .mask(mini_batch_mask[:, t - 1 : t]) + .to_event(1), + obs=mini_batch[:, t - 1, :], + ) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t # the guide q(z_{1:T} | x_{1:T}) (i.e. the variational distribution) - def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, - mini_batch_seq_lengths, annealing_factor=1.0): + def guide( + self, + mini_batch, + mini_batch_reversed, + mini_batch_mask, + mini_batch_seq_lengths, + annealing_factor=1.0, + ): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) @@ -243,7 +277,9 @@ def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory - h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() + h_0_contig = self.h_0.expand( + 1, mini_batch.size(0), self.rnn.hidden_size + ).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) @@ -265,25 +301,32 @@ def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: - z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) + z_dist = TransformedDistribution( + dist.Normal(z_loc, z_scale), self.iafs + ) assert z_dist.event_shape == (self.z_q_0.size(0),) assert z_dist.batch_shape[-1:] == (len(mini_batch),) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () - assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0)) + assert z_dist.batch_shape[-2:] == ( + len(mini_batch), + self.z_q_0.size(0), + ) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) - z_t = pyro.sample("z_%d" % t, - z_dist.mask(mini_batch_mask[:, t - 1])) + z_t = pyro.sample( + "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1]) + ) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent - z_t = pyro.sample("z_%d" % t, - z_dist.mask(mini_batch_mask[:, t - 1:t]) - .to_event(1)) + z_t = pyro.sample( + "z_%d" % t, + z_dist.mask(mini_batch_mask[:, t - 1 : t]).to_event(1), + ) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t @@ -292,26 +335,32 @@ def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, # setup, training, and evaluation def main(args): # setup logging - logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w') + logging.basicConfig( + level=logging.DEBUG, format="%(message)s", filename=args.log, filemode="w" + ) console = logging.StreamHandler() console.setLevel(logging.INFO) - logging.getLogger('').addHandler(console) + logging.getLogger("").addHandler(console) logging.info(args) data = poly.load_data(poly.JSB_CHORALES) - training_seq_lengths = data['train']['sequence_lengths'] - training_data_sequences = data['train']['sequences'] - test_seq_lengths = data['test']['sequence_lengths'] - test_data_sequences = data['test']['sequences'] - val_seq_lengths = data['valid']['sequence_lengths'] - val_data_sequences = data['valid']['sequences'] + training_seq_lengths = data["train"]["sequence_lengths"] + training_data_sequences = data["train"]["sequences"] + test_seq_lengths = data["test"]["sequence_lengths"] + test_data_sequences = data["test"]["sequences"] + val_seq_lengths = data["valid"]["sequence_lengths"] + val_data_sequences = data["valid"]["sequences"] N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) - N_mini_batches = int(N_train_data / args.mini_batch_size + - int(N_train_data % args.mini_batch_size > 0)) + N_mini_batches = int( + N_train_data / args.mini_batch_size + + int(N_train_data % args.mini_batch_size > 0) + ) - logging.info("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % - (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) + logging.info( + "N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" + % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches) + ) # how often we do validation/test evaluation during training val_test_frequency = 50 @@ -324,26 +373,55 @@ def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples - return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape) + return ( + x.repeat(repeat_dims) + .reshape(n_eval_samples, -1) + .transpose(1, 0) + .reshape(rep_shape) + ) # get the validation/test data ready for the dmm: pack into sequences, etc. val_seq_lengths = rep(val_seq_lengths) test_seq_lengths = rep(test_seq_lengths) - val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( - torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), - val_seq_lengths, cuda=args.cuda) - test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( - torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), - test_seq_lengths, cuda=args.cuda) + ( + val_batch, + val_batch_reversed, + val_batch_mask, + val_seq_lengths, + ) = poly.get_mini_batch( + torch.arange(n_eval_samples * val_data_sequences.shape[0]), + rep(val_data_sequences), + val_seq_lengths, + cuda=args.cuda, + ) + ( + test_batch, + test_batch_reversed, + test_batch_mask, + test_seq_lengths, + ) = poly.get_mini_batch( + torch.arange(n_eval_samples * test_data_sequences.shape[0]), + rep(test_data_sequences), + test_seq_lengths, + cuda=args.cuda, + ) # instantiate the dmm - dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, - iaf_dim=args.iaf_dim, use_cuda=args.cuda) + dmm = DMM( + rnn_dropout_rate=args.rnn_dropout_rate, + num_iafs=args.num_iafs, + iaf_dim=args.iaf_dim, + use_cuda=args.cuda, + ) # setup optimizer - adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2), - "clip_norm": args.clip_norm, "lrd": args.lr_decay, - "weight_decay": args.weight_decay} + adam_params = { + "lr": args.learning_rate, + "betas": (args.beta1, args.beta2), + "clip_norm": args.clip_norm, + "lrd": args.lr_decay, + "weight_decay": args.weight_decay, + } adam = ClippedAdam(adam_params) # setup inference algorithm @@ -351,13 +429,23 @@ def rep(x): if args.jit: raise NotImplementedError("no JIT support yet for TMC") tmc_loss = TraceTMC_ELBO() - dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) + dmm_guide = config_enumerate( + dmm.guide, + default="parallel", + num_samples=args.tmc_num_samples, + expand=False, + ) svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss) elif args.tmcelbo: if args.jit: raise NotImplementedError("no JIT support yet for TMC ELBO") elbo = TraceEnum_ELBO() - dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) + dmm_guide = config_enumerate( + dmm.guide, + default="parallel", + num_samples=args.tmc_num_samples, + expand=False, + ) svi = SVI(dmm.model, dmm_guide, adam, loss=elbo) else: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() @@ -375,8 +463,9 @@ def save_checkpoint(): # loads the model and optimizer states from disk def load_checkpoint(): - assert exists(args.load_opt) and exists(args.load_model), \ - "--load-model and/or --load-opt misspecified" + assert exists(args.load_opt) and exists( + args.load_model + ), "--load-model and/or --load-opt misspecified" logging.info("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) logging.info("loading optimizer states from %s..." % args.load_opt) @@ -388,24 +477,40 @@ def process_minibatch(epoch, which_mini_batch, shuffled_indices): if args.annealing_epochs > 0 and epoch < args.annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = args.minimum_annealing_factor - annealing_factor = min_af + (1.0 - min_af) * \ - (float(which_mini_batch + epoch * N_mini_batches + 1) / - float(args.annealing_epochs * N_mini_batches)) + annealing_factor = min_af + (1.0 - min_af) * ( + float(which_mini_batch + epoch * N_mini_batches + 1) + / float(args.annealing_epochs * N_mini_batches) + ) else: # by default the KL annealing factor is unity annealing_factor = 1.0 # compute which sequences in the training set we should grab - mini_batch_start = (which_mini_batch * args.mini_batch_size) - mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data]) + mini_batch_start = which_mini_batch * args.mini_batch_size + mini_batch_end = np.min( + [(which_mini_batch + 1) * args.mini_batch_size, N_train_data] + ) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader - mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ - = poly.get_mini_batch(mini_batch_indices, training_data_sequences, - training_seq_lengths, cuda=args.cuda) + ( + mini_batch, + mini_batch_reversed, + mini_batch_mask, + mini_batch_seq_lengths, + ) = poly.get_mini_batch( + mini_batch_indices, + training_data_sequences, + training_seq_lengths, + cuda=args.cuda, + ) # do an actual gradient step - loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, - mini_batch_seq_lengths, annealing_factor) + loss = svi.step( + mini_batch, + mini_batch_reversed, + mini_batch_mask, + mini_batch_seq_lengths, + annealing_factor, + ) # keep track of the training loss return loss @@ -415,17 +520,19 @@ def do_evaluation(): dmm.rnn.eval() # compute the validation and test loss n_samples many times - val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, - val_seq_lengths) / float(torch.sum(val_seq_lengths)) - test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, - test_seq_lengths) / float(torch.sum(test_seq_lengths)) + val_nll = svi.evaluate_loss( + val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths + ) / float(torch.sum(val_seq_lengths)) + test_nll = svi.evaluate_loss( + test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths + ) / float(torch.sum(test_seq_lengths)) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() return val_nll, test_nll # if checkpoint files provided, load model and optimizer states from disk before we start training - if args.load_opt != '' and args.load_model != '': + if args.load_opt != "" and args.load_model != "": load_checkpoint() ################# @@ -449,44 +556,48 @@ def do_evaluation(): # report training diagnostics times.append(time.time()) epoch_time = times[-1] - times[-2] - logging.info("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % - (epoch, epoch_nll / N_train_time_slices, epoch_time)) + logging.info( + "[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" + % (epoch, epoch_nll / N_train_time_slices, epoch_time) + ) # do evaluation on test and validation data and report results if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0: val_nll, test_nll = do_evaluation() - logging.info("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll)) + logging.info( + "[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll) + ) # parse command-line arguments and execute the main method -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', type=int, default=5000) - parser.add_argument('-lr', '--learning-rate', type=float, default=0.0003) - parser.add_argument('-b1', '--beta1', type=float, default=0.96) - parser.add_argument('-b2', '--beta2', type=float, default=0.999) - parser.add_argument('-cn', '--clip-norm', type=float, default=10.0) - parser.add_argument('-lrd', '--lr-decay', type=float, default=0.99996) - parser.add_argument('-wd', '--weight-decay', type=float, default=2.0) - parser.add_argument('-mbs', '--mini-batch-size', type=int, default=20) - parser.add_argument('-ae', '--annealing-epochs', type=int, default=1000) - parser.add_argument('-maf', '--minimum-annealing-factor', type=float, default=0.2) - parser.add_argument('-rdr', '--rnn-dropout-rate', type=float, default=0.1) - parser.add_argument('-iafs', '--num-iafs', type=int, default=0) - parser.add_argument('-id', '--iaf-dim', type=int, default=100) - parser.add_argument('-cf', '--checkpoint-freq', type=int, default=0) - parser.add_argument('-lopt', '--load-opt', type=str, default='') - parser.add_argument('-lmod', '--load-model', type=str, default='') - parser.add_argument('-sopt', '--save-opt', type=str, default='') - parser.add_argument('-smod', '--save-model', type=str, default='') - parser.add_argument('--cuda', action='store_true') - parser.add_argument('--jit', action='store_true') - parser.add_argument('--tmc', action='store_true') - parser.add_argument('--tmcelbo', action='store_true') - parser.add_argument('--tmc-num-samples', default=10, type=int) - parser.add_argument('-l', '--log', type=str, default='dmm.log') + parser.add_argument("-n", "--num-epochs", type=int, default=5000) + parser.add_argument("-lr", "--learning-rate", type=float, default=0.0003) + parser.add_argument("-b1", "--beta1", type=float, default=0.96) + parser.add_argument("-b2", "--beta2", type=float, default=0.999) + parser.add_argument("-cn", "--clip-norm", type=float, default=10.0) + parser.add_argument("-lrd", "--lr-decay", type=float, default=0.99996) + parser.add_argument("-wd", "--weight-decay", type=float, default=2.0) + parser.add_argument("-mbs", "--mini-batch-size", type=int, default=20) + parser.add_argument("-ae", "--annealing-epochs", type=int, default=1000) + parser.add_argument("-maf", "--minimum-annealing-factor", type=float, default=0.2) + parser.add_argument("-rdr", "--rnn-dropout-rate", type=float, default=0.1) + parser.add_argument("-iafs", "--num-iafs", type=int, default=0) + parser.add_argument("-id", "--iaf-dim", type=int, default=100) + parser.add_argument("-cf", "--checkpoint-freq", type=int, default=0) + parser.add_argument("-lopt", "--load-opt", type=str, default="") + parser.add_argument("-lmod", "--load-model", type=str, default="") + parser.add_argument("-sopt", "--save-opt", type=str, default="") + parser.add_argument("-smod", "--save-model", type=str, default="") + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true") + parser.add_argument("--tmc", action="store_true") + parser.add_argument("--tmcelbo", action="store_true") + parser.add_argument("--tmc-num-samples", default=10, type=int) + parser.add_argument("-l", "--log", type=str, default="dmm.log") args = parser.parse_args() main(args) diff --git a/examples/eight_schools/data.py b/examples/eight_schools/data.py index 56158fa36e..82bfbb55df 100644 --- a/examples/eight_schools/data.py +++ b/examples/eight_schools/data.py @@ -4,5 +4,5 @@ import torch J = 8 -y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) -sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) +y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor) +sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index af7d4d418e..b7ae265f51 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -12,14 +12,14 @@ import pyro.poutine as poutine from pyro.infer import MCMC, NUTS -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) pyro.set_rng_seed(0) def model(sigma): - eta = pyro.sample('eta', dist.Normal(torch.zeros(data.J), torch.ones(data.J))) - mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1))) - tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1))) + eta = pyro.sample("eta", dist.Normal(torch.zeros(data.J), torch.ones(data.J))) + mu = pyro.sample("mu", dist.Normal(torch.zeros(1), 10 * torch.ones(1))) + tau = pyro.sample("tau", dist.HalfCauchy(scale=25 * torch.ones(1))) theta = mu + tau * eta @@ -32,24 +32,38 @@ def conditioned_model(model, sigma, y): def main(args): nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit) - mcmc = MCMC(nuts_kernel, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps, - num_chains=args.num_chains) + mcmc = MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + ) mcmc.run(model, data.sigma, data.y) mcmc.summary(prob=0.5) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='Eight Schools MCMC') - parser.add_argument('--num-samples', type=int, default=1000, - help='number of MCMC samples (default: 1000)') - parser.add_argument('--num-chains', type=int, default=1, - help='number of parallel MCMC chains (default: 1)') - parser.add_argument('--warmup-steps', type=int, default=1000, - help='number of MCMC samples for warmup (default: 1000)') - parser.add_argument('--jit', action='store_true', default=False) +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser(description="Eight Schools MCMC") + parser.add_argument( + "--num-samples", + type=int, + default=1000, + help="number of MCMC samples (default: 1000)", + ) + parser.add_argument( + "--num-chains", + type=int, + default=1, + help="number of parallel MCMC chains (default: 1)", + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=1000, + help="number of MCMC samples for warmup (default: 1000)", + ) + parser.add_argument("--jit", action="store_true", default=False) args = parser.parse_args() main(args) diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index 2aaa2ad1a1..063a99a587 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -13,7 +13,7 @@ from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) data = torch.stack([y, sigma], dim=1) @@ -22,9 +22,9 @@ def model(data): sigma = data[:, 1] with pyro.plate("data", J): - eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J))) - mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1))) - tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1))) + eta = pyro.sample("eta", dist.Normal(torch.zeros(J), torch.ones(J))) + mu = pyro.sample("mu", dist.Normal(torch.zeros(1), 10 * torch.ones(1))) + tau = pyro.sample("tau", dist.HalfCauchy(scale=25 * torch.ones(1))) theta = mu + tau * eta @@ -46,22 +46,26 @@ def guide(data): m_mu_param = pyro.param("loc_mu", loc_mu) s_mu_param = pyro.param("scale_mu", scale_mu, constraint=constraints.positive) m_logtau_param = pyro.param("loc_logtau", loc_logtau) - s_logtau_param = pyro.param("scale_logtau", scale_logtau, constraint=constraints.positive) + s_logtau_param = pyro.param( + "scale_logtau", scale_logtau, constraint=constraints.positive + ) # guide distributions dist_eta = dist.Normal(m_eta_param, s_eta_param) dist_mu = dist.Normal(m_mu_param, s_mu_param) - dist_tau = dist.TransformedDistribution(dist.Normal(m_logtau_param, s_logtau_param), - transforms=transforms.ExpTransform()) + dist_tau = dist.TransformedDistribution( + dist.Normal(m_logtau_param, s_logtau_param), + transforms=transforms.ExpTransform(), + ) with pyro.plate("data", J): - pyro.sample('eta', dist_eta) - pyro.sample('mu', dist_mu) - pyro.sample('tau', dist_tau) + pyro.sample("eta", dist_eta) + pyro.sample("mu", dist_mu) + pyro.sample("tau", dist_tau) def main(args): - optim = Adam({'lr': args.lr}) + optim = Adam({"lr": args.lr}) elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(model, guide, optim, loss=elbo) @@ -76,14 +80,16 @@ def main(args): logging.info(value.detach().cpu().numpy()) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='Eight Schools SVI') - parser.add_argument('--lr', type=float, default=0.01, - help='learning rate (default: 0.01)') - parser.add_argument('--num-epochs', type=int, default=1000, - help='number of epochs (default: 1000)') - parser.add_argument('--jit', action='store_true', default=False) +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser(description="Eight Schools SVI") + parser.add_argument( + "--lr", type=float, default=0.01, help="learning rate (default: 0.01)" + ) + parser.add_argument( + "--num-epochs", type=int, default=1000, help="number of epochs (default: 1000)" + ) + parser.add_argument("--jit", action="store_true", default=False) args = parser.parse_args() main(args) diff --git a/examples/einsum.py b/examples/einsum.py index 3265a87531..1c257d00f8 100644 --- a/examples/einsum.py +++ b/examples/einsum.py @@ -41,7 +41,7 @@ def jit_prob(equation, *operands, **kwargs): This is cheap but less numerically stable than using the torch_log backend. """ - key = 'prob', equation, kwargs['plates'] + key = "prob", equation, kwargs["plates"] if key not in _CACHE: # This simply wraps einsum for jit compilation. @@ -59,12 +59,14 @@ def jit_logprob(equation, *operands, **kwargs): This simulates evaluating an undirected graphical model. """ - key = 'logprob', equation, kwargs['plates'] + key = "logprob", equation, kwargs["plates"] if key not in _CACHE: # This simply wraps einsum for jit compilation. def _einsum(*operands): - return einsum(equation, *operands, backend='pyro.ops.einsum.torch_log', **kwargs) + return einsum( + equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs + ) _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False) @@ -77,12 +79,14 @@ def jit_gradient(equation, *operands, **kwargs): This is simulates training an undirected graphical model. """ - key = 'gradient', equation, kwargs['plates'] + key = "gradient", equation, kwargs["plates"] if key not in _CACHE: # This wraps einsum for jit compilation, but we will call backward on the result. def _einsum(*operands): - return einsum(equation, *operands, backend='pyro.ops.einsum.torch_log', **kwargs) + return einsum( + equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs + ) _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False) @@ -95,8 +99,9 @@ def _einsum(*operands): losses = (losses,) # Run backward pass. - grads = tuple(grad(loss, operands, retain_graph=True, allow_unused=True) - for loss in losses) + grads = tuple( + grad(loss, operands, retain_graph=True, allow_unused=True) for loss in losses + ) return grads @@ -106,8 +111,8 @@ def _jit_adjoint(equation, *operands, **kwargs): This simulates serving predictions from an undirected graphical model. """ - backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_sample') - key = backend, equation, tuple(x.shape for x in operands), kwargs['plates'] + backend = kwargs.pop("backend", "pyro.ops.einsum.torch_sample") + key = backend, equation, tuple(x.shape for x in operands), kwargs["plates"] if key not in _CACHE: # This wraps a complete adjoint algorithm call. @@ -141,19 +146,25 @@ def _forward_backward(*operands): def jit_map(equation, *operands, **kwargs): - return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_map', **kwargs) + return _jit_adjoint( + equation, *operands, backend="pyro.ops.einsum.torch_map", **kwargs + ) def jit_sample(equation, *operands, **kwargs): - return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_sample', **kwargs) + return _jit_adjoint( + equation, *operands, backend="pyro.ops.einsum.torch_sample", **kwargs + ) def jit_marginal(equation, *operands, **kwargs): - return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_marginal', **kwargs) + return _jit_adjoint( + equation, *operands, backend="pyro.ops.einsum.torch_marginal", **kwargs + ) def time_fn(fn, equation, *operands, **kwargs): - iters = kwargs.pop('iters') + iters = kwargs.pop("iters") _CACHE.clear() # Avoid memory leaks. fn(equation, *operands, **kwargs) @@ -167,45 +178,55 @@ def time_fn(fn, equation, *operands, **kwargs): def main(args): if args.cuda: - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") else: - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_tensor_type("torch.FloatTensor") - if args.method == 'all': - for method in ['prob', 'logprob', 'gradient', 'marginal', 'map', 'sample']: + if args.method == "all": + for method in ["prob", "logprob", "gradient", "marginal", "map", "sample"]: args.method = method main(args) return - print('Plate size Time per iteration of {} (ms)'.format(args.method)) - fn = globals()['jit_{}'.format(args.method)] + print("Plate size Time per iteration of {} (ms)".format(args.method)) + fn = globals()["jit_{}".format(args.method)] equation = args.equation plates = args.plates - inputs, outputs = equation.split('->') - inputs = inputs.split(',') + inputs, outputs = equation.split("->") + inputs = inputs.split(",") # Vary all plate sizes at the same time. for plate_size in range(8, 1 + args.max_plate_size, 8): operands = [] for dims in inputs: - shape = torch.Size([plate_size if d in plates else args.dim_size - for d in dims]) + shape = torch.Size( + [plate_size if d in plates else args.dim_size for d in dims] + ) operands.append((torch.empty(shape).uniform_() + 0.5).requires_grad_()) - time = time_fn(fn, equation, *operands, plates=plates, modulo_total=True, - iters=args.iters) - print('{: <11s} {:0.4g}'.format('{} ** {}'.format(plate_size, len(args.plates)), time * 1e3)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='plated einsum profiler') - parser.add_argument('-e', '--equation', default='a,abi,bcij,adj,deij->') - parser.add_argument('-p', '--plates', default='ij') - parser.add_argument('-d', '--dim-size', default=32, type=int) - parser.add_argument('-s', '--max-plate-size', default=32, type=int) - parser.add_argument('-n', '--iters', default=10, type=int) - parser.add_argument('--cuda', action='store_true') - parser.add_argument('-m', '--method', default='all', - help='one of: prob, logprob, gradient, marginal, map, sample') + time = time_fn( + fn, equation, *operands, plates=plates, modulo_total=True, iters=args.iters + ) + print( + "{: <11s} {:0.4g}".format( + "{} ** {}".format(plate_size, len(args.plates)), time * 1e3 + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="plated einsum profiler") + parser.add_argument("-e", "--equation", default="a,abi,bcij,adj,deij->") + parser.add_argument("-p", "--plates", default="ij") + parser.add_argument("-d", "--dim-size", default=32, type=int) + parser.add_argument("-s", "--max-plate-size", default=32, type=int) + parser.add_argument("-n", "--iters", default=10, type=int) + parser.add_argument("--cuda", action="store_true") + parser.add_argument( + "-m", + "--method", + default="all", + help="one of: prob, logprob, gradient, marginal, map, sample", + ) args = parser.parse_args() main(args) diff --git a/examples/hmm.py b/examples/hmm.py index 2fa15de058..5700065b09 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -53,7 +53,7 @@ from pyro.optim import Adam from pyro.util import ignore_jit_warnings -logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG) +logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.DEBUG) # Add another handler for logging debugging events (e.g. for profiling) # in a separate stream that can be captured. @@ -87,16 +87,17 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, data_dim]) - .to_event(2)) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), + ) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) @@ -108,11 +109,19 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). - x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}_{}".format(i, t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate: - pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), - obs=sequence[t]) + pyro.sample( + "y_{}_{}".format(i, t), + dist.Bernoulli(probs_y[x.squeeze(-1)]), + obs=sequence[t], + ) + + # To see how enumeration changes the shapes of these sample sites, we can use # the Trace.format_shapes() to print shapes at each site: # $ python examples/hmm.py -m 0 -n 1 -b 1 -t 5 --print-shapes @@ -171,13 +180,14 @@ def model_1(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, data_dim]) - .to_event(2)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different @@ -193,11 +203,19 @@ def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[x.squeeze(-1)]), + obs=sequences[batch, t], + ) + + # Let's see how batching changes the shapes of sample sites: # $ python examples/hmm.py -m 1 -n 1 -t 5 --batch-size=10 --print-shapes # ... @@ -249,27 +267,34 @@ def model_2(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([args.hidden_dim, 2, data_dim]) - .to_event(3)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: - y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), - obs=sequences[batch, t]).long() + y = pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[x, y, tones]), + obs=sequences[batch, t], + ).long() # Next consider a Factorial HMM with two hidden states. @@ -294,29 +319,38 @@ def model_3(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with poutine.mask(mask=include_prior): - probs_w = pyro.sample("probs_w", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([hidden_dim, hidden_dim, data_dim]) - .to_event(3)) + probs_w = pyro.sample( + "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_x = pyro.sample( + "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): - w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), - infer={"enumerate": "parallel"}) - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + w = pyro.sample( + "w_{}".format(t), + dist.Categorical(probs_w[w]), + infer={"enumerate": "parallel"}, + ) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) with tones_plate as tones: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[w, x, tones]), + obs=sequences[batch, t], + ) # By adding a dependency of x on w, we generalize to a @@ -340,17 +374,19 @@ def model_4(sequences, lengths, args, batch_size=None, include_prior=True): assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with poutine.mask(mask=include_prior): - probs_w = pyro.sample("probs_w", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .to_event(1)) - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) - .expand_by([hidden_dim]) - .to_event(2)) - probs_y = pyro.sample("probs_y", - dist.Beta(0.1, 0.9) - .expand([hidden_dim, hidden_dim, data_dim]) - .to_event(3)) + probs_w = pyro.sample( + "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) + ) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) + .expand_by([hidden_dim]) + .to_event(2), + ) + probs_y = pyro.sample( + "probs_y", + dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] @@ -360,14 +396,22 @@ def model_4(sequences, lengths, args, batch_size=None, include_prior=True): w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): - w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), - infer={"enumerate": "parallel"}) - x = pyro.sample("x_{}".format(t), - dist.Categorical(Vindex(probs_x)[w, x]), - infer={"enumerate": "parallel"}) + w = pyro.sample( + "w_{}".format(t), + dist.Categorical(probs_w[w]), + infer={"enumerate": "parallel"}, + ) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(Vindex(probs_x)[w, x]), + infer={"enumerate": "parallel"}, + ) with tones_plate as tones: - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y[w, x, tones]), + obs=sequences[batch, t], + ) # Next let's consider a neural HMM model. @@ -394,8 +438,12 @@ def forward(self, x, y): # a bernoulli variable y. Whereas x will typically be enumerated, y will be observed. # We apply x_to_hidden independently from y_to_hidden, then broadcast the non-enumerated # y part up to the enumerated x part in the + operation. - x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_(-1, x, 1) - y_conv = self.relu(self.conv(y.reshape(-1, 1, self.data_dim))).reshape(y.shape[:-1] + (-1,)) + x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_( + -1, x, 1 + ) + y_conv = self.relu(self.conv(y.reshape(-1, 1, self.data_dim))).reshape( + y.shape[:-1] + (-1,) + ) h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv)) return self.hidden_to_logits(h) @@ -420,23 +468,29 @@ def model_5(sequences, lengths, args, batch_size=None, include_prior=True): pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): - x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x[x]), + infer={"enumerate": "parallel"}, + ) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): - y = pyro.sample("y_{}".format(t), - dist.Bernoulli(logits=tones_generator(x, y)), - obs=sequences[batch, t]) + y = pyro.sample( + "y_{}".format(t), + dist.Bernoulli(logits=tones_generator(x, y)), + obs=sequences[batch, t], + ) # Next let's consider a second-order HMM model @@ -463,24 +517,38 @@ def model_6(sequences, lengths, args, batch_size=None, include_prior=False): if not args.raftery_parameterization: # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. - probs_x = pyro.param("probs_x", torch.rand(hidden_dim, hidden_dim, hidden_dim), - constraint=constraints.simplex) + probs_x = pyro.param( + "probs_x", + torch.rand(hidden_dim, hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) else: # Use the more parsimonious "Raftery" parameterization of # the tensor of transition probabilities. See reference: # Raftery, A. E. A model for high-order markov chains. # Journal of the Royal Statistical Society. 1985. - probs_x1 = pyro.param("probs_x1", torch.rand(hidden_dim, hidden_dim), - constraint=constraints.simplex) - probs_x2 = pyro.param("probs_x2", torch.rand(hidden_dim, hidden_dim), - constraint=constraints.simplex) - mix_lambda = pyro.param("mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval) + probs_x1 = pyro.param( + "probs_x1", + torch.rand(hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) + probs_x2 = pyro.param( + "probs_x2", + torch.rand(hidden_dim, hidden_dim), + constraint=constraints.simplex, + ) + mix_lambda = pyro.param( + "mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval + ) # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim) probs_x = mix_lambda * probs_x1 + (1.0 - mix_lambda) * probs_x2.unsqueeze(-2) - probs_y = pyro.param("probs_y", torch.rand(hidden_dim, data_dim), - constraint=constraints.unit_interval) + probs_y = pyro.param( + "probs_y", + torch.rand(hidden_dim, data_dim), + constraint=constraints.unit_interval, + ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] @@ -490,12 +558,18 @@ def model_6(sequences, lengths, args, batch_size=None, include_prior=False): for t in pyro.markov(range(lengths.max()), history=2): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): probs_x_t = Vindex(probs_x)[x_prev, x_curr] - x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Categorical(probs_x_t), - infer={"enumerate": "parallel"}) + x_prev, x_curr = x_curr, pyro.sample( + "x_{}".format(t), + dist.Categorical(probs_x_t), + infer={"enumerate": "parallel"}, + ) with tones_plate: probs_y_t = probs_y[x_curr.squeeze(-1)] - pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y_t), - obs=sequences[batch, t]) + pyro.sample( + "y_{}".format(t), + dist.Bernoulli(probs_y_t), + obs=sequences[batch, t], + ) # Next we demonstrate how to parallelize the neural HMM above using Pyro's @@ -515,51 +589,59 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True): pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): - probs_x = pyro.sample("probs_x", - dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1) - .to_event(1)) + probs_x = pyro.sample( + "probs_x", + dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), + ) with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch: lengths = lengths[batch] - y = sequences[batch] if args.jit else sequences[batch, :lengths.max()] + y = sequences[batch] if args.jit else sequences[batch, : lengths.max()] x = torch.arange(args.hidden_dim) t = torch.arange(y.size(1)) - init_logits = torch.full((args.hidden_dim,), -float('inf')) + init_logits = torch.full((args.hidden_dim,), -float("inf")) init_logits[0] = 0 trans_logits = probs_x.log() with ignore_jit_warnings(): - obs_dist = dist.Bernoulli(logits=tones_generator(x, y.unsqueeze(-2))).to_event(1) + obs_dist = dist.Bernoulli( + logits=tones_generator(x, y.unsqueeze(-2)) + ).to_event(1) obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1)) hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) pyro.sample("y", hmm_dist, obs=y) -models = {name[len('model_'):]: model - for name, model in globals().items() - if name.startswith('model_')} +models = { + name[len("model_") :]: model + for name, model in globals().items() + if name.startswith("model_") +} def main(args): if args.cuda: - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") - logging.info('Loading data') + logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) - logging.info('-' * 40) + logging.info("-" * 40) model = models[args.model] - logging.info('Training {} on {} sequences'.format( - model.__name__, len(data['train']['sequences']))) - sequences = data['train']['sequences'] - lengths = data['train']['sequence_lengths'] + logging.info( + "Training {} on {} sequences".format( + model.__name__, len(data["train"]["sequences"]) + ) + ) + sequences = data["train"]["sequences"] + lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set - present_notes = ((sequences == 1).sum(0).sum(0) > 0) + present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) - sequences = sequences[:, :args.truncate] + sequences = sequences[:, : args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() @@ -568,7 +650,9 @@ def main(args): # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. - guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) + guide = AutoDelta( + poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")) + ) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is @@ -576,49 +660,59 @@ def main(args): if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( - sequences, lengths, args=args, batch_size=args.batch_size) + sequences, lengths, args=args, batch_size=args.batch_size + ) model_trace = poutine.trace( - poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace( - sequences, lengths, args=args, batch_size=args.batch_size) + poutine.replay(poutine.enum(model, first_available_dim), guide_trace) + ).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". - optim = Adam({'lr': args.learning_rate}) + optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError("jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} if msg["infer"].get("enumerate", None) == "parallel" else {}) # noqa: E501 + lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {}, + ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO - elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, - strict_enumeration_warning=(model is not model_7), - jit_options={"time_compilation": args.time_compilation}) + elbo = Elbo( + max_plate_nesting=1 if model is model_0 else 2, + strict_enumeration_warning=(model is not model_7), + jit_options={"time_compilation": args.time_compilation}, + ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. - logging.info('Step\tLoss') + logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) - logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) + logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: - logging.debug('time to compile: {} s.'.format(elbo._differentiable_loss.compile_time)) + logging.debug( + "time to compile: {} s.".format(elbo._differentiable_loss.compile_time) + ) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) - logging.info('training loss = {}'.format(train_loss / num_observations)) + logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. - logging.info('-' * 40) - logging.info('Evaluating on {} test sequences'.format(len(data['test']['sequences']))) - sequences = data['test']['sequences'][..., present_notes] - lengths = data['test']['sequence_lengths'] + logging.info("-" * 40) + logging.info( + "Evaluating on {} test sequences".format(len(data["test"]["sequences"])) + ) + sequences = data["test"]["sequences"][..., present_notes] + lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) @@ -626,21 +720,31 @@ def main(args): # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. - test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) - logging.info('test loss = {}'.format(test_loss / num_observations)) + test_loss = elbo.loss( + model, guide, sequences, lengths, args=args, include_prior=False + ) + logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. - capacity = sum(value.reshape(-1).size(0) - for value in pyro.get_param_store().values()) - logging.info('{} capacity = {} parameters'.format(model.__name__, capacity)) - - -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description="MAP Baum-Welch learning Bach Chorales") - parser.add_argument("-m", "--model", default="1", type=str, - help="one of: {}".format(", ".join(sorted(models.keys())))) + capacity = sum( + value.reshape(-1).size(0) for value in pyro.get_param_store().values() + ) + logging.info("{} capacity = {} parameters".format(model.__name__, capacity)) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="MAP Baum-Welch learning Bach Chorales" + ) + parser.add_argument( + "-m", + "--model", + default="1", + type=str, + help="one of: {}".format(", ".join(sorted(models.keys()))), + ) parser.add_argument("-n", "--num-steps", default=50, type=int) parser.add_argument("-b", "--batch-size", default=8, type=int) parser.add_argument("-d", "--hidden-dim", default=16, type=int) @@ -650,15 +754,18 @@ def main(args): parser.add_argument("-t", "--truncate", type=int) parser.add_argument("-p", "--print-shapes", action="store_true") parser.add_argument("--seed", default=0, type=int) - parser.add_argument('--cuda', action='store_true') - parser.add_argument('--jit', action='store_true') - parser.add_argument('--time-compilation', action='store_true') - parser.add_argument('-rp', '--raftery-parameterization', action='store_true') - parser.add_argument('--tmc', action='store_true', - help="Use Tensor Monte Carlo instead of exact enumeration " - "to estimate the marginal likelihood. You probably don't want to do this, " - "except to see that TMC makes Monte Carlo gradient estimation feasible " - "even with very large numbers of non-reparametrized variables.") - parser.add_argument('--tmc-num-samples', default=10, type=int) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true") + parser.add_argument("--time-compilation", action="store_true") + parser.add_argument("-rp", "--raftery-parameterization", action="store_true") + parser.add_argument( + "--tmc", + action="store_true", + help="Use Tensor Monte Carlo instead of exact enumeration " + "to estimate the marginal likelihood. You probably don't want to do this, " + "except to see that TMC makes Monte Carlo gradient estimation feasible " + "even with very large numbers of non-reparametrized variables.", + ) + parser.add_argument("--tmc-num-samples", default=10, type=int) args = parser.parse_args() main(args) diff --git a/examples/inclined_plane.py b/examples/inclined_plane.py index 1ad2e6fdc0..b8614e9132 100644 --- a/examples/inclined_plane.py +++ b/examples/inclined_plane.py @@ -32,17 +32,23 @@ # the forward simulator, which does numerical integration of the equations of motion # in steps of size dt, and optionally includes measurement noise + def simulate(mu, length=2.0, phi=np.pi / 6.0, dt=0.005, noise_sigma=None): T = torch.zeros(()) velocity = torch.zeros(()) displacement = torch.zeros(()) - acceleration = torch.tensor(little_g * np.sin(phi)) - \ - torch.tensor(little_g * np.cos(phi)) * mu - - if acceleration.numpy() <= 0.0: # the box doesn't slide if the friction is too large - return torch.tensor(1.0e5) # return a very large time instead of infinity - - while displacement.numpy() < length: # otherwise slide to the end of the inclined plane + acceleration = ( + torch.tensor(little_g * np.sin(phi)) - torch.tensor(little_g * np.cos(phi)) * mu + ) + + if ( + acceleration.numpy() <= 0.0 + ): # the box doesn't slide if the friction is too large + return torch.tensor(1.0e5) # return a very large time instead of infinity + + while ( + displacement.numpy() < length + ): # otherwise slide to the end of the inclined plane displacement += velocity * dt velocity += acceleration * dt T += dt @@ -56,6 +62,7 @@ def simulate(mu, length=2.0, phi=np.pi / 6.0, dt=0.005, noise_sigma=None): # analytic formula that the simulator above is computing via # numerical integration (no measurement noise) + def analytic_T(mu, length=2.0, phi=np.pi / 6.0): numerator = 2.0 * length denominator = little_g * (np.sin(phi) - mu * np.cos(phi)) @@ -66,8 +73,12 @@ def analytic_T(mu, length=2.0, phi=np.pi / 6.0): print("generating simulated data using the true coefficient of friction %.3f" % mu0) N_obs = 20 torch.manual_seed(2) -observed_data = torch.tensor([simulate(torch.tensor(mu0), noise_sigma=time_measurement_sigma) - for _ in range(N_obs)]) +observed_data = torch.tensor( + [ + simulate(torch.tensor(mu0), noise_sigma=time_measurement_sigma) + for _ in range(N_obs) + ] +) observed_mean = np.mean([T.item() for T in observed_data]) @@ -102,18 +113,29 @@ def main(args): # report results inferred_mu = posterior_mean.item() inferred_mu_uncertainty = posterior_std_dev.item() - print("the coefficient of friction inferred by pyro is %.3f +- %.3f" % - (inferred_mu, inferred_mu_uncertainty)) + print( + "the coefficient of friction inferred by pyro is %.3f +- %.3f" + % (inferred_mu, inferred_mu_uncertainty) + ) # note that, given the finite step size in the simulator, the simulated descent times will # not precisely match the numbers from the analytic result. # in particular the first two numbers reported below should match each other pretty closely # but will be systematically off from the third number - print("the mean observed descent time in the dataset is: %.4f seconds" % observed_mean) - print("the (forward) simulated descent time for the inferred (mean) mu is: %.4f seconds" % - simulate(posterior_mean).item()) - print(("disregarding measurement noise, elementary calculus gives the descent time\n" + - "for the inferred (mean) mu as: %.4f seconds") % analytic_T(posterior_mean.item())) + print( + "the mean observed descent time in the dataset is: %.4f seconds" % observed_mean + ) + print( + "the (forward) simulated descent time for the inferred (mean) mu is: %.4f seconds" + % simulate(posterior_mean).item() + ) + print( + ( + "disregarding measurement noise, elementary calculus gives the descent time\n" + + "for the inferred (mean) mu as: %.4f seconds" + ) + % analytic_T(posterior_mean.item()) + ) """ ################## EXERCISE ################### @@ -122,9 +144,9 @@ def main(args): """ -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=500, type=int) + parser.add_argument("-n", "--num-samples", default=500, type=int) args = parser.parse_args() main(args) diff --git a/examples/lda.py b/examples/lda.py index 0e6a7be726..9a58c0a06c 100644 --- a/examples/lda.py +++ b/examples/lda.py @@ -32,7 +32,7 @@ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.optim import ClippedAdam -logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO) +logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) # This is a fully generative model of a batch of documents. @@ -42,9 +42,12 @@ def model(data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): - topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) - topic_words = pyro.sample("topic_words", - dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) + topic_weights = pyro.sample( + "topic_weights", dist.Gamma(1.0 / args.num_topics, 1.0) + ) + topic_words = pyro.sample( + "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words) + ) # Locals. with pyro.plate("documents", args.num_docs) as ind: @@ -58,10 +61,14 @@ def model(data=None, args=None, batch_size=None): # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. - word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), - infer={"enumerate": "parallel"}) - data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), - obs=data) + word_topics = pyro.sample( + "word_topics", + dist.Categorical(doc_topics), + infer={"enumerate": "parallel"}, + ) + data = pyro.sample( + "doc_words", dist.Categorical(topic_words[word_topics]), obs=data + ) return topic_weights, topic_words, data @@ -69,10 +76,12 @@ def model(data=None, args=None, batch_size=None): # We will use amortized inference of the local topic variables, achieved by a # multi-layer perceptron. We'll wrap the guide in an nn.Module. def make_predictor(args): - layer_sizes = ([args.num_words] + - [int(s) for s in args.layer_sizes.split('-')] + - [args.num_topics]) - logging.info('Creating MLP with sizes {}'.format(layer_sizes)) + layer_sizes = ( + [args.num_words] + + [int(s) for s in args.layer_sizes.split("-")] + + [args.num_topics] + ) + logging.info("Creating MLP with sizes {}".format(layer_sizes)) layers = [] for in_size, out_size in zip(layer_sizes, layer_sizes[1:]): layer = nn.Linear(in_size, out_size) @@ -87,15 +96,17 @@ def make_predictor(args): def parametrized_guide(predictor, data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( - "topic_weights_posterior", - lambda: torch.ones(args.num_topics), - constraint=constraints.positive) + "topic_weights_posterior", + lambda: torch.ones(args.num_topics), + constraint=constraints.positive, + ) topic_words_posterior = pyro.param( - "topic_words_posterior", - lambda: torch.ones(args.num_topics, args.num_words), - constraint=constraints.greater_than(0.5)) + "topic_words_posterior", + lambda: torch.ones(args.num_topics, args.num_words), + constraint=constraints.greater_than(0.5), + ) with pyro.plate("topics", args.num_topics): - pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) + pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. @@ -104,14 +115,15 @@ def parametrized_guide(predictor, data, args, batch_size=None): data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. - counts = (torch.zeros(args.num_words, ind.size(0)) - .scatter_add(0, data, torch.ones(data.shape))) + counts = torch.zeros(args.num_words, ind.size(0)).scatter_add( + 0, data, torch.ones(data.shape) + ) doc_topics = predictor(counts.transpose(0, 1)) pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1)) def main(args): - logging.info('Generating data') + logging.info("Generating data") pyro.set_rng_seed(0) pyro.clear_param_store() @@ -119,26 +131,28 @@ def main(args): true_topic_weights, true_topic_words, data = model(args=args) # We'll train using SVI. - logging.info('-' * 40) - logging.info('Training on {} documents'.format(args.num_docs)) + logging.info("-" * 40) + logging.info("Training on {} documents".format(args.num_docs)) predictor = make_predictor(args) guide = functools.partial(parametrized_guide, predictor) Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=2) - optim = ClippedAdam({'lr': args.learning_rate}) + optim = ClippedAdam({"lr": args.learning_rate}) svi = SVI(model, guide, optim, elbo) - logging.info('Step\tLoss') + logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(data, args=args, batch_size=args.batch_size) if step % 10 == 0: - logging.info('{: >5d}\t{}'.format(step, loss)) + logging.info("{: >5d}\t{}".format(step, loss)) loss = elbo.loss(model, guide, data, args=args) - logging.info('final loss = {}'.format(loss)) + logging.info("final loss = {}".format(loss)) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation") +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="Amortized Latent Dirichlet Allocation" + ) parser.add_argument("-t", "--num-topics", default=8, type=int) parser.add_argument("-w", "--num-words", default=1024, type=int) parser.add_argument("-d", "--num-docs", default=1000, type=int) @@ -147,6 +161,6 @@ def main(args): parser.add_argument("-l", "--layer-sizes", default="100-100") parser.add_argument("-lr", "--learning-rate", default=0.01, type=float) parser.add_argument("-b", "--batch-size", default=32, type=int) - parser.add_argument('--jit', action='store_true') + parser.add_argument("--jit", action="store_true") args = parser.parse_args() main(args) diff --git a/examples/lkj.py b/examples/lkj.py index 57747267dd..003deca463 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -26,7 +26,9 @@ def model(y): # Vector of variances for each of the d variables theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options))) # Lower cholesky factor of a correlation matrix - concentration = torch.ones((), **options) # Implies a uniform distribution over correlation matrices + concentration = torch.ones( + (), **options + ) # Implies a uniform distribution over correlation matrices L_omega = pyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega) @@ -45,19 +47,23 @@ def main(args): if args.cuda: y = y.cuda() nuts_kernel = NUTS(model, jit_compile=False, step_size=1e-5) - MCMC(nuts_kernel, num_samples=args.num_samples, - warmup_steps=args.warmup_steps, num_chains=args.num_chains).run(y) + MCMC( + nuts_kernel, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + num_chains=args.num_chains, + ).run(y) if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Demonstrate the use of an LKJ Prior") parser.add_argument("--num-samples", nargs="?", default=200, type=int) parser.add_argument("--n", nargs="?", default=500, type=int) - parser.add_argument("--num-chains", nargs='?', default=4, type=int) - parser.add_argument("--num-variables", nargs='?', default=5, type=int) - parser.add_argument("--warmup-steps", nargs='?', default=100, type=int) - parser.add_argument("--rng_seed", nargs='?', default=0, type=int) + parser.add_argument("--num-chains", nargs="?", default=4, type=int) + parser.add_argument("--num-variables", nargs="?", default=5, type=int) + parser.add_argument("--warmup-steps", nargs="?", default=100, type=int) + parser.add_argument("--rng_seed", nargs="?", default=0, type=int) parser.add_argument("--cuda", action="store_true", default=False) args = parser.parse_args() diff --git a/examples/minipyro.py b/examples/minipyro.py index 445133e51f..2ac0706e48 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -20,15 +20,15 @@ def main(args): # Define a basic model with a single Normal latent random variable `loc` # and a batch of Normally distributed observations. def model(data): - loc = pyro.sample("loc", dist.Normal(0., 1.)) + loc = pyro.sample("loc", dist.Normal(0.0, 1.0)) with pyro.plate("data", len(data), dim=-1): - pyro.sample("obs", dist.Normal(loc, 1.), obs=data) + pyro.sample("obs", dist.Normal(loc, 1.0), obs=data) # Define a guide (i.e. variational distribution) with a Normal # distribution over the latent random variable `loc`. def guide(data): - guide_loc = pyro.param("guide_loc", torch.tensor(0.)) - guide_scale = ops.exp(pyro.param("guide_scale_log", torch.tensor(0.))) + guide_loc = pyro.param("guide_loc", torch.tensor(0.0)) + guide_scale = ops.exp(pyro.param("guide_scale_log", torch.tensor(0.0))) pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) # Generate some data. @@ -65,7 +65,7 @@ def guide(data): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-b", "--backend", default="minipyro") parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index bd4bd02e6a..e7df74cbc9 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -42,7 +42,9 @@ def run_expt(args): optim = args["optim"] lr = args["learnrate"] timesteps = args["timesteps"] - schedule = [] if not args["schedule"] else [int(i) for i in args["schedule"].split(",")] + schedule = ( + [] if not args["schedule"] else [int(i) for i in args["schedule"].split(",")] + ) random_effects = {"group": args["group"], "individual": args["individual"]} pyro.set_rng_seed(seed) # reproducible random effect parameter init @@ -62,14 +64,18 @@ def run_expt(args): loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss with pyro.poutine.trace(param_only=True) as param_capture: loss_fn(model, guide) - params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] + params = [ + site["value"].unconstrained() for site in param_capture.trace.nodes.values() + ] optimizer = torch.optim.Adam(params, lr=lr) if schedule: - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=schedule, gamma=0.5 + ) schedule_step_loss = False else: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") schedule_step_loss = True for t in range(timesteps): @@ -81,40 +87,50 @@ def run_expt(args): scheduler.step(loss.item() if schedule_step_loss else t) losses.append(loss.item()) - print("Loss: {}, AIC[{}]: ".format(loss.item(), t), - 2. * loss + 2. * num_parameters) + print( + "Loss: {}, AIC[{}]: ".format(loss.item(), t), + 2.0 * loss + 2.0 * num_parameters, + ) # LBFGS elif optim == "lbfgs": loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss with pyro.poutine.trace(param_only=True) as param_capture: loss_fn(model, guide) - params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] + params = [ + site["value"].unconstrained() for site in param_capture.trace.nodes.values() + ] optimizer = torch.optim.LBFGS(params, lr=lr) if schedule: - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=schedule, gamma=0.5 + ) schedule_step_loss = False else: schedule_step_loss = True - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") for t in range(timesteps): + def closure(): optimizer.zero_grad() loss = loss_fn(model, guide) loss.backward() return loss + loss = optimizer.step(closure) scheduler.step(loss.item() if schedule_step_loss else t) losses.append(loss.item()) - print("Loss: {}, AIC[{}]: ".format(loss.item(), t), - 2. * loss + 2. * num_parameters) + print( + "Loss: {}, AIC[{}]: ".format(loss.item(), t), + 2.0 * loss + 2.0 * num_parameters, + ) else: raise ValueError("{} not supported optimizer".format(optim)) - aic_final = 2. * losses[-1] + 2. * num_parameters + aic_final = 2.0 * losses[-1] + 2.0 * num_parameters print("AIC final: {}".format(aic_final)) results = {} @@ -126,9 +142,23 @@ def closure(): results["aic_num_parameters"] = num_parameters if args["resultsdir"] is not None and os.path.exists(args["resultsdir"]): - re_str = "g" + ("n" if args["group"] is None else "d" if args["group"] == "discrete" else "c") - re_str += "i" + ("n" if args["individual"] is None else "d" if args["individual"] == "discrete" else "c") - results_filename = "expt_{}_{}_{}.json".format(dataset, re_str, str(uuid.uuid4().hex)[0:5]) + re_str = "g" + ( + "n" + if args["group"] is None + else "d" + if args["group"] == "discrete" + else "c" + ) + re_str += "i" + ( + "n" + if args["individual"] is None + else "d" + if args["individual"] == "discrete" + else "c" + ) + results_filename = "expt_{}_{}_{}.json".format( + dataset, re_str, str(uuid.uuid4().hex)[0:5] + ) with open(os.path.join(args["resultsdir"], results_filename), "w") as f: json.dump(results, f) diff --git a/examples/mixed_hmm/model.py b/examples/mixed_hmm/model.py index a4be63780d..963b33d1e2 100644 --- a/examples/mixed_hmm/model.py +++ b/examples/mixed_hmm/model.py @@ -17,30 +17,55 @@ def guide_generic(config): if config["group"]["random"] == "continuous": loc_g = pyro.param("loc_group", lambda: torch.zeros((N_state ** 2,))) - scale_g = pyro.param("scale_group", lambda: torch.ones((N_state ** 2,)), - constraint=constraints.positive) + scale_g = pyro.param( + "scale_group", + lambda: torch.ones((N_state ** 2,)), + constraint=constraints.positive, + ) # initialize individual-level random effect parameters N_c = config["sizes"]["group"] if config["individual"]["random"] == "continuous": - loc_i = pyro.param("loc_individual", lambda: torch.zeros((N_c, N_state ** 2,))) - scale_i = pyro.param("scale_individual", lambda: torch.ones((N_c, N_state ** 2,)), - constraint=constraints.positive) + loc_i = pyro.param( + "loc_individual", + lambda: torch.zeros( + ( + N_c, + N_state ** 2, + ) + ), + ) + scale_i = pyro.param( + "scale_individual", + lambda: torch.ones( + ( + N_c, + N_state ** 2, + ) + ), + constraint=constraints.positive, + ) N_c = config["sizes"]["group"] with pyro.plate("group", N_c, dim=-1): if config["group"]["random"] == "continuous": - pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1), - ) # infer={"num_samples": 10}) + pyro.sample( + "eps_g", + dist.Normal(loc_g, scale_g).to_event(1), + ) # infer={"num_samples": 10}) N_s = config["sizes"]["individual"] - with pyro.plate("individual", N_s, dim=-2), poutine.mask(mask=config["individual"]["mask"]): + with pyro.plate("individual", N_s, dim=-2), poutine.mask( + mask=config["individual"]["mask"] + ): # individual-level random effects if config["individual"]["random"] == "continuous": - pyro.sample("eps_i", dist.Normal(loc_i, scale_i).to_event(1), - ) # infer={"num_samples": 10}) + pyro.sample( + "eps_i", + dist.Normal(loc_i, scale_i).to_event(1), + ) # infer={"num_samples": 10}) @config_enumerate @@ -53,7 +78,11 @@ def model_generic(config): # initialize group-level random effect parameterss if config["group"]["random"] == "discrete": - probs_e_g = pyro.param("probs_e_group", lambda: torch.randn((N_v,)).abs(), constraint=constraints.simplex) + probs_e_g = pyro.param( + "probs_e_group", + lambda: torch.randn((N_v,)).abs(), + constraint=constraints.simplex, + ) theta_g = pyro.param("theta_group", lambda: torch.randn((N_v, N_state ** 2))) elif config["group"]["random"] == "continuous": loc_g = torch.zeros((N_state ** 2,)) @@ -62,39 +91,67 @@ def model_generic(config): # initialize individual-level random effect parameters N_c = config["sizes"]["group"] if config["individual"]["random"] == "discrete": - probs_e_i = pyro.param("probs_e_individual", - lambda: torch.randn((N_c, N_v,)).abs(), - constraint=constraints.simplex) - theta_i = pyro.param("theta_individual", - lambda: torch.randn((N_c, N_v, N_state ** 2))) + probs_e_i = pyro.param( + "probs_e_individual", + lambda: torch.randn( + ( + N_c, + N_v, + ) + ).abs(), + constraint=constraints.simplex, + ) + theta_i = pyro.param( + "theta_individual", lambda: torch.randn((N_c, N_v, N_state ** 2)) + ) elif config["individual"]["random"] == "continuous": - loc_i = torch.zeros((N_c, N_state ** 2,)) - scale_i = torch.ones((N_c, N_state ** 2,)) + loc_i = torch.zeros( + ( + N_c, + N_state ** 2, + ) + ) + scale_i = torch.ones( + ( + N_c, + N_state ** 2, + ) + ) # initialize likelihood parameters # observation 1: step size (step ~ Gamma) step_zi_param = pyro.param("step_zi_param", lambda: torch.ones((N_state, 2))) - step_concentration = pyro.param("step_param_concentration", - lambda: torch.randn((N_state,)).abs(), - constraint=constraints.positive) - step_rate = pyro.param("step_param_rate", - lambda: torch.randn((N_state,)).abs(), - constraint=constraints.positive) + step_concentration = pyro.param( + "step_param_concentration", + lambda: torch.randn((N_state,)).abs(), + constraint=constraints.positive, + ) + step_rate = pyro.param( + "step_param_rate", + lambda: torch.randn((N_state,)).abs(), + constraint=constraints.positive, + ) # observation 2: step angle (angle ~ VonMises) - angle_concentration = pyro.param("angle_param_concentration", - lambda: torch.randn((N_state,)).abs(), - constraint=constraints.positive) + angle_concentration = pyro.param( + "angle_param_concentration", + lambda: torch.randn((N_state,)).abs(), + constraint=constraints.positive, + ) angle_loc = pyro.param("angle_param_loc", lambda: torch.randn((N_state,)).abs()) # observation 3: dive activity (omega ~ Beta) omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones((N_state, 2))) - omega_concentration0 = pyro.param("omega_param_concentration0", - lambda: torch.randn((N_state,)).abs(), - constraint=constraints.positive) - omega_concentration1 = pyro.param("omega_param_concentration1", - lambda: torch.randn((N_state,)).abs(), - constraint=constraints.positive) + omega_concentration0 = pyro.param( + "omega_param_concentration0", + lambda: torch.randn((N_state,)).abs(), + constraint=constraints.positive, + ) + omega_concentration1 = pyro.param( + "omega_param_concentration1", + lambda: torch.randn((N_state,)).abs(), + constraint=constraints.positive, + ) # initialize gamma to uniform gamma = torch.zeros((N_state ** 2,)) @@ -108,16 +165,20 @@ def model_generic(config): e_g = pyro.sample("e_g", dist.Categorical(probs_e_g)) eps_g = Vindex(theta_g)[..., e_g, :] elif config["group"]["random"] == "continuous": - eps_g = pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1), - ) # infer={"num_samples": 10}) + eps_g = pyro.sample( + "eps_g", + dist.Normal(loc_g, scale_g).to_event(1), + ) # infer={"num_samples": 10}) else: - eps_g = 0. + eps_g = 0.0 # add group-level random effect to gamma gamma = gamma + eps_g N_s = config["sizes"]["individual"] - with pyro.plate("individual", N_s, dim=-2), poutine.mask(mask=config["individual"]["mask"]): + with pyro.plate("individual", N_s, dim=-2), poutine.mask( + mask=config["individual"]["mask"] + ): # individual-level random effects if config["individual"]["random"] == "discrete": @@ -126,10 +187,12 @@ def model_generic(config): eps_i = Vindex(theta_i)[..., e_i, :] # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v elif config["individual"]["random"] == "continuous": - eps_i = pyro.sample("eps_i", dist.Normal(loc_i, scale_i).to_event(1), - ) # infer={"num_samples": 10}) + eps_i = pyro.sample( + "eps_i", + dist.Normal(loc_i, scale_i).to_event(1), + ) # infer={"num_samples": 10}) else: - eps_i = 0. + eps_i = 0.0 # add individual-level random effect to gamma gamma = gamma + eps_i @@ -142,7 +205,9 @@ def model_generic(config): gamma_t = gamma # per-timestep variable # finally, reshape gamma as batch of transition matrices - gamma_t = gamma_t.reshape(tuple(gamma_t.shape[:-1]) + (N_state, N_state)) + gamma_t = gamma_t.reshape( + tuple(gamma_t.shape[:-1]) + (N_state, N_state) + ) # we've accounted for all effects, now actually compute gamma_y gamma_y = Vindex(gamma_t)[..., y, :] @@ -151,47 +216,61 @@ def model_generic(config): # observation 1: step size step_dist = dist.Gamma( concentration=Vindex(step_concentration)[..., y], - rate=Vindex(step_rate)[..., y] + rate=Vindex(step_rate)[..., y], ) # zero-inflation with MaskedMixture step_zi = Vindex(step_zi_param)[..., y, :] step_zi_mask = config["observations"]["step"][..., t] == MISSING - pyro.sample("step_zi_{}".format(t), - dist.Categorical(logits=step_zi), - obs=step_zi_mask.long()) + pyro.sample( + "step_zi_{}".format(t), + dist.Categorical(logits=step_zi), + obs=step_zi_mask.long(), + ) step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) - step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist, step_zi_zero_dist) + step_zi_dist = dist.MaskedMixture( + step_zi_mask, step_dist, step_zi_zero_dist + ) - pyro.sample("step_{}".format(t), - step_zi_dist, - obs=config["observations"]["step"][..., t]) + pyro.sample( + "step_{}".format(t), + step_zi_dist, + obs=config["observations"]["step"][..., t], + ) # observation 2: step angle angle_dist = dist.VonMises( concentration=Vindex(angle_concentration)[..., y], - loc=Vindex(angle_loc)[..., y] + loc=Vindex(angle_loc)[..., y], + ) + pyro.sample( + "angle_{}".format(t), + angle_dist, + obs=config["observations"]["angle"][..., t], ) - pyro.sample("angle_{}".format(t), - angle_dist, - obs=config["observations"]["angle"][..., t]) # observation 3: dive activity omega_dist = dist.Beta( concentration0=Vindex(omega_concentration0)[..., y], - concentration1=Vindex(omega_concentration1)[..., y] + concentration1=Vindex(omega_concentration1)[..., y], ) # zero-inflation with MaskedMixture omega_zi = Vindex(omega_zi_param)[..., y, :] omega_zi_mask = config["observations"]["omega"][..., t] == MISSING - pyro.sample("omega_zi_{}".format(t), - dist.Categorical(logits=omega_zi), - obs=omega_zi_mask.long()) + pyro.sample( + "omega_zi_{}".format(t), + dist.Categorical(logits=omega_zi), + obs=omega_zi_mask.long(), + ) omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) - omega_zi_dist = dist.MaskedMixture(omega_zi_mask, omega_dist, omega_zi_zero_dist) + omega_zi_dist = dist.MaskedMixture( + omega_zi_mask, omega_dist, omega_zi_zero_dist + ) - pyro.sample("omega_{}".format(t), - omega_zi_dist, - obs=config["observations"]["omega"][..., t]) + pyro.sample( + "omega_{}".format(t), + omega_zi_dist, + obs=config["observations"]["omega"][..., t], + ) diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py index 609fc69da3..5a24a0c56d 100644 --- a/examples/mixed_hmm/seal_data.py +++ b/examples/mixed_hmm/seal_data.py @@ -30,7 +30,9 @@ def prepare_seal(filename, random_effects): for g, (group, group_df) in enumerate(seal_df.groupby("sex")): for i, (ind, ind_df) in enumerate(group_df.groupby("ID")): for o, obs_key in enumerate(obs_keys): - observations[i, g, 0:len(ind_df), o] = torch.tensor(ind_df[obs_key].values) + observations[i, g, 0 : len(ind_df), o] = torch.tensor( + ind_df[obs_key].values + ) observations[torch.isnan(observations)] = float("-inf") @@ -39,11 +41,11 @@ def prepare_seal(filename, random_effects): mask_i = (observations > float("-inf")).any(dim=-1).any(dim=-1) # time nonempty # mask_t handles padding for time series of different length - mask_t = (observations > float("-inf")).all(dim=-1) # include non-inf + mask_t = (observations > float("-inf")).all(dim=-1) # include non-inf # temporary hack to avoid zero-inflation issues # observations[observations == 0.] = MISSING - observations[(observations == 0.) | (observations == float("-inf"))] = MISSING + observations[(observations == 0.0) | (observations == float("-inf"))] = MISSING assert not torch.isnan(observations).any() # observations = observations[..., 5:11, :] # truncate for testing @@ -58,7 +60,11 @@ def prepare_seal(filename, random_effects): "timesteps": observations.shape[2], }, "group": {"random": random_effects["group"], "fixed": None}, - "individual": {"random": random_effects["individual"], "fixed": None, "mask": mask_i}, + "individual": { + "random": random_effects["individual"], + "fixed": None, + "mask": mask_i, + }, "timestep": {"random": None, "fixed": None, "mask": mask_t}, "observations": { "step": observations[..., 0], diff --git a/examples/neutra.py b/examples/neutra.py index b7f2870ed2..046bac18d5 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -40,7 +40,7 @@ from pyro.infer.autoguide import AutoDiagonalNormal, AutoNormalizingFlow from pyro.infer.reparam import NeuTraReparam -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) class BananaShaped(dist.TorchDistribution): @@ -49,8 +49,10 @@ class BananaShaped(dist.TorchDistribution): def __init__(self, a, b, rho=0.9): self.a, self.b, self.rho = broadcast_all(a, b, rho) - self.mvn = dist.MultivariateNormal(torch.tensor([0., 0.]), - covariance_matrix=torch.tensor([[1., self.rho], [self.rho, 1.]])) + self.mvn = dist.MultivariateNormal( + torch.tensor([0.0, 0.0]), + covariance_matrix=torch.tensor([[1.0, self.rho], [self.rho, 1.0]]), + ) super().__init__(event_shape=(2,)) def sample(self, sample_shape=()): @@ -70,12 +72,12 @@ def log_prob(self, x): def model(a, b, rho=0.9): - pyro.sample('x', BananaShaped(a, b, rho)) + pyro.sample("x", BananaShaped(a, b, rho)) def fit_guide(guide, args): pyro.clear_param_store() - adam = optim.Adam({'lr': args.learning_rate}) + adam = optim.Adam({"lr": args.learning_rate}) svi = SVI(model, guide, adam, Trace_ELBO()) for i in range(args.num_steps): loss = svi.step(args.param_a, args.param_b) @@ -103,8 +105,8 @@ def main(args): ax6 = fig.add_subplot(gs[1, 1]) ax7 = fig.add_subplot(gs[2, 1]) ax8 = fig.add_subplot(gs[3, 1]) - xlim = tuple(int(x) for x in args.x_lim.strip().split(',')) - ylim = tuple(int(x) for x in args.y_lim.strip().split(',')) + xlim = tuple(int(x) for x in args.x_lim.strip().split(",")) + ylim = tuple(int(x) for x in args.y_lim.strip().split(",")) assert len(xlim) == 2 assert len(ylim) == 2 @@ -112,102 +114,169 @@ def main(args): x1, x2 = torch.meshgrid([torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)]) d = BananaShaped(args.param_a, args.param_b) p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1))) - ax1.contourf(x1, x2, p, cmap='OrRd',) - ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='BananaShaped distribution: \nlog density') + ax1.contourf( + x1, + x2, + p, + cmap="OrRd", + ) + ax1.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="BananaShaped distribution: \nlog density", + ) # 2. Run vanilla HMC - logging.info('\nDrawing samples using vanilla HMC ...') + logging.info("\nDrawing samples using vanilla HMC ...") mcmc = run_hmc(args, model) - vanilla_samples = mcmc.get_samples()['x'].cpu().numpy() - ax2.contourf(x1, x2, p, cmap='OrRd') - ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='Posterior \n(vanilla HMC)') + vanilla_samples = mcmc.get_samples()["x"].cpu().numpy() + ax2.contourf(x1, x2, p, cmap="OrRd") + ax2.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="Posterior \n(vanilla HMC)", + ) sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2) # 3(a). Fit a diagonal normal autoguide - logging.info('\nFitting a DiagNormal autoguide ...') + logging.info("\nFitting a DiagNormal autoguide ...") guide = AutoDiagonalNormal(model, init_scale=0.05) fit_guide(guide, args) - with pyro.plate('N', args.num_samples): - guide_samples = guide()['x'].detach().cpu().numpy() - - ax3.contourf(x1, x2, p, cmap='OrRd') - ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='Posterior \n(DiagNormal autoguide)') + with pyro.plate("N", args.num_samples): + guide_samples = guide()["x"].detach().cpu().numpy() + + ax3.contourf(x1, x2, p, cmap="OrRd") + ax3.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="Posterior \n(DiagNormal autoguide)", + ) sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3) # 3(b). Draw samples using NeuTra HMC - logging.info('\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...') + logging.info("\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...") neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) - zs = mcmc.get_samples()['x_shared_latent'] + zs = mcmc.get_samples()["x_shared_latent"] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4) - ax4.set(xlabel='x0', ylabel='x1', - title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)') + ax4.set( + xlabel="x0", + ylabel="x1", + title="Posterior (warped) samples \n(DiagNormal + NeuTra HMC)", + ) samples = neutra.transform_sample(zs) - samples = samples['x'].cpu().numpy() - ax5.contourf(x1, x2, p, cmap='OrRd') - ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)') + samples = samples["x"].cpu().numpy() + ax5.contourf(x1, x2, p, cmap="OrRd") + ax5.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="Posterior (transformed) \n(DiagNormal + NeuTra HMC)", + ) sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5) # 4(a). Fit a BNAF autoguide - logging.info('\nFitting a BNAF autoguide ...') - guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, block_autoregressive)) + logging.info("\nFitting a BNAF autoguide ...") + guide = AutoNormalizingFlow( + model, partial(iterated, args.num_flows, block_autoregressive) + ) fit_guide(guide, args) - with pyro.plate('N', args.num_samples): - guide_samples = guide()['x'].detach().cpu().numpy() - - ax6.contourf(x1, x2, p, cmap='OrRd') - ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='Posterior \n(BNAF autoguide)') + with pyro.plate("N", args.num_samples): + guide_samples = guide()["x"].detach().cpu().numpy() + + ax6.contourf(x1, x2, p, cmap="OrRd") + ax6.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="Posterior \n(BNAF autoguide)", + ) sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6) # 4(b). Draw samples using NeuTra HMC - logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...') + logging.info("\nDrawing samples using BNAF autoguide + NeuTra HMC ...") neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) - zs = mcmc.get_samples()['x_shared_latent'] + zs = mcmc.get_samples()["x_shared_latent"] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7) - ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)') + ax7.set( + xlabel="x0", + ylabel="x1", + title="Posterior (warped) samples \n(BNAF + NeuTra HMC)", + ) samples = neutra.transform_sample(zs) - samples = samples['x'].cpu().numpy() - ax8.contourf(x1, x2, p, cmap='OrRd') - ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, - title='Posterior (transformed) \n(BNAF + NeuTra HMC)') + samples = samples["x"].cpu().numpy() + ax8.contourf(x1, x2, p, cmap="OrRd") + ax8.set( + xlabel="x0", + ylabel="x1", + xlim=xlim, + ylim=ylim, + title="Posterior (transformed) \n(BNAF + NeuTra HMC)", + ) sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8) - plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf')) - - -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer') - parser.add_argument('-n', '--num-steps', default=10000, type=int, - help='number of SVI steps') - parser.add_argument('-lr', '--learning-rate', default=1e-2, type=float, - help='learning rate for the Adam optimizer') - parser.add_argument('--rng-seed', default=1, type=int, - help='RNG seed') - parser.add_argument('--num-warmup', default=500, type=int, - help='number of warmup steps for NUTS') - parser.add_argument('--num-samples', default=1000, type=int, - help='number of samples to be drawn from NUTS') - parser.add_argument('--param-a', default=1.15, type=float, - help='parameter `a` of BananaShaped distribution') - parser.add_argument('--param-b', default=1., type=float, - help='parameter `b` of BananaShaped distribution') - parser.add_argument('--num-flows', default=1, type=int, - help='number of flows in the BNAF autoguide') - parser.add_argument('--x-lim', default='-3,3', type=str, - help='x limits for the plots') - parser.add_argument('--y-lim', default='0,8', type=str, - help='y limits for the plots') + plt.savefig(os.path.join(os.path.dirname(__file__), "neutra.pdf")) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser( + description="Example illustrating NeuTra Reparametrizer" + ) + parser.add_argument( + "-n", "--num-steps", default=10000, type=int, help="number of SVI steps" + ) + parser.add_argument( + "-lr", + "--learning-rate", + default=1e-2, + type=float, + help="learning rate for the Adam optimizer", + ) + parser.add_argument("--rng-seed", default=1, type=int, help="RNG seed") + parser.add_argument( + "--num-warmup", default=500, type=int, help="number of warmup steps for NUTS" + ) + parser.add_argument( + "--num-samples", + default=1000, + type=int, + help="number of samples to be drawn from NUTS", + ) + parser.add_argument( + "--param-a", + default=1.15, + type=float, + help="parameter `a` of BananaShaped distribution", + ) + parser.add_argument( + "--param-b", + default=1.0, + type=float, + help="parameter `b` of BananaShaped distribution", + ) + parser.add_argument( + "--num-flows", default=1, type=int, help="number of flows in the BNAF autoguide" + ) + parser.add_argument( + "--x-lim", default="-3,3", type=str, help="x limits for the plots" + ) + parser.add_argument( + "--y-lim", default="0,8", type=str, help="y limits for the plots" + ) args = parser.parse_args() main(args) diff --git a/examples/rsa/generics.py b/examples/rsa/generics.py index 27de269cd3..2482a51220 100644 --- a/examples/rsa/generics.py +++ b/examples/rsa/generics.py @@ -40,28 +40,41 @@ def discretize_beta_pdf(bins, gamma, delta): discretized version of the Beta pdf used for approximately integrating via Search """ shape_alpha = gamma * delta - shape_beta = (1.-gamma) * delta + shape_beta = (1.0 - gamma) * delta return torch.tensor( - list(map(lambda x: (x ** (shape_alpha-1)) * ((1.-x)**(shape_beta-1)), bins))) + list( + map( + lambda x: (x ** (shape_alpha - 1)) * ((1.0 - x) ** (shape_beta - 1)), + bins, + ) + ) + ) @Marginal def structured_prior_model(params): - propertyIsPresent = pyro.sample("propertyIsPresent", - dist.Bernoulli(params.theta)).item() == 1 + propertyIsPresent = ( + pyro.sample("propertyIsPresent", dist.Bernoulli(params.theta)).item() == 1 + ) if propertyIsPresent: # approximately integrate over a beta by enumerating over bins beta_bins = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99] - ix = pyro.sample("bin", dist.Categorical( - probs=discretize_beta_pdf(beta_bins, params.gamma, params.delta))) + ix = pyro.sample( + "bin", + dist.Categorical( + probs=discretize_beta_pdf(beta_bins, params.gamma, params.delta) + ), + ) return beta_bins[ix] else: return 0 def threshold_prior(): - threshold_bins = [0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - ix = pyro.sample("threshold", dist.Categorical(logits=torch.zeros(len(threshold_bins)))) + threshold_bins = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + ix = pyro.sample( + "threshold", dist.Categorical(logits=torch.zeros(len(threshold_bins))) + ) return threshold_bins[ix] @@ -93,13 +106,13 @@ def meaning(utterance, state, threshold): def listener0(utterance, threshold, prior): state = pyro.sample("state", prior) m = meaning(utterance, state, threshold) - pyro.factor("listener0_true", 0. if m else -99999.) + pyro.factor("listener0_true", 0.0 if m else -99999.0) return state @Marginal def speaker1(state, threshold, prior): - s1Optimality = 5. + s1Optimality = 5.0 utterance = utterance_prior() L0 = listener0(utterance, threshold, prior) with poutine.scale(scale=torch.tensor(s1Optimality)): @@ -125,18 +138,22 @@ def speaker2(prevalence, prior): def main(args): - hasWingsERP = structured_prior_model(Params(theta=0.5, gamma=0.99, delta=10.)) - laysEggsERP = structured_prior_model(Params(theta=0.5, gamma=0.5, delta=10.)) - carriesMalariaERP = structured_prior_model(Params(theta=0.1, gamma=0.01, delta=2.)) - areFemaleERP = structured_prior_model(Params(theta=0.99, gamma=0.5, delta=50.)) + hasWingsERP = structured_prior_model(Params(theta=0.5, gamma=0.99, delta=10.0)) + laysEggsERP = structured_prior_model(Params(theta=0.5, gamma=0.5, delta=10.0)) + carriesMalariaERP = structured_prior_model(Params(theta=0.1, gamma=0.01, delta=2.0)) + areFemaleERP = structured_prior_model(Params(theta=0.99, gamma=0.5, delta=50.0)) # listener interpretation of generics wingsPosterior = listener1("generic is true", hasWingsERP) malariaPosterior = listener1("generic is true", carriesMalariaERP) eggsPosterior = listener1("generic is true", laysEggsERP) femalePosterior = listener1("generic is true", areFemaleERP) - listeners = {"wings": wingsPosterior, "malaria": malariaPosterior, - "eggs": eggsPosterior, "female": femalePosterior} + listeners = { + "wings": wingsPosterior, + "malaria": malariaPosterior, + "eggs": eggsPosterior, + "female": femalePosterior, + } for name, listener in listeners.items(): for elt in listener.enumerate_support(): @@ -147,8 +164,12 @@ def main(args): eggSpeaker = speaker2(0.6, laysEggsERP) femaleSpeaker = speaker2(0.5, areFemaleERP) lionSpeaker = speaker2(0.01, laysEggsERP) - speakers = {"malaria": malariaSpeaker, "egg": eggSpeaker, - "female": femaleSpeaker, "lion": lionSpeaker} + speakers = { + "malaria": malariaSpeaker, + "egg": eggSpeaker, + "female": femaleSpeaker, + "lion": lionSpeaker, + } for name, speaker in speakers.items(): for elt in speaker.enumerate_support(): @@ -156,8 +177,8 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=10, type=int) + parser.add_argument("-n", "--num-samples", default=10, type=int) args = parser.parse_args() main(args) diff --git a/examples/rsa/hyperbole.py b/examples/rsa/hyperbole.py index f5f1ff4b58..22236e334d 100644 --- a/examples/rsa/hyperbole.py +++ b/examples/rsa/hyperbole.py @@ -34,15 +34,17 @@ def Marginal(fn): def approx(x, b=None): if b is None: - b = 10. - div = float(x)/b + b = 10.0 + div = float(x) / b rounded = int(div) + 1 if div - float(int(div)) >= 0.5 else int(div) return int(b) * rounded def price_prior(): values = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001] - probs = torch.tensor([0.4205, 0.3865, 0.0533, 0.0538, 0.0223, 0.0211, 0.0112, 0.0111, 0.0083, 0.0120]) + probs = torch.tensor( + [0.4205, 0.3865, 0.0533, 0.0538, 0.0223, 0.0211, 0.0112, 0.0111, 0.0083, 0.0120] + ) ix = pyro.sample("price", dist.Categorical(probs=probs)) return values[ix] @@ -58,7 +60,7 @@ def valence_prior(price): 5000: 0.9524, 5001: 0.9524, 10000: 0.9864, - 10001: 0.9864 + 10001: 0.9864, } return pyro.sample("valence", dist.Bernoulli(probs=probs[price])).item() == 1 @@ -72,25 +74,30 @@ def meaning(utterance, price): "valence": lambda state: State(price=None, valence=state.valence), "priceValence": lambda state: State(price=state.price, valence=state.valence), "approxPrice": lambda state: State(price=approx(state.price), valence=None), - "approxPriceValence": lambda state: State(price=approx(state.price), valence=state.valence), + "approxPriceValence": lambda state: State( + price=approx(state.price), valence=state.valence + ), } def qud_prior(): values = ["price", "valence", "priceValence", "approxPrice", "approxPriceValence"] - ix = pyro.sample("qud", dist.Categorical(probs=torch.ones(len(values)) / len(values))) + ix = pyro.sample( + "qud", dist.Categorical(probs=torch.ones(len(values)) / len(values)) + ) return values[ix] def utterance_cost(numberUtt): - preciseNumberCost = 1. - return 0. if approx(numberUtt) == numberUtt else preciseNumberCost + preciseNumberCost = 1.0 + return 0.0 if approx(numberUtt) == numberUtt else preciseNumberCost def utterance_prior(): utterances = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001] - utteranceLogits = -torch.tensor(list(map(utterance_cost, utterances)), - dtype=torch.float64) + utteranceLogits = -torch.tensor( + list(map(utterance_cost, utterances)), dtype=torch.float64 + ) ix = pyro.sample("utterance", dist.Categorical(logits=utteranceLogits)) return utterances[ix] @@ -99,13 +106,13 @@ def utterance_prior(): def literal_listener(utterance, qud): price = price_prior() state = State(price=price, valence=valence_prior(price)) - pyro.factor("literal_meaning", 0. if meaning(utterance, price) else -999999.) + pyro.factor("literal_meaning", 0.0 if meaning(utterance, price) else -999999.0) return qud_fns[qud](state) @Marginal def speaker(qudValue, qud): - alpha = 1. + alpha = 1.0 utterance = utterance_prior() literal_marginal = literal_listener(utterance, qud) with poutine.scale(scale=torch.tensor(alpha)): @@ -130,15 +137,68 @@ def pragmatic_listener(utterance): def test_truth(): true_vals = { - "probs": torch.tensor([0.0018655171404222354,0.1512643329444101,0.0030440475496016296,0.23182161303428897,0.00003854830096338984,0.01502495595927897,0.00003889558295405101,0.015160315922876075,0.00016425635615857924,0.026788637869123822,0.00017359794987375924,0.028312162297699582,0.0008164336950199063,0.060558944822420434,0.0008088460212743665,0.05999612935009309,0.01925106279557206,0.17429720083660782,0.02094455861717477,0.18962994295418778]), # noqa: E231,E501 - "support": list(map(lambda d: State(**d), [{"price":10001,"valence":False},{"price":10001,"valence":True},{"price":10000,"valence":False},{"price":10000,"valence":True},{"price":5001,"valence":False},{"price":5001,"valence":True},{"price":5000,"valence":False},{"price":5000,"valence":True},{"price":1001,"valence":False},{"price":1001,"valence":True},{"price":1000,"valence":False},{"price":1000,"valence":True},{"price":501,"valence":False},{"price":501,"valence":True},{"price":500,"valence":False},{"price":500,"valence":True},{"price":51,"valence":False},{"price":51,"valence":True},{"price":50,"valence":False},{"price":50,"valence":True}])) # noqa: E231,E501 + "probs": torch.tensor( + [ + 0.0018655171404222354, + 0.1512643329444101, + 0.0030440475496016296, + 0.23182161303428897, + 0.00003854830096338984, + 0.01502495595927897, + 0.00003889558295405101, + 0.015160315922876075, + 0.00016425635615857924, + 0.026788637869123822, + 0.00017359794987375924, + 0.028312162297699582, + 0.0008164336950199063, + 0.060558944822420434, + 0.0008088460212743665, + 0.05999612935009309, + 0.01925106279557206, + 0.17429720083660782, + 0.02094455861717477, + 0.18962994295418778, + ] + ), # noqa: E231,E501 + "support": list( + map( + lambda d: State(**d), + [ + {"price": 10001, "valence": False}, + {"price": 10001, "valence": True}, + {"price": 10000, "valence": False}, + {"price": 10000, "valence": True}, + {"price": 5001, "valence": False}, + {"price": 5001, "valence": True}, + {"price": 5000, "valence": False}, + {"price": 5000, "valence": True}, + {"price": 1001, "valence": False}, + {"price": 1001, "valence": True}, + {"price": 1000, "valence": False}, + {"price": 1000, "valence": True}, + {"price": 501, "valence": False}, + {"price": 501, "valence": True}, + {"price": 500, "valence": False}, + {"price": 500, "valence": True}, + {"price": 51, "valence": False}, + {"price": 51, "valence": True}, + {"price": 50, "valence": False}, + {"price": 50, "valence": True}, + ], + ) + ), # noqa: E231,E501 } pragmatic_marginal = pragmatic_listener(10000) for i, elt in enumerate(true_vals["support"]): - print("{}: true prob {} pyro prob {}".format( - elt, true_vals["probs"][i].item(), - pragmatic_marginal.log_prob(elt).exp().item())) + print( + "{}: true prob {} pyro prob {}".format( + elt, + true_vals["probs"][i].item(), + pragmatic_marginal.log_prob(elt).exp().item(), + ) + ) def main(args): @@ -148,14 +208,18 @@ def main(args): pragmatic_marginal = pragmatic_listener(args.price) pd, pv = pragmatic_marginal._dist_and_values() - print([(s, pragmatic_marginal.log_prob(s).exp().item()) - for s in pragmatic_marginal.enumerate_support()]) + print( + [ + (s, pragmatic_marginal.log_prob(s).exp().item()) + for s in pragmatic_marginal.enumerate_support() + ] + ) if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=10, type=int) - parser.add_argument('--price', default=10000, type=int) + parser.add_argument("-n", "--num-samples", default=10, type=int) + parser.add_argument("--price", default=10000, type=int) args = parser.parse_args() main(args) diff --git a/examples/rsa/schelling.py b/examples/rsa/schelling.py index 5e658cca07..64628f397c 100644 --- a/examples/rsa/schelling.py +++ b/examples/rsa/schelling.py @@ -68,19 +68,21 @@ def main(args): # draw num_samples samples from Bob's decision process # and use those to estimate the marginal probability # that Bob chooses their preferred location - bob_prob = sum([bob_decision() - for i in range(num_samples)]) / float(num_samples) + bob_prob = sum([bob_decision() for i in range(num_samples)]) / float(num_samples) - print("Empirical frequency of Bob choosing their favored location " + - "given preference {} and recursion depth {}: {}" - .format(shared_preference, bob_depth, bob_prob)) + print( + "Empirical frequency of Bob choosing their favored location " + + "given preference {} and recursion depth {}: {}".format( + shared_preference, bob_depth, bob_prob + ) + ) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=10, type=int) - parser.add_argument('--depth', default=2, type=int) - parser.add_argument('--preference', default=0.6, type=float) + parser.add_argument("-n", "--num-samples", default=10, type=int) + parser.add_argument("--depth", default=2, type=int) + parser.add_argument("--preference", default=0.6, type=float) args = parser.parse_args() main(args) diff --git a/examples/rsa/schelling_false.py b/examples/rsa/schelling_false.py index cd51ed55f5..fd3c770796 100644 --- a/examples/rsa/schelling_false.py +++ b/examples/rsa/schelling_false.py @@ -38,7 +38,7 @@ def alice_fb(preference, depth): """ alice_prior = location(preference) with poutine.block(): - bob_marginal = HashingMarginal(Search(bob).run(preference, depth-1)) + bob_marginal = HashingMarginal(Search(bob).run(preference, depth - 1)) pyro.sample("bob_choice", bob_marginal, obs=alice_prior) return 1 - alice_prior @@ -76,24 +76,30 @@ def main(args): # We sample Alice's true choice of location # by marginalizing over her decision process - alice_decision = HashingMarginal(Search(alice_fb).run(shared_preference, alice_depth)) + alice_decision = HashingMarginal( + Search(alice_fb).run(shared_preference, alice_depth) + ) # draw num_samples samples from Alice's decision process # and use those to estimate the marginal probability # that Alice chooses their preferred location - alice_prob = sum([alice_decision() - for i in range(num_samples)]) / float(num_samples) + alice_prob = sum([alice_decision() for i in range(num_samples)]) / float( + num_samples + ) - print("Empirical frequency of Alice choosing their favored location " + - "given preference {} and recursion depth {}: {}" - .format(shared_preference, alice_depth, alice_prob)) + print( + "Empirical frequency of Alice choosing their favored location " + + "given preference {} and recursion depth {}: {}".format( + shared_preference, alice_depth, alice_prob + ) + ) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=10, type=int) - parser.add_argument('--depth', default=3, type=int) - parser.add_argument('--preference', default=0.55, type=float) + parser.add_argument("-n", "--num-samples", default=10, type=int) + parser.add_argument("--depth", default=3, type=int) + parser.add_argument("--preference", default=0.55, type=float) args = parser.parse_args() main(args) diff --git a/examples/rsa/search_inference.py b/examples/rsa/search_inference.py index 7e2cb8e142..e61e2ebbac 100644 --- a/examples/rsa/search_inference.py +++ b/examples/rsa/search_inference.py @@ -33,15 +33,16 @@ class HashingMarginal(dist.Distribution): Turns a TracePosterior object into a Distribution over the return values of the TracePosterior's model. """ + def __init__(self, trace_dist, sites=None): - assert isinstance(trace_dist, TracePosterior), \ - "trace_dist must be trace posterior distribution object" + assert isinstance( + trace_dist, TracePosterior + ), "trace_dist must be trace posterior distribution object" if sites is None: sites = "_RETURN" - assert isinstance(sites, (str, list)), \ - "sites must be either '_RETURN' or list" + assert isinstance(sites, (str, list)), "sites must be either '_RETURN' or list" self.sites = sites super().__init__() @@ -53,8 +54,7 @@ def __init__(self, trace_dist, sites=None): def _dist_and_values(self): # XXX currently this whole object is very inefficient values_map, logits = collections.OrderedDict(), collections.OrderedDict() - for tr, logit in zip(self.trace_dist.exec_traces, - self.trace_dist.log_weights): + for tr, logit in zip(self.trace_dist.exec_traces, self.trace_dist.log_weights): if isinstance(self.sites, str): value = tr.nodes[self.sites]["value"] else: @@ -70,7 +70,9 @@ def _dist_and_values(self): value_hash = hash(value) if value_hash in logits: # Value has already been seen. - logits[value_hash] = dist.util.logsumexp(torch.stack([logits[value_hash], logit]), dim=-1) + logits[value_hash] = dist.util.logsumexp( + torch.stack([logits[value_hash], logit]), dim=-1 + ) else: logits[value_hash] = logit values_map[value_hash] = value @@ -132,10 +134,12 @@ def variance(self): # Exact Search inference ######################## + class Search(TracePosterior): """ Exact inference by enumerating over all possible executions """ + def __init__(self, model, max_tries=int(1e6), **kwargs): self.model = model self.max_tries = max_tries @@ -144,8 +148,7 @@ def __init__(self, model, max_tries=int(1e6), **kwargs): def _traces(self, *args, **kwargs): q = queue.Queue() q.put(poutine.Trace()) - p = poutine.trace( - poutine.queue(self.model, queue=q, max_tries=self.max_tries)) + p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries)) while not q.empty(): tr = p.get_trace(*args, **kwargs) yield tr, tr.log_prob_sum() @@ -157,30 +160,38 @@ def _traces(self, *args, **kwargs): def pqueue(fn, queue): - def sample_escape(tr, site): - return (site["name"] not in tr) and \ - (site["type"] == "sample") and \ - (not site["is_observed"]) + return ( + (site["name"] not in tr) + and (site["type"] == "sample") + and (not site["is_observed"]) + ) def _fn(*args, **kwargs): for i in range(int(1e6)): - assert not queue.empty(), \ - "trying to get() from an empty queue will deadlock" + assert ( + not queue.empty() + ), "trying to get() from an empty queue will deadlock" priority, next_trace = queue.get() try: - ftr = poutine.trace(poutine.escape(poutine.replay(fn, next_trace), - functools.partial(sample_escape, - next_trace))) + ftr = poutine.trace( + poutine.escape( + poutine.replay(fn, next_trace), + functools.partial(sample_escape, next_trace), + ) + ) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() - for tr in poutine.util.enum_extend(ftr.trace.copy(), - site_container.site): + for tr in poutine.util.enum_extend( + ftr.trace.copy(), site_container.site + ): # add a little bit of noise to the priority to break ties... - queue.put((tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr)) + queue.put( + (tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr) + ) raise ValueError("max tries ({}) exceeded".format(str(1e6))) @@ -192,6 +203,7 @@ class BestFirstSearch(TracePosterior): Inference by enumerating executions ordered by their probabilities. Exact (and results equivalent to Search) if all executions are enumerated. """ + def __init__(self, model, num_samples=None, **kwargs): if num_samples is None: num_samples = 100 diff --git a/examples/rsa/semantic_parsing.py b/examples/rsa/semantic_parsing.py index b9dea031db..d8b656063a 100644 --- a/examples/rsa/semantic_parsing.py +++ b/examples/rsa/semantic_parsing.py @@ -22,13 +22,16 @@ def Marginal(fn=None, **kwargs): if fn is None: return lambda _fn: Marginal(_fn, **kwargs) - return memoize(lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args))) + return memoize( + lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args)) + ) ################################################################### # Lexical semantics ################################################################### + def flip(name, p): return pyro.sample(name, dist.Bernoulli(p)).item() == 1 @@ -38,10 +41,12 @@ def flip(name, p): def Obj(name): - return obj(name=name, - blond=flip(name + "_blond", 0.5), - nice=flip(name + "_nice", 0.5), - tall=flip(name + "_tall", 0.5)) + return obj( + name=name, + blond=flip(name + "_blond", 0.5), + nice=flip(name + "_nice", 0.5), + tall=flip(name + "_tall", 0.5), + ) class Meaning: @@ -99,6 +104,7 @@ def sem(self, world): def f1(P): def f2(Q): return len(list(filter(Q, filter(P, world)))) > 0 + return f2 return f1 @@ -110,8 +116,8 @@ def syn(self): "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, - "out": "S" - } + "out": "S", + }, } @@ -119,8 +125,10 @@ class AllMeaning(Meaning): def sem(self, world): def f1(P): def f2(Q): - return len(list(filter(lambda *args: not Q(*args), - filter(P, world)))) == 0 + return ( + len(list(filter(lambda *args: not Q(*args), filter(P, world)))) == 0 + ) + return f2 return f1 @@ -132,8 +140,8 @@ def syn(self): "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, - "out": "S" - } + "out": "S", + }, } @@ -142,6 +150,7 @@ def sem(self, world): def f1(P): def f2(Q): return len(list(filter(Q, filter(P, world)))) == 0 + return f2 return f1 @@ -153,8 +162,8 @@ def syn(self): "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, - "out": "S" - } + "out": "S", + }, } @@ -174,14 +183,15 @@ def syn(self): # Compositional semantics ################################################################### + def heuristic(is_good): if is_good: - return torch.tensor(0.) + return torch.tensor(0.0) return torch.tensor(-100.0) def world_prior(num_objs, meaning_fn): - prev_factor = torch.tensor(0.) + prev_factor = torch.tensor(0.0) world = [] for i in range(num_objs): world.append(Obj("obj_{}".format(i))) @@ -200,7 +210,7 @@ def lexical_meaning(word): "Bob": BobMeaning, "some": SomeMeaning, "none": NoneMeaning, - "all": AllMeaning + "all": AllMeaning, } if word in meanings: return meanings[word]() @@ -214,9 +224,11 @@ def apply_world_passing(f, a): def syntax_match(s, t): if "dir" in s and "dir" in t: - return (s["dir"] and t["dir"]) and \ - syntax_match(s["int"], t["int"]) and \ - syntax_match(s["out"], t["out"]) + return ( + (s["dir"] and t["dir"]) + and syntax_match(s["int"], t["int"]) + and syntax_match(s["out"], t["out"]) + ) else: return s == t @@ -228,9 +240,9 @@ def can_apply(meanings): s = meaning.syn() if "dir" in s: if s["dir"] == "L": - applies = syntax_match(s["int"], meanings[i-1].syn()) + applies = syntax_match(s["int"], meanings[i - 1].syn()) elif s["dir"] == "R": - applies = syntax_match(s["int"], meanings[i+1].syn()) + applies = syntax_match(s["int"], meanings[i + 1].syn()) else: applies = False @@ -243,34 +255,32 @@ def can_apply(meanings): def combine_meaning(meanings, c): possible_combos = can_apply(meanings) N = len(possible_combos) - ix = pyro.sample("ix_{}".format(c), - dist.Categorical(torch.ones(N) / N)) + ix = pyro.sample("ix_{}".format(c), dist.Categorical(torch.ones(N) / N)) i = possible_combos[ix] s = meanings[i].syn() if s["dir"] == "L": f = meanings[i].sem - a = meanings[i-1].sem - new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), - syn=s["out"]) - return meanings[0:i-1] + [new_meaning] + meanings[i+1:] + a = meanings[i - 1].sem + new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), syn=s["out"]) + return meanings[0 : i - 1] + [new_meaning] + meanings[i + 1 :] if s["dir"] == "R": f = meanings[i].sem - a = meanings[i+1].sem - new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), - syn=s["out"]) - return meanings[0:i] + [new_meaning] + meanings[i+2:] + a = meanings[i + 1].sem + new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), syn=s["out"]) + return meanings[0:i] + [new_meaning] + meanings[i + 2 :] def combine_meanings(meanings, c=0): if len(meanings) == 1: return meanings[0].sem else: - return combine_meanings(combine_meaning(meanings, c), c=c+1) + return combine_meanings(combine_meaning(meanings, c), c=c + 1) def meaning(utterance): - defined = filter(lambda w: "" != w.syn(), - list(map(lexical_meaning, utterance.split(" ")))) + defined = filter( + lambda w: "" != w.syn(), list(map(lexical_meaning, utterance.split(" "))) + ) return combine_meanings(list(defined)) @@ -283,9 +293,11 @@ def literal_listener(utterance): def utterance_prior(): - utterances = ["some of the blond people are nice", - "all of the blond people are nice", - "none of the blond people are nice"] + utterances = [ + "some of the blond people are nice", + "all of the blond people are nice", + "none of the blond people are nice", + ] ix = pyro.sample("utterance", dist.Categorical(torch.ones(3) / 3.0)) return utterances[ix] @@ -339,8 +351,8 @@ def is_all_qud(world): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=10, type=int) + parser.add_argument("-n", "--num-samples", default=10, type=int) args = parser.parse_args() main(args) diff --git a/examples/scanvi/data.py b/examples/scanvi/data.py index 429883d1a3..ae291e62f0 100644 --- a/examples/scanvi/data.py +++ b/examples/scanvi/data.py @@ -20,6 +20,7 @@ class BatchDataLoader(object): This custom DataLoader serves mini-batches that are either fully-observed (i.e. labeled) or partially-observed (i.e. unlabeled) but never mixed. """ + def __init__(self, data_x, data_y, batch_size, num_classes=4, missing_label=-1): super().__init__() self.data_x = data_x @@ -53,11 +54,11 @@ def _sample_batch_indices(self): slices = [] for i in range(self.num_unlabeled_batches): - _slice = unlabeled_idx[i * self.batch_size: (i + 1) * self.batch_size] + _slice = unlabeled_idx[i * self.batch_size : (i + 1) * self.batch_size] slices.append((_slice, False)) for i in range(self.num_labeled_batches): - _slice = labeled_idx[i * self.batch_size: (i + 1) * self.batch_size] + _slice = labeled_idx[i * self.batch_size : (i + 1) * self.batch_size] slices.append((_slice, True)) return slices, batch_order @@ -69,8 +70,9 @@ def __iter__(self): _slice = slices[batch_order[i]] if _slice[1]: # labeled - yield self.data_x[_slice[0]], \ - nn.functional.one_hot(self.data_y[_slice[0]], num_classes=self.num_classes) + yield self.data_x[_slice[0]], nn.functional.one_hot( + self.data_y[_slice[0]], num_classes=self.num_classes + ) else: # unlabeled yield self.data_x[_slice[0]], None @@ -81,10 +83,10 @@ def _get_score(normalized_adata, gene_set): Returns the score per cell given a dictionary of + and - genes """ score = np.zeros(normalized_adata.n_obs) - for gene in gene_set['positive']: + for gene in gene_set["positive"]: expression = np.array(normalized_adata[:, gene].X) score += expression.flatten() - for gene in gene_set['negative']: + for gene in gene_set["negative"]: expression = np.array(normalized_adata[:, gene].X) score -= expression.flatten() return score @@ -106,13 +108,15 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): """ Does the necessary preprocessing and returns a BatchDataLoader for the PBMC dataset. """ - assert dataset in ['pbmc', 'mock'] + assert dataset in ["pbmc", "mock"] # create mock dataset for CI - if dataset == 'mock': + if dataset == "mock": num_genes = 17 num_data = 200 - X = torch.distributions.Poisson(rate=10.0).sample(sample_shape=(num_data, num_genes)) + X = torch.distributions.Poisson(rate=10.0).sample( + sample_shape=(num_data, num_genes) + ) Y = torch.zeros(num_data, dtype=torch.long) Y[50:100] = 1 Y[100:] = -1 @@ -124,11 +128,25 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): import scanpy as sc import scvi - adata = scvi.data.purified_pbmc_dataset(subset_datasets=["regulatory_t", "naive_t", - "memory_t", "naive_cytotoxic"]) - gene_subset = ["CD4", "FOXP3", "TNFRSF18", "IL2RA", "CTLA4", "CD44", "TCF7", - "CD8B", "CCR7", "CD69", "PTPRC", "S100A4"] + adata = scvi.data.purified_pbmc_dataset( + subset_datasets=["regulatory_t", "naive_t", "memory_t", "naive_cytotoxic"] + ) + + gene_subset = [ + "CD4", + "FOXP3", + "TNFRSF18", + "IL2RA", + "CTLA4", + "CD44", + "TCF7", + "CD8B", + "CCR7", + "CD69", + "PTPRC", + "S100A4", + ] normalized = adata.copy() sc.pp.normalize_total(normalized, target_sum=1e4) @@ -138,10 +156,19 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): sc.pp.scale(normalized) # hand curated list of genes for identifying ground truth - cd4_reg_geneset = {"positive": ["TNFRSF18", "CTLA4", "FOXP3", "IL2RA"], "negative": ["S100A4", "PTPRC", "CD8B"]} + cd4_reg_geneset = { + "positive": ["TNFRSF18", "CTLA4", "FOXP3", "IL2RA"], + "negative": ["S100A4", "PTPRC", "CD8B"], + } cd8_naive_geneset = {"positive": ["CD8B", "CCR7"], "negative": ["CD4"]} - cd4_naive_geneset = {"positive": ["CCR7", "CD4"], "negative": ["S100A4", "PTPRC", "FOXP3", "IL2RA", "CD69"]} - cd4_mem_geneset = {"positive": ["S100A4"], "negative": ["IL2RA", "FOXP3", "TNFRSF18", "CCR7"]} + cd4_naive_geneset = { + "positive": ["CCR7", "CD4"], + "negative": ["S100A4", "PTPRC", "FOXP3", "IL2RA", "CD69"], + } + cd4_mem_geneset = { + "positive": ["S100A4"], + "negative": ["IL2RA", "FOXP3", "TNFRSF18", "CCR7"], + } cd4_reg_mask = _get_cell_mask(normalized, cd4_reg_geneset) cd8_naive_mask = _get_cell_mask(normalized, cd8_naive_geneset) @@ -152,27 +179,27 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): seed_labels = -np.ones(cd4_mem_mask.shape[0]) seed_labels[cd8_naive_mask] = 0 # "CD8 Naive T cell" seed_labels[cd4_naive_mask] = 1 # "CD4 Naive T cell" - seed_labels[cd4_mem_mask] = 2 # "CD4 Memory T cell" - seed_labels[cd4_reg_mask] = 3 # "CD4 Regulatory T cell" + seed_labels[cd4_mem_mask] = 2 # "CD4 Memory T cell" + seed_labels[cd4_reg_mask] = 3 # "CD4 Regulatory T cell" # this metadata will be used for plotting - seed_colors = ['lightgray'] * seed_labels.shape[0] + seed_colors = ["lightgray"] * seed_labels.shape[0] seed_sizes = [0.05] * seed_labels.shape[0] for i in range(len(seed_colors)): if seed_labels[i] == 0: - seed_colors[i] = 'lightcoral' + seed_colors[i] = "lightcoral" elif seed_labels[i] == 1: - seed_colors[i] = 'limegreen' + seed_colors[i] = "limegreen" elif seed_labels[i] == 2: - seed_colors[i] = 'deepskyblue' + seed_colors[i] = "deepskyblue" elif seed_labels[i] == 3: - seed_colors[i] = 'mediumorchid' + seed_colors[i] = "mediumorchid" if seed_labels[i] != -1: seed_sizes[i] = 25 - adata.obs['seed_labels'] = seed_labels - adata.obs['seed_colors'] = seed_colors - adata.obs['seed_marker_sizes'] = seed_sizes + adata.obs["seed_labels"] = seed_labels + adata.obs["seed_colors"] = seed_colors + adata.obs["seed_marker_sizes"] = seed_sizes Y = torch.from_numpy(seed_labels).long() X = torch.from_numpy(sparse.csr_matrix.todense(adata.X)).float() @@ -197,4 +224,10 @@ def get_data(dataset="pbmc", batch_size=100, cuda=False): adata = adata[idx.data.cpu().numpy()] adata.raw = adata - return BatchDataLoader(X[idx], Y[idx], batch_size), num_genes, l_mean, l_scale, adata + return ( + BatchDataLoader(X[idx], Y[idx], batch_size), + num_genes, + l_mean, + l_scale, + adata, + ) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 6028bc1b08..05c5e1203d 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -146,7 +146,16 @@ def forward(self, x): # Encompasses the scANVI model and guide as a PyTorch nn.Module class SCANVI(nn.Module): - def __init__(self, num_genes, num_labels, l_loc, l_scale, latent_dim=10, alpha=0.01, scale_factor=1.0): + def __init__( + self, + num_genes, + num_labels, + l_loc, + l_scale, + latent_dim=10, + alpha=0.01, + scale_factor=1.0, + ): assert isinstance(num_genes, int) self.num_genes = num_genes @@ -173,13 +182,27 @@ def __init__(self, num_genes, num_labels, l_loc, l_scale, latent_dim=10, alpha=0 super().__init__() # Setup the various neural networks used in the model and guide - self.z2_decoder = Z2Decoder(z1_dim=self.latent_dim, y_dim=self.num_labels, - z2_dim=self.latent_dim, hidden_dims=[50]) - self.x_decoder = XDecoder(num_genes=num_genes, hidden_dims=[100], z2_dim=self.latent_dim) - self.z2l_encoder = Z2LEncoder(num_genes=num_genes, z2_dim=self.latent_dim, hidden_dims=[100]) - self.classifier = Classifier(z2_dim=self.latent_dim, hidden_dims=[50], num_labels=num_labels) - self.z1_encoder = Z1Encoder(num_labels=num_labels, z1_dim=self.latent_dim, - z2_dim=self.latent_dim, hidden_dims=[50]) + self.z2_decoder = Z2Decoder( + z1_dim=self.latent_dim, + y_dim=self.num_labels, + z2_dim=self.latent_dim, + hidden_dims=[50], + ) + self.x_decoder = XDecoder( + num_genes=num_genes, hidden_dims=[100], z2_dim=self.latent_dim + ) + self.z2l_encoder = Z2LEncoder( + num_genes=num_genes, z2_dim=self.latent_dim, hidden_dims=[100] + ) + self.classifier = Classifier( + z2_dim=self.latent_dim, hidden_dims=[50], num_labels=num_labels + ) + self.z1_encoder = Z1Encoder( + num_labels=num_labels, + z1_dim=self.latent_dim, + z2_dim=self.latent_dim, + hidden_dims=[50], + ) self.epsilon = 5.0e-3 @@ -188,17 +211,23 @@ def model(self, x, y=None): pyro.module("scanvi", self) # This gene-level parameter modulates the variance of the observation distribution - theta = pyro.param("inverse_dispersion", 10.0 * x.new_ones(self.num_genes), - constraint=constraints.positive) + theta = pyro.param( + "inverse_dispersion", + 10.0 * x.new_ones(self.num_genes), + constraint=constraints.positive, + ) # We scale all sample statements by scale_factor so that the ELBO is normalized # wrt the number of datapoints and genes with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor): - z1 = pyro.sample("z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1)) + z1 = pyro.sample( + "z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1) + ) # Note that if y is None (i.e. y is unobserved) then y will be sampled; # otherwise y will be treated as observed. - y = pyro.sample("y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)), - obs=y) + y = pyro.sample( + "y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)), obs=y + ) z2_loc, z2_scale = self.z2_decoder(z1, y) z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1)) @@ -213,8 +242,9 @@ def model(self, x, y=None): # from failure to success parametrization; # see https://github.com/pytorch/pytorch/issues/42449 nb_logits = (l * mu + self.epsilon).log() - (theta + self.epsilon).log() - x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta, - logits=nb_logits) + x_dist = dist.ZeroInflatedNegativeBinomial( + gate_logits=gate_logits, total_count=theta, logits=nb_logits + ) # Observe the datapoint x using the observation distribution x_dist pyro.sample("x", x_dist.to_event(1), obs=x) @@ -249,12 +279,18 @@ def main(args): # Enable optional validation warnings # Load and pre-process data - dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset=args.dataset, batch_size=args.batch_size, - cuda=args.cuda) + dataloader, num_genes, l_mean, l_scale, anndata = get_data( + dataset=args.dataset, batch_size=args.batch_size, cuda=args.cuda + ) # Instantiate instance of model/guide and various neural networks - scanvi = SCANVI(num_genes=num_genes, num_labels=4, l_loc=l_mean, l_scale=l_scale, - scale_factor=1.0 / (args.batch_size * num_genes)) + scanvi = SCANVI( + num_genes=num_genes, + num_labels=4, + l_loc=l_mean, + l_scale=l_scale, + scale_factor=1.0 / (args.batch_size * num_genes), + ) if args.cuda: scanvi.cuda() @@ -262,10 +298,14 @@ def main(args): # Setup an optimizer (Adam) and learning rate scheduler. # By default we start with a moderately high learning rate (0.005) # and reduce by a factor of 5 after 20 epochs. - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': args.learning_rate}, - 'milestones': [20], - 'gamma': 0.2}) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": args.learning_rate}, + "milestones": [20], + "gamma": 0.2, + } + ) # Tell Pyro to enumerate out y when y is unobserved guide = config_enumerate(scanvi.guide, "parallel", expand=True) @@ -295,7 +335,7 @@ def main(args): scanvi.eval() # Now that we're done training we'll inspect the latent representations we've learned - if args.plot and args.dataset == 'pbmc': + if args.plot and args.dataset == "pbmc": import scanpy as sc # Compute latent representation (z2_loc) for each cell in the dataset @@ -310,55 +350,88 @@ def main(args): anndata.obsm["X_scANVI"] = latent_rep.data.cpu().numpy() sc.pp.neighbors(anndata, use_rep="X_scANVI") sc.tl.umap(anndata) - umap1, umap2 = anndata.obsm['X_umap'][:, 0], anndata.obsm['X_umap'][:, 1] + umap1, umap2 = anndata.obsm["X_umap"][:, 0], anndata.obsm["X_umap"][:, 1] # Construct plots; all plots are scatterplots depicting the two-dimensional UMAP embedding # and only differ in how points are colored # The topmost plot depicts the 200 hand-curated seed labels in our dataset fig, axes = plt.subplots(3, 2) - seed_marker_sizes = anndata.obs['seed_marker_sizes'] - axes[0, 0].scatter(umap1, umap2, s=seed_marker_sizes, c=anndata.obs['seed_colors'], marker='.', alpha=0.7) - axes[0, 0].set_title('Hand-Curated Seed Labels') - patch1 = Patch(color='lightcoral', label='CD8-Naive') - patch2 = Patch(color='limegreen', label='CD4-Naive') - patch3 = Patch(color='deepskyblue', label='CD4-Memory') - patch4 = Patch(color='mediumorchid', label='CD4-Regulatory') - axes[0, 1].legend(loc='center left', handles=[patch1, patch2, patch3, patch4]) + seed_marker_sizes = anndata.obs["seed_marker_sizes"] + axes[0, 0].scatter( + umap1, + umap2, + s=seed_marker_sizes, + c=anndata.obs["seed_colors"], + marker=".", + alpha=0.7, + ) + axes[0, 0].set_title("Hand-Curated Seed Labels") + patch1 = Patch(color="lightcoral", label="CD8-Naive") + patch2 = Patch(color="limegreen", label="CD4-Naive") + patch3 = Patch(color="deepskyblue", label="CD4-Memory") + patch4 = Patch(color="mediumorchid", label="CD4-Regulatory") + axes[0, 1].legend(loc="center left", handles=[patch1, patch2, patch3, patch4]) axes[0, 1].get_xaxis().set_visible(False) axes[0, 1].get_yaxis().set_visible(False) axes[0, 1].set_frame_on(False) # The remaining plots depict the inferred cell type probability for each of the four cell types - s10 = axes[1, 0].scatter(umap1, umap2, s=1, c=y_probs[:, 0], marker='.', alpha=0.7) - axes[1, 0].set_title('Inferred CD8-Naive probability') + s10 = axes[1, 0].scatter( + umap1, umap2, s=1, c=y_probs[:, 0], marker=".", alpha=0.7 + ) + axes[1, 0].set_title("Inferred CD8-Naive probability") fig.colorbar(s10, ax=axes[1, 0]) - s11 = axes[1, 1].scatter(umap1, umap2, s=1, c=y_probs[:, 1], marker='.', alpha=0.7) - axes[1, 1].set_title('Inferred CD4-Naive probability') + s11 = axes[1, 1].scatter( + umap1, umap2, s=1, c=y_probs[:, 1], marker=".", alpha=0.7 + ) + axes[1, 1].set_title("Inferred CD4-Naive probability") fig.colorbar(s11, ax=axes[1, 1]) - s20 = axes[2, 0].scatter(umap1, umap2, s=1, c=y_probs[:, 2], marker='.', alpha=0.7) - axes[2, 0].set_title('Inferred CD4-Memory probability') + s20 = axes[2, 0].scatter( + umap1, umap2, s=1, c=y_probs[:, 2], marker=".", alpha=0.7 + ) + axes[2, 0].set_title("Inferred CD4-Memory probability") fig.colorbar(s20, ax=axes[2, 0]) - s21 = axes[2, 1].scatter(umap1, umap2, s=1, c=y_probs[:, 3], marker='.', alpha=0.7) - axes[2, 1].set_title('Inferred CD4-Regulatory probability') + s21 = axes[2, 1].scatter( + umap1, umap2, s=1, c=y_probs[:, 3], marker=".", alpha=0.7 + ) + axes[2, 1].set_title("Inferred CD4-Regulatory probability") fig.colorbar(s21, ax=axes[2, 1]) fig.tight_layout() - plt.savefig('scanvi.pdf') + plt.savefig("scanvi.pdf") if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") # Parse command line arguments - parser = argparse.ArgumentParser(description="single-cell ANnotation using Variational Inference") - parser.add_argument('-s', '--seed', default=0, type=int, help='rng seed') - parser.add_argument('-n', '--num-epochs', default=60, type=int, help='number of training epochs') - parser.add_argument('-d', '--dataset', default='pbmc', type=str, - help='which dataset to use', choices=['pbmc', 'mock']) - parser.add_argument('-bs', '--batch-size', default=100, type=int, help='mini-batch size') - parser.add_argument('-lr', '--learning-rate', default=0.005, type=float, help='learning rate') - parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda') - parser.add_argument('--plot', action='store_true', default=False, help='whether to make a plot') + parser = argparse.ArgumentParser( + description="single-cell ANnotation using Variational Inference" + ) + parser.add_argument("-s", "--seed", default=0, type=int, help="rng seed") + parser.add_argument( + "-n", "--num-epochs", default=60, type=int, help="number of training epochs" + ) + parser.add_argument( + "-d", + "--dataset", + default="pbmc", + type=str, + help="which dataset to use", + choices=["pbmc", "mock"], + ) + parser.add_argument( + "-bs", "--batch-size", default=100, type=int, help="mini-batch size" + ) + parser.add_argument( + "-lr", "--learning-rate", default=0.005, type=float, help="learning rate" + ) + parser.add_argument( + "--cuda", action="store_true", default=False, help="whether to use cuda" + ) + parser.add_argument( + "--plot", action="store_true", default=False, help="whether to make a plot" + ) args = parser.parse_args() main(args) diff --git a/examples/sir_hmc.py b/examples/sir_hmc.py index 43b96613a1..6acd511358 100644 --- a/examples/sir_hmc.py +++ b/examples/sir_hmc.py @@ -30,7 +30,7 @@ from pyro.ops.tensor_utils import convolve from pyro.util import warn_if_nan -logging.basicConfig(format='%(message)s', level=logging.INFO) +logging.basicConfig(format="%(message)s", level=logging.INFO) # A Discrete SIR Model @@ -59,9 +59,10 @@ # Binomial.log_prob() will error, whereas ExtendedBinomial.log_prob() will # return -inf. + def global_model(population): tau = args.recovery_time # Assume this can be measured exactly. - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Uniform(0, 1)) # Convert interpretable parameters to distribution parameters. @@ -76,27 +77,26 @@ def discrete_model(args, data): rate_s, prob_i, rho = global_model(args.population) # Sequentially sample time-local variables. - S = torch.tensor(args.population - 1.) - I = torch.tensor(1.) + S = torch.tensor(args.population - 1.0) + I = torch.tensor(1.0) for t, datum in enumerate(data): - S2I = pyro.sample("S2I_{}".format(t), - dist.Binomial(S, -(rate_s * I).expm1())) - I2R = pyro.sample("I2R_{}".format(t), - dist.Binomial(I, prob_i)) + S2I = pyro.sample("S2I_{}".format(t), dist.Binomial(S, -(rate_s * I).expm1())) + I2R = pyro.sample("I2R_{}".format(t), dist.Binomial(I, prob_i)) S = pyro.deterministic("S_{}".format(t), S - S2I) I = pyro.deterministic("I_{}".format(t), I + S2I - I2R) - pyro.sample("obs_{}".format(t), - dist.ExtendedBinomial(S2I, rho), - obs=datum) + pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum) # We can use this model to simulate data. We'll use poutine.condition to pin # parameter values and poutine.trace to record sample observations. + def generate_data(args): logging.info("Generating data...") - params = {"R0": torch.tensor(args.basic_reproduction_number), - "rho": torch.tensor(args.response_rate)} + params = { + "R0": torch.tensor(args.basic_reproduction_number), + "rho": torch.tensor(args.response_rate), + } empty_data = [None] * (args.duration + args.forecast) # We'll retry until we get an actual outbreak. @@ -106,24 +106,38 @@ def generate_data(args): discrete_model(args, empty_data) # Concatenate sequential time series into tensors. - obs = torch.stack([site["value"] - for name, site in tr.trace.nodes.items() - if re.match("obs_[0-9]+", name)]) - S2I = torch.stack([site["value"] - for name, site in tr.trace.nodes.items() - if re.match("S2I_[0-9]+", name)]) + obs = torch.stack( + [ + site["value"] + for name, site in tr.trace.nodes.items() + if re.match("obs_[0-9]+", name) + ] + ) + S2I = torch.stack( + [ + site["value"] + for name, site in tr.trace.nodes.items() + if re.match("S2I_[0-9]+", name) + ] + ) assert len(obs) == len(empty_data) - obs_sum = int(obs[:args.duration].sum()) - S2I_sum = int(S2I[:args.duration].sum()) + obs_sum = int(obs[: args.duration].sum()) + S2I_sum = int(S2I[: args.duration].sum()) if obs_sum >= args.min_observations: - logging.info("Observed {:d}/{:d} infections:\n{}".format( - obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]]))) + logging.info( + "Observed {:d}/{:d} infections:\n{}".format( + obs_sum, + S2I_sum, + " ".join([str(int(x)) for x in obs[: args.duration]]), + ) + ) return {"S2I": S2I, "obs": obs} - raise ValueError("Failed to generate {} observations. Try increasing " - "--population or decreasing --min-observations" - .format(args.min_observations)) + raise ValueError( + "Failed to generate {} observations. Try increasing " + "--population or decreasing --min-observations".format(args.min_observations) + ) # Inference @@ -147,36 +161,37 @@ def generate_data(args): # # The following model is equivalent to the discrete_model: + @config_enumerate def reparameterized_discrete_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sequentially sample time-local variables. - S_curr = torch.tensor(args.population - 1.) - I_curr = torch.tensor(1.) + S_curr = torch.tensor(args.population - 1.0) + I_curr = torch.tensor(1.0) for t, datum in enumerate(data): # Sample reparameterizing variables. # When reparameterizing to a factor graph, we ignored density via # .mask(False). Thus distributions are used only for initialization. S_prev, I_prev = S_curr, I_curr - S_curr = pyro.sample("S_{}".format(t), - dist.Binomial(args.population, 0.5).mask(False)) - I_curr = pyro.sample("I_{}".format(t), - dist.Binomial(args.population, 0.5).mask(False)) + S_curr = pyro.sample( + "S_{}".format(t), dist.Binomial(args.population, 0.5).mask(False) + ) + I_curr = pyro.sample( + "I_{}".format(t), dist.Binomial(args.population, 0.5).mask(False) + ) # Now we reverse the computation. S2I = S_prev - S_curr I2R = I_prev - I_curr + S2I - pyro.sample("S2I_{}".format(t), - dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), - obs=S2I) - pyro.sample("I2R_{}".format(t), - dist.ExtendedBinomial(I_prev, prob_i), - obs=I2R) - pyro.sample("obs_{}".format(t), - dist.ExtendedBinomial(S2I, rho), - obs=datum) + pyro.sample( + "S2I_{}".format(t), + dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), + obs=S2I, + ) + pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R) + pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum) # By reparameterizing, we have converted to coordinates that make the model @@ -190,6 +205,7 @@ def reparameterized_discrete_model(args, data): # # Here is an inference approach using an MCMC sampler. + def infer_hmc_enum(args, data): model = reparameterized_discrete_model return _infer_hmc(args, data, model) @@ -197,11 +213,14 @@ def infer_hmc_enum(args, data): def _infer_hmc(args, data, model, init_values={}): logging.info("Running inference...") - kernel = NUTS(model, - full_mass=[("R0", "rho")], - max_tree_depth=args.max_tree_depth, - init_strategy=init_to_value(values=init_values), - jit_compile=args.jit, ignore_jit_warnings=True) + kernel = NUTS( + model, + full_mass=[("R0", "rho")], + max_tree_depth=args.max_tree_depth, + init_strategy=init_to_value(values=init_values), + jit_compile=args.jit, + ignore_jit_warnings=True, + ) # We'll define a hook_fn to log potential energy values during inference. # This is helpful to diagnose whether the chain is mixing. @@ -213,13 +232,17 @@ def hook_fn(kernel, *unused): if args.verbose: logging.info("potential = {:0.6g}".format(e)) - mcmc = MCMC(kernel, hook_fn=hook_fn, - num_samples=args.num_samples, - warmup_steps=args.warmup_steps) + mcmc = MCMC( + kernel, + hook_fn=hook_fn, + num_samples=args.num_samples, + warmup_steps=args.warmup_steps, + ) mcmc.run(args, data) mcmc.summary() if args.plot: import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) plt.plot(energies) plt.xlabel("MCMC step") @@ -240,6 +263,7 @@ def hook_fn(kernel, *unused): # # We first define a helper to create enumerated Categorical sites. + def quantize(name, x_real, min, max): """ Randomly quantize in a way that preserves probability mass. @@ -254,12 +278,18 @@ def quantize(name, x_real, min, max): ss = s * s t = 1 - s tt = t * t - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) + probs = ( + torch.stack( + [ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], + dim=-1, + ) + * (1 / 6) + ) q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real) x = lb + q - 1 @@ -271,22 +301,31 @@ def quantize(name, x_real, min, max): # Now we can define another equivalent model. + @config_enumerate def continuous_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sample reparameterizing variables. - S_aux = pyro.sample("S_aux", - dist.Uniform(-0.5, args.population + 0.5) - .mask(False).expand(data.shape).to_event(1)) - I_aux = pyro.sample("I_aux", - dist.Uniform(-0.5, args.population + 0.5) - .mask(False).expand(data.shape).to_event(1)) + S_aux = pyro.sample( + "S_aux", + dist.Uniform(-0.5, args.population + 0.5) + .mask(False) + .expand(data.shape) + .to_event(1), + ) + I_aux = pyro.sample( + "I_aux", + dist.Uniform(-0.5, args.population + 0.5) + .mask(False) + .expand(data.shape) + .to_event(1), + ) # Sequentially sample time-local variables. - S_curr = torch.tensor(args.population - 1.) - I_curr = torch.tensor(1.) + S_curr = torch.tensor(args.population - 1.0) + I_curr = torch.tensor(1.0) for t, datum in poutine.markov(enumerate(data)): S_prev, I_prev = S_curr, I_curr S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population) @@ -295,15 +334,13 @@ def continuous_model(args, data): # Now we reverse the computation. S2I = S_prev - S_curr I2R = I_prev - I_curr + S2I - pyro.sample("S2I_{}".format(t), - dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), - obs=S2I) - pyro.sample("I2R_{}".format(t), - dist.ExtendedBinomial(I_prev, prob_i), - obs=I2R) - pyro.sample("obs_{}".format(t), - dist.ExtendedBinomial(S2I, rho), - obs=datum) + pyro.sample( + "S2I_{}".format(t), + dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), + obs=S2I, + ) + pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R) + pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum) # Now all latent variables in the continuous_model are either continuous or @@ -312,18 +349,19 @@ def continuous_model(args, data): # hypothesis space that are infeasible (i.e. whose log_prob is -infinity). We # thus heuristically initialize to a feasible point. + def heuristic_init(args, data): """Heuristically initialize to a feasible point.""" # Start with a single infection. S0 = args.population - 1 # Assume 50% <= response rate <= 100%. - S2I = data * min(2., (S0 / data.sum()).sqrt()) + S2I = data * min(2.0, (S0 / data.sum()).sqrt()) S_aux = (S0 - S2I.cumsum(-1)).clamp(min=0.5) # Account for the single initial infection. S2I[0] += 1 # Assume infection lasts less than a month. - recovery = torch.arange(30.).div(args.recovery_time).neg().exp() - I_aux = convolve(S2I, recovery)[:len(data)].clamp(min=0.5) + recovery = torch.arange(30.0).div(args.recovery_time).neg().exp() + I_aux = convolve(S2I, recovery)[: len(data)].clamp(min=0.5) return { "R0": torch.tensor(2.0), @@ -344,6 +382,7 @@ def infer_hmc_cont(model, args, data): # with 4 * 4 = 16 states, and then manually perform variable elimination (the # factors here don't quite conform to DiscreteHMM's interface). + def quantize_enumerate(x_real, min, max): """ Randomly quantize in a way that preserves probability mass. @@ -358,14 +397,20 @@ def quantize_enumerate(x_real, min, max): ss = s * s t = 1 - s tt = t * t - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) + probs = ( + torch.stack( + [ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], + dim=-1, + ) + * (1 / 6) + ) logits = safe_log(probs) - q = torch.arange(-1., 3.) + q = torch.arange(-1.0, 3.0) x = lb.unsqueeze(-1) + q x = torch.max(x, 2 * min - 1 - x) @@ -378,18 +423,28 @@ def vectorized_model(args, data): rate_s, prob_i, rho = global_model(args.population) # Sample reparameterizing variables. - S_aux = pyro.sample("S_aux", - dist.Uniform(-0.5, args.population + 0.5) - .mask(False).expand(data.shape).to_event(1)) - I_aux = pyro.sample("I_aux", - dist.Uniform(-0.5, args.population + 0.5) - .mask(False).expand(data.shape).to_event(1)) + S_aux = pyro.sample( + "S_aux", + dist.Uniform(-0.5, args.population + 0.5) + .mask(False) + .expand(data.shape) + .to_event(1), + ) + I_aux = pyro.sample( + "I_aux", + dist.Uniform(-0.5, args.population + 0.5) + .mask(False) + .expand(data.shape) + .to_event(1), + ) # Manually enumerate. S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population) I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population) # Truncate final value from the right then pad initial value onto the left. - S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1) + S_prev = torch.nn.functional.pad( + S_curr[:-1], (0, 0, 1, 0), value=args.population - 1 + ) I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1) # Reshape to support broadcasting, similar to EnumMessenger. T = len(data) @@ -429,20 +484,24 @@ def vectorized_model(args, data): # After inference we have samples of all latent variables. Let's define a # helper to examine the inferred posterior distributions. + def evaluate(args, samples): # Print estimated values. - names = {"basic_reproduction_number": "R0", - "response_rate": "rho"} + names = {"basic_reproduction_number": "R0", "response_rate": "rho"} for name, key in names.items(): mean = samples[key].mean().item() std = samples[key].std().item() - logging.info("{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}" - .format(key, getattr(args, name), mean, std)) + logging.info( + "{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}".format( + key, getattr(args, name), mean, std + ) + ) # Optionally plot histograms. if args.plot: import matplotlib.pyplot as plt import seaborn as sns + fig, axes = plt.subplots(2, 1, figsize=(5, 5)) axes[0].set_title("Posterior parameter estimates") for ax, (name, key) in zip(axes, names.items()): @@ -469,6 +528,7 @@ def evaluate(args, samples): # forecast forward in time. Let's assume posterior samples have already been # generated via infer_hmc_cont(vectorized_model, ...). + @torch.no_grad() def predict(args, data, samples, truth=None): logging.info("Forecasting {} steps ahead...".format(args.forecast)) @@ -483,9 +543,11 @@ def predict(args, data, samples, truth=None): model = infer_discrete(model, first_available_dim=-2) with poutine.trace() as tr: model(args, data) - samples = OrderedDict((name, site["value"]) - for name, site in tr.trace.nodes.items() - if site["type"] == "sample") + samples = OrderedDict( + (name, site["value"]) + for name, site in tr.trace.nodes.items() + if site["type"] == "sample" + ) # Next we'll run the forward generative process in discrete_model. This # samples time steps [duration:duration+forecast]. Again we'll update the @@ -495,35 +557,39 @@ def predict(args, data, samples, truth=None): model = particle_plate(model) with poutine.trace() as tr: model(args, extended_data) - samples = OrderedDict((name, site["value"]) - for name, site in tr.trace.nodes.items() - if site["type"] == "sample") + samples = OrderedDict( + (name, site["value"]) + for name, site in tr.trace.nodes.items() + if site["type"] == "sample" + ) # Finally we'll concatenate the sequentially sampled values into contiguous # tensors. This operates on the entire time interval [0:duration+forecast]. for key in ("S", "I", "S2I", "I2R"): pattern = key + "_[0-9]+" - series = [value - for name, value in samples.items() - if re.match(pattern, name)] + series = [value for name, value in samples.items() if re.match(pattern, name)] assert len(series) == args.duration + args.forecast series[0] = series[0].expand(series[1].shape) samples[key] = torch.stack(series, dim=-1) S2I = samples["S2I"] median = S2I.median(dim=0).values - logging.info("Median prediction of new infections (starting on day 0):\n{}" - .format(" ".join(map(str, map(int, median))))) + logging.info( + "Median prediction of new infections (starting on day 0):\n{}".format( + " ".join(map(str, map(int, median))) + ) + ) # Optionally plot the latent and forecasted series of new infections. if args.plot: import matplotlib.pyplot as plt + plt.figure() time = torch.arange(args.duration + args.forecast) p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") plt.plot(time, median, "r-", label="median") - plt.plot(time[:args.duration], data, "k.", label="observed") + plt.plot(time[: args.duration], data, "k.", label="observed") if truth is not None: plt.plot(time, truth, "k--", label="truth") plt.axvline(args.duration - 0.5, color="gray", lw=1) @@ -547,11 +613,12 @@ def predict(args, data, samples, truth=None): # # python sir_hmc.py -p 10000 -d 60 -f 30 --plot + def main(args): pyro.set_rng_seed(args.rng_seed) dataset = generate_data(args) - obs = dataset["obs"][:args.duration] + obs = dataset["obs"][: args.duration] # Choose among inference methods. if args.enum: @@ -572,7 +639,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC") parser.add_argument("-p", "--population", default=10, type=int) parser.add_argument("-m", "--min-observations", default=3, type=int) @@ -581,10 +648,15 @@ def main(args): parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float) parser.add_argument("-tau", "--recovery-time", default=7.0, type=float) parser.add_argument("-rho", "--response-rate", default=0.5, type=float) - parser.add_argument("-e", "--enum", action="store_true", - help="use the full enumeration model") - parser.add_argument("-s", "--sequential", action="store_true", - help="use the sequential continuous model") + parser.add_argument( + "-e", "--enum", action="store_true", help="use the full enumeration model" + ) + parser.add_argument( + "-s", + "--sequential", + action="store_true", + help="use the sequential continuous model", + ) parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-w", "--warmup-steps", default=100, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) @@ -608,4 +680,5 @@ def main(args): if args.plot: import matplotlib.pyplot as plt + plt.show() diff --git a/examples/smcfilter.py b/examples/smcfilter.py index 7ef6425608..0fe81b99c8 100644 --- a/examples/smcfilter.py +++ b/examples/smcfilter.py @@ -23,11 +23,9 @@ class SimpleHarmonicModel: - def __init__(self, process_noise, measurement_noise): - self.A = torch.tensor([[0., 1.], - [-1., 0.]]) - self.B = torch.tensor([3., 3.]) + self.A = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) + self.B = torch.tensor([3.0, 3.0]) self.sigma_z = torch.tensor(process_noise) self.sigma_y = torch.tensor(measurement_noise) @@ -39,15 +37,15 @@ def step(self, state, y=None): self.t += 1 state["z"] = pyro.sample( "z_{}".format(self.t), - dist.Normal(state["z"].matmul(self.A), self.B*self.sigma_z).to_event(1)) - y = pyro.sample("y_{}".format(self.t), - dist.Normal(state["z"][..., 0], self.sigma_y), - obs=y) + dist.Normal(state["z"].matmul(self.A), self.B * self.sigma_z).to_event(1), + ) + y = pyro.sample( + "y_{}".format(self.t), dist.Normal(state["z"][..., 0], self.sigma_y), obs=y + ) return state["z"], y class SimpleHarmonicModel_Guide: - def __init__(self, model): self.model = model @@ -61,14 +59,17 @@ def step(self, state, y=None): # Proposal distribution pyro.sample( "z_{}".format(self.t), - dist.Normal(state["z"].matmul(self.model.A), torch.tensor([1., 1.])).to_event(1)) + dist.Normal( + state["z"].matmul(self.model.A), torch.tensor([1.0, 1.0]) + ).to_event(1), + ) def generate_data(args): model = SimpleHarmonicModel(args.process_noise, args.measurement_noise) state = {} - initial = torch.tensor([1., 0.]) + initial = torch.tensor([1.0, 0.0]) model.init(state, initial=initial) zs = [initial] ys = [None] @@ -93,7 +94,7 @@ def main(args): logging.info("Filtering") - smc.init(initial=torch.tensor([1., 0.])) + smc.init(initial=torch.tensor([1.0, 0.0])) for y in ys[1:]: smc.step(y) @@ -105,11 +106,13 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Simple Harmonic Oscillator w/ SMC Filtering Inference") + parser = argparse.ArgumentParser( + description="Simple Harmonic Oscillator w/ SMC Filtering Inference" + ) parser.add_argument("-n", "--num-timesteps", default=500, type=int) parser.add_argument("-p", "--num-particles", default=100, type=int) - parser.add_argument("--process-noise", default=1., type=float) - parser.add_argument("--measurement-noise", default=1., type=float) + parser.add_argument("--process-noise", default=1.0, type=float) + parser.add_argument("--measurement-noise", default=1.0, type=float) parser.add_argument("--seed", default=0, type=int) args = parser.parse_args() main(args) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index a85ae0fc88..9a729c7257 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -31,7 +31,7 @@ from pyro.infer import SVI, TraceMeanField_ELBO from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible -torch.set_default_tensor_type('torch.FloatTensor') +torch.set_default_tensor_type("torch.FloatTensor") pyro.util.set_rng_seed(0) @@ -72,25 +72,41 @@ def model(self, x): # sample the local latent random variables # (the plate encodes the fact that the z's for different datapoints are conditionally independent) with pyro.plate("data", x_size): - z_top = pyro.sample("z_top", Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1)) + z_top = pyro.sample( + "z_top", + Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1), + ) # note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping # to make sure our code is fully vectorized - w_top = w_top.reshape(self.top_width, self.mid_width) if w_top.dim() == 1 else \ - w_top.reshape(-1, self.top_width, self.mid_width) + w_top = ( + w_top.reshape(self.top_width, self.mid_width) + if w_top.dim() == 1 + else w_top.reshape(-1, self.top_width, self.mid_width) + ) mean_mid = torch.matmul(z_top, w_top) - z_mid = pyro.sample("z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1)) - - w_mid = w_mid.reshape(self.mid_width, self.bottom_width) if w_mid.dim() == 1 else \ - w_mid.reshape(-1, self.mid_width, self.bottom_width) + z_mid = pyro.sample( + "z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1) + ) + + w_mid = ( + w_mid.reshape(self.mid_width, self.bottom_width) + if w_mid.dim() == 1 + else w_mid.reshape(-1, self.mid_width, self.bottom_width) + ) mean_bottom = torch.matmul(z_mid, w_mid) - z_bottom = pyro.sample("z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1)) - - w_bottom = w_bottom.reshape(self.bottom_width, self.image_size) if w_bottom.dim() == 1 else \ - w_bottom.reshape(-1, self.bottom_width, self.image_size) + z_bottom = pyro.sample( + "z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1) + ) + + w_bottom = ( + w_bottom.reshape(self.bottom_width, self.image_size) + if w_bottom.dim() == 1 + else w_bottom.reshape(-1, self.bottom_width, self.image_size) + ) mean_obs = torch.matmul(z_bottom, w_bottom) # observe the data using a poisson likelihood - pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x) + pyro.sample("obs", Poisson(mean_obs).to_event(1), obs=x) # define our custom guide a.k.a. variational distribution. # (note the guide is mean field gamma) @@ -99,19 +115,29 @@ def guide(self, x): # define a helper function to sample z's for a single layer def sample_zs(name, width): - alpha_z_q = pyro.param("alpha_z_q_%s" % name, - lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init)) - mean_z_q = pyro.param("mean_z_q_%s" % name, - lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init)) + alpha_z_q = pyro.param( + "alpha_z_q_%s" % name, + lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init), + ) + mean_z_q = pyro.param( + "mean_z_q_%s" % name, + lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init), + ) alpha_z_q, mean_z_q = softplus(alpha_z_q), softplus(mean_z_q) - pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1)) + pyro.sample( + "z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1) + ) # define a helper function to sample w's for a single layer def sample_ws(name, width): - alpha_w_q = pyro.param("alpha_w_q_%s" % name, - lambda: rand_tensor((width), self.alpha_init, self.sigma_init)) - mean_w_q = pyro.param("mean_w_q_%s" % name, - lambda: rand_tensor((width), self.mean_init, self.sigma_init)) + alpha_w_q = pyro.param( + "alpha_w_q_%s" % name, + lambda: rand_tensor((width), self.alpha_init, self.sigma_init), + ) + mean_w_q = pyro.param( + "mean_w_q_%s" % name, + lambda: rand_tensor((width), self.mean_init, self.sigma_init), + ) alpha_w_q, mean_w_q = softplus(alpha_w_q), softplus(mean_w_q) pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q)) @@ -153,10 +179,14 @@ class MyEasyGuide(EasyGuide): def guide(self, x): # group all the latent weights into one large latent variable global_group = self.group(match="w_.*") - global_mean = pyro.param("w_mean", - lambda: rand_tensor(global_group.event_shape, 0.5, 0.1)) - global_scale = softplus(pyro.param("w_scale", - lambda: rand_tensor(global_group.event_shape, 0.0, 0.1))) + global_mean = pyro.param( + "w_mean", lambda: rand_tensor(global_group.event_shape, 0.5, 0.1) + ) + global_scale = softplus( + pyro.param( + "w_scale", lambda: rand_tensor(global_group.event_shape, 0.0, 0.1) + ) + ) # use a mean field Normal distribution on all the ws global_group.sample("ws", Normal(global_mean, global_scale).to_event(1)) @@ -165,19 +195,19 @@ def guide(self, x): x_shape = x.shape[:1] + local_group.event_shape with self.plate("data", x.size(0)): - local_mean = pyro.param("z_mean", - lambda: rand_tensor(x_shape, 0.5, 0.1)) - local_scale = softplus(pyro.param("z_scale", - lambda: rand_tensor(x_shape, 0.0, 0.1))) + local_mean = pyro.param("z_mean", lambda: rand_tensor(x_shape, 0.5, 0.1)) + local_scale = softplus( + pyro.param("z_scale", lambda: rand_tensor(x_shape, 0.0, 0.1)) + ) # use a mean field Normal distribution on all the zs local_group.sample("zs", Normal(local_mean, local_scale).to_event(1)) def main(args): # load data - print('loading training data...') + print("loading training data...") dataset_directory = get_data_directory(__file__) - dataset_path = os.path.join(dataset_directory, 'faces_training.csv') + dataset_path = os.path.join(dataset_directory, "faces_training.csv") if not os.path.exists(dataset_path): try: os.makedirs(dataset_directory) @@ -185,8 +215,11 @@ def main(args): if e.errno != errno.EEXIST: raise pass - wget.download('https://d2hg8soec8ck9v.cloudfront.net/datasets/faces_training.csv', dataset_path) - data = torch.tensor(np.loadtxt(dataset_path, delimiter=',')).float() + wget.download( + "https://d2hg8soec8ck9v.cloudfront.net/datasets/faces_training.csv", + dataset_path, + ) + data = torch.tensor(np.loadtxt(dataset_path, delimiter=",")).float() sparse_gamma_def = SparseGammaDEF() @@ -194,14 +227,14 @@ def main(args): # seems to be more amenable to higher learning rates. # Nevertheless, the easy guide performs the best (presumably because of numerical instabilities # related to the gamma distribution in the custom guide). - learning_rate = 0.2 if args.guide in ['auto', 'easy'] else 4.5 - momentum = 0.05 if args.guide in ['auto', 'easy'] else 0.1 + learning_rate = 0.2 if args.guide in ["auto", "easy"] else 4.5 + momentum = 0.05 if args.guide in ["auto", "easy"] else 0.1 opt = optim.AdagradRMSProp({"eta": learning_rate, "t": momentum}) # use one of our three different guide types - if args.guide == 'auto': + if args.guide == "auto": guide = AutoDiagonalNormal(sparse_gamma_def.model, init_loc_fn=init_to_feasible) - elif args.guide == 'easy': + elif args.guide == "easy": guide = MyEasyGuide(sparse_gamma_def.model) else: guide = sparse_gamma_def.guide @@ -212,16 +245,22 @@ def main(args): # we use svi_eval during evaluation; since we took care to write down our model in # a fully vectorized way, this computation can be done efficiently with large tensor ops - svi_eval = SVI(sparse_gamma_def.model, guide, opt, - loss=TraceMeanField_ELBO(num_particles=args.eval_particles, vectorize_particles=True)) + svi_eval = SVI( + sparse_gamma_def.model, + guide, + opt, + loss=TraceMeanField_ELBO( + num_particles=args.eval_particles, vectorize_particles=True + ), + ) - print('\nbeginning training with %s guide...' % args.guide) + print("\nbeginning training with %s guide..." % args.guide) # the training loop for k in range(args.num_epochs): loss = svi.step(data) # for the custom guide we clip parameters after each gradient step - if args.guide == 'custom': + if args.guide == "custom": clip_params() if k % args.eval_frequency == 0 and k > 0 or k == args.num_epochs - 1: @@ -229,17 +268,30 @@ def main(args): print("[epoch %04d] training elbo: %.4g" % (k, -loss)) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', default=1500, type=int, help='number of training epochs') - parser.add_argument('-ef', '--eval-frequency', default=25, type=int, - help='how often to evaluate elbo (number of epochs)') - parser.add_argument('-ep', '--eval-particles', default=20, type=int, - help='number of samples/particles to use during evaluation') - parser.add_argument('--guide', default='custom', type=str, - help='use a custom, auto, or easy guide') + parser.add_argument( + "-n", "--num-epochs", default=1500, type=int, help="number of training epochs" + ) + parser.add_argument( + "-ef", + "--eval-frequency", + default=25, + type=int, + help="how often to evaluate elbo (number of epochs)", + ) + parser.add_argument( + "-ep", + "--eval-particles", + default=20, + type=int, + help="number of samples/particles to use during evaluation", + ) + parser.add_argument( + "--guide", default="custom", type=str, help="use a custom, auto, or easy guide" + ) args = parser.parse_args() - assert args.guide in ['custom', 'auto', 'easy'] + assert args.guide in ["custom", "auto", "easy"] main(args) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 53521deb52..1bd2600e8d 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -41,7 +41,7 @@ """ -torch.set_default_tensor_type('torch.FloatTensor') +torch.set_default_tensor_type("torch.FloatTensor") def dot(X, Z): @@ -60,27 +60,34 @@ def kernel(X, Z, eta1, eta2, c): # Most of the model code is concerned with constructing the sparsity inducing prior. def model(X, Y, hypers, jitter=1.0e-4): - S, P, N = hypers['expected_sparsity'], X.size(1), X.size(0) + S, P, N = hypers["expected_sparsity"], X.size(1), X.size(0) - sigma = pyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) + sigma = pyro.sample("sigma", dist.HalfNormal(hypers["alpha3"])) phi = sigma * (S / math.sqrt(N)) / (P - S) eta1 = pyro.sample("eta1", dist.HalfCauchy(phi)) - msq = pyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) - xisq = pyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) + msq = pyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"])) + xisq = pyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"])) eta2 = eta1.pow(2.0) * xisq.sqrt() / msq - lam = pyro.sample("lambda", dist.HalfCauchy(torch.ones(P, device=X.device)).to_event(1)) + lam = pyro.sample( + "lambda", dist.HalfCauchy(torch.ones(P, device=X.device)).to_event(1) + ) kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt() kX = kappa * X # compute the kernel for the given hyperparameters - k = kernel(kX, kX, eta1, eta2, hypers['c']) + (sigma ** 2 + jitter) * torch.eye(N, device=X.device) + k = kernel(kX, kX, eta1, eta2, hypers["c"]) + (sigma ** 2 + jitter) * torch.eye( + N, device=X.device + ) # observe the outputs Y - pyro.sample("Y", dist.MultivariateNormal(torch.zeros(N, device=X.device), covariance_matrix=k), - obs=Y) + pyro.sample( + "Y", + dist.MultivariateNormal(torch.zeros(N, device=X.device), covariance_matrix=k), + obs=Y, + ) """ @@ -108,25 +115,39 @@ def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter=1.0e-4) kprobe = kprobe.reshape(-1, P) # compute various kernels - k_xx = kernel(kX, kX, eta1, eta2, c) + (jitter + sigma ** 2) * torch.eye(N, dtype=X.dtype, device=X.device) + k_xx = kernel(kX, kX, eta1, eta2, c) + (jitter + sigma ** 2) * torch.eye( + N, dtype=X.dtype, device=X.device + ) k_xx_inv = torch.inverse(k_xx) k_probeX = kernel(kprobe, kX, eta1, eta2, c) k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c) # compute mean and variance for singleton weights vec = torch.tensor([0.50, -0.50], dtype=X.dtype, device=X.device) - mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(P, 2) + mu = ( + torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)) + .squeeze(-1) + .reshape(P, 2) + ) mu = (mu * vec).sum(-1) var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t())) var = var.reshape(P, 2, P, 2).diagonal(dim1=-4, dim2=-2) # 2 2 P - std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt() + std = ( + ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)) + .sum(-2) + .clamp(min=0.0) + .sqrt() + ) # select active dimensions (those that are non-zero with sufficient statistical significance) active_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)).bool() active_dims = active_dims.nonzero(as_tuple=False).squeeze(-1) - print("Identified the following active dimensions:", active_dims.data.numpy().flatten()) + print( + "Identified the following active dimensions:", + active_dims.data.numpy().flatten(), + ) print("Mean estimate for active singleton weights:\n", mu[active_dims].data.numpy()) # if there are 0 or 1 active dimensions there are no quadratic weights to be found @@ -154,19 +175,39 @@ def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter=1.0e-4) # compute mean and covariance for a subset of weights theta_ij (namely those with # 'active' dimensions i and j) vec = torch.tensor([0.25, -0.25, -0.25, 0.25], dtype=X.dtype, device=X.device) - mu = torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)).squeeze(-1).reshape(left_dims.size(0), 4) + mu = ( + torch.matmul(k_probeX, torch.matmul(k_xx_inv, Y).unsqueeze(-1)) + .squeeze(-1) + .reshape(left_dims.size(0), 4) + ) mu = (mu * vec).sum(-1) var = k_prbprb - torch.matmul(k_probeX, torch.matmul(k_xx_inv, k_probeX.t())) - var = var.reshape(left_dims.size(0), 4, left_dims.size(0), 4).diagonal(dim1=-4, dim2=-2) - std = ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)).sum(-2).clamp(min=0.0).sqrt() - - active_quad_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)) & (mu.abs() > 1.0e-4).bool() + var = var.reshape(left_dims.size(0), 4, left_dims.size(0), 4).diagonal( + dim1=-4, dim2=-2 + ) + std = ( + ((var * vec.unsqueeze(-1)).sum(-2) * vec.unsqueeze(-1)) + .sum(-2) + .clamp(min=0.0) + .sqrt() + ) + + active_quad_dims = (((mu - 4.0 * std) > 0.0) | ((mu + 4.0 * std) < 0.0)) & ( + mu.abs() > 1.0e-4 + ).bool() active_quad_dims = active_quad_dims.nonzero(as_tuple=False) - active_quadratic_dims = np.stack([left_dims[active_quad_dims].data.numpy().flatten(), - right_dims[active_quad_dims].data.numpy().flatten()], axis=1) - active_quadratic_dims = np.split(active_quadratic_dims, active_quadratic_dims.shape[0]) + active_quadratic_dims = np.stack( + [ + left_dims[active_quad_dims].data.numpy().flatten(), + right_dims[active_quad_dims].data.numpy().flatten(), + ], + axis=1, + ) + active_quadratic_dims = np.split( + active_quadratic_dims, active_quadratic_dims.shape[0] + ) active_quadratic_dims = [tuple(a.tolist()[0]) for a in active_quadratic_dims] return active_dims.data.numpy(), active_quadratic_dims @@ -222,17 +263,24 @@ def init_loc_fn(site): def main(args): # setup hyperparameters for the model - hypers = {'expected_sparsity': max(1.0, args.num_dimensions / 10), - 'alpha1': 3.0, 'beta1': 1.0, 'alpha2': 3.0, 'beta2': 1.0, 'alpha3': 1.0, - 'c': 1.0} + hypers = { + "expected_sparsity": max(1.0, args.num_dimensions / 10), + "alpha1": 3.0, + "beta1": 1.0, + "alpha2": 3.0, + "beta2": 1.0, + "alpha3": 1.0, + "c": 1.0, + } P = args.num_dimensions S = args.active_dimensions Q = args.quadratic_dimensions # generate artificial dataset - X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data, P=P, S=S, - Q=Q, sigma_obs=args.sigma) + X, Y, expected_thetas, expected_quad_dims = get_data( + N=args.num_data, P=P, S=S, Q=Q, sigma_obs=args.sigma + ) loss_fn = Trace_ELBO().differentiable_loss @@ -270,7 +318,7 @@ def main(args): # we manually reduce the learning rate according to this schedule if step in [100, 300, 700, 900]: - adam.param_groups[0]['lr'] *= 0.2 + adam.param_groups[0]["lr"] *= 0.2 if step % report_frequency == 0 or step == args.num_steps - 1: print("[step %04d] loss: %.5f" % (step, loss)) @@ -279,11 +327,16 @@ def main(args): # we do the final computation using double precision median = guide.median() # == mode for MAP inference - active_dims, active_quad_dims = \ - compute_posterior_stats(X.double(), Y.double(), median['msq'].double(), - median['lambda'].double(), median['eta1'].double(), - median['xisq'].double(), torch.tensor(hypers['c']).double(), - median['sigma'].double()) + active_dims, active_quad_dims = compute_posterior_stats( + X.double(), + Y.double(), + median["msq"].double(), + median["lambda"].double(), + median["eta1"].double(), + median["xisq"].double(), + torch.tensor(hypers["c"]).double(), + median["sigma"].double(), + ) expected_active_dims = np.arange(S).tolist() @@ -300,23 +353,27 @@ def main(args): # We report how well we did, i.e. did we recover the sparse set of coefficients # that we expected for our artificial dataset? print("[SUMMARY STATS]") - print("Singletons (true positive, false positive, false negative): " + - "(%d, %d, %d)" % singleton_stats) - print("Quadratic (true positive, false positive, false negative): " + - "(%d, %d, %d)" % quad_stats) - - -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='Krylov KIT') - parser.add_argument('--num-data', type=int, default=750) - parser.add_argument('--num-steps', type=int, default=1000) - parser.add_argument('--num-dimensions', type=int, default=100) - parser.add_argument('--num-restarts', type=int, default=10) - parser.add_argument('--sigma', type=float, default=0.05) - parser.add_argument('--active-dimensions', type=int, default=10) - parser.add_argument('--quadratic-dimensions', type=int, default=5) - parser.add_argument('--lr', type=float, default=0.3) + print( + "Singletons (true positive, false positive, false negative): " + + "(%d, %d, %d)" % singleton_stats + ) + print( + "Quadratic (true positive, false positive, false negative): " + + "(%d, %d, %d)" % quad_stats + ) + + +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser(description="Krylov KIT") + parser.add_argument("--num-data", type=int, default=750) + parser.add_argument("--num-steps", type=int, default=1000) + parser.add_argument("--num-dimensions", type=int, default=100) + parser.add_argument("--num-restarts", type=int, default=10) + parser.add_argument("--sigma", type=float, default=0.05) + parser.add_argument("--active-dimensions", type=int, default=10) + parser.add_argument("--quadratic-dimensions", type=int, default=5) + parser.add_argument("--lr", type=float, default=0.3) args = parser.parse_args() main(args) diff --git a/examples/svi_horovod.py b/examples/svi_horovod.py index 3ab9e37ee1..bb3aac9995 100644 --- a/examples/svi_horovod.py +++ b/examples/svi_horovod.py @@ -56,8 +56,7 @@ def forward(self, covariates, data=None): # identical subsamples. with pyro.plate("data", self.size, len(covariates)): loc = bias + coeff * covariates - return pyro.sample("obs", dist.Normal(loc, scale), - obs=data) + return pyro.sample("obs", dist.Normal(loc, scale), obs=data) # The following is a standard training loop. To emphasize the Horovod-specific @@ -73,6 +72,7 @@ def main(args): if args.horovod: # Initialize Horovod and set PyTorch globals. import horovod.torch as hvd + hvd.init() torch.set_num_threads(1) if args.cuda: @@ -100,7 +100,8 @@ def main(args): if args.horovod: # Horovod requires a distributed sampler. sampler = torch.utils.data.distributed.DistributedSampler( - dataset, hvd.size(), hvd.rank()) + dataset, hvd.size(), hvd.rank() + ) else: sampler = torch.utils.data.RandomSampler(dataset) config = {"batch_size": args.batch_size, "sampler": sampler} @@ -108,8 +109,11 @@ def main(args): config["num_workers"] = 1 config["pin_memory"] = True # Try to use forkserver to spawn workers instead of fork. - if (hasattr(mp, "_supports_context") and mp._supports_context and - "forkserver" in mp.get_all_start_methods()): + if ( + hasattr(mp, "_supports_context") + and mp._supports_context + and "forkserver" in mp.get_all_start_methods() + ): config["multiprocessing_context"] = "forkserver" dataloader = torch.utils.data.DataLoader(dataset, **config) @@ -150,7 +154,7 @@ def main(args): if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Distributed training via Horovod") parser.add_argument("-o", "--outfile") parser.add_argument("-s", "--size", default=1000000, type=int) diff --git a/examples/toy_mixture_model_discrete_enumeration.py b/examples/toy_mixture_model_discrete_enumeration.py index 657b49f7d3..9959392e02 100644 --- a/examples/toy_mixture_model_discrete_enumeration.py +++ b/examples/toy_mixture_model_discrete_enumeration.py @@ -47,51 +47,56 @@ def main(args): def generate_data(num_obs): # domain = [False, True] - prior = {'A': torch.tensor([1., 10.]), - 'B': torch.tensor([[10., 1.], - [1., 10.]]), - 'C': torch.tensor([[10., 1.], - [1., 10.]])} - CPDs = {'p_A': Beta(prior['A'][0], prior['A'][1]).sample(), - 'p_B': Beta(prior['B'][:, 0], prior['B'][:, 1]).sample(), - 'p_C': Beta(prior['C'][:, 0], prior['C'][:, 1]).sample(), - } - data = {'A': Bernoulli(torch.ones(num_obs) * CPDs['p_A']).sample()} - data['B'] = Bernoulli(torch.gather(CPDs['p_B'], 0, data['A'].type(torch.long))).sample() - data['C'] = Bernoulli(torch.gather(CPDs['p_C'], 0, data['B'].type(torch.long))).sample() + prior = { + "A": torch.tensor([1.0, 10.0]), + "B": torch.tensor([[10.0, 1.0], [1.0, 10.0]]), + "C": torch.tensor([[10.0, 1.0], [1.0, 10.0]]), + } + CPDs = { + "p_A": Beta(prior["A"][0], prior["A"][1]).sample(), + "p_B": Beta(prior["B"][:, 0], prior["B"][:, 1]).sample(), + "p_C": Beta(prior["C"][:, 0], prior["C"][:, 1]).sample(), + } + data = {"A": Bernoulli(torch.ones(num_obs) * CPDs["p_A"]).sample()} + data["B"] = Bernoulli( + torch.gather(CPDs["p_B"], 0, data["A"].type(torch.long)) + ).sample() + data["C"] = Bernoulli( + torch.gather(CPDs["p_C"], 0, data["B"].type(torch.long)) + ).sample() return prior, CPDs, data @pyro.infer.config_enumerate def model(prior, obs, num_obs): - p_A = pyro.sample('p_A', dist.Beta(1, 1)) - p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1)) - p_C = pyro.sample('p_C', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1)) - with pyro.plate('data_plate', num_obs): - A = pyro.sample('A', dist.Bernoulli(p_A.expand(num_obs)), obs=obs['A']) + p_A = pyro.sample("p_A", dist.Beta(1, 1)) + p_B = pyro.sample("p_B", dist.Beta(torch.ones(2), torch.ones(2)).to_event(1)) + p_C = pyro.sample("p_C", dist.Beta(torch.ones(2), torch.ones(2)).to_event(1)) + with pyro.plate("data_plate", num_obs): + A = pyro.sample("A", dist.Bernoulli(p_A.expand(num_obs)), obs=obs["A"]) # Vindex used to ensure proper indexing into the enumerated sample sites - B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"}) - pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C']) + B = pyro.sample( + "B", + dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("C", dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs["C"]) def guide(prior, obs, num_obs): - a = pyro.param('a', prior['A'], constraint=constraints.positive) - pyro.sample('p_A', dist.Beta(a[0], a[1])) - b = pyro.param('b', prior['B'], constraint=constraints.positive) - pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1)) - c = pyro.param('c', prior['C'], constraint=constraints.positive) - pyro.sample('p_C', dist.Beta(c[:, 0], c[:, 1]).to_event(1)) + a = pyro.param("a", prior["A"], constraint=constraints.positive) + pyro.sample("p_A", dist.Beta(a[0], a[1])) + b = pyro.param("b", prior["B"], constraint=constraints.positive) + pyro.sample("p_B", dist.Beta(b[:, 0], b[:, 1]).to_event(1)) + c = pyro.param("c", prior["C"], constraint=constraints.positive) + pyro.sample("p_C", dist.Beta(c[:, 0], c[:, 1]).to_event(1)) def train(prior, data, num_steps, num_obs): pyro.clear_param_store() # max_plate_nesting = 1 because there is a single plate in the model loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1) - svi = pyro.infer.SVI(model, - guide, - pyro.optim.Adam({'lr': .01}), - loss=loss_func - ) + svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), loss=loss_func) losses = [] for _ in tqdm(range(num_steps)): loss = svi.step(prior, data, num_obs) @@ -100,33 +105,35 @@ def train(prior, data, num_steps, num_obs): plt.plot(losses) plt.show() posterior_params = {k: np.array(v.data) for k, v in pyro.get_param_store().items()} - posterior_params['a'] = posterior_params['a'][None, :] # reshape to same as other variables + posterior_params["a"] = posterior_params["a"][ + None, : + ] # reshape to same as other variables return posterior_params def evaluate(CPDs, posterior_params): - true_p_A, pred_p_A = get_true_pred_CPDs(CPDs['p_A'], posterior_params['a']) - true_p_B, pred_p_B = get_true_pred_CPDs(CPDs['p_B'], posterior_params['b']) - true_p_C, pred_p_C = get_true_pred_CPDs(CPDs['p_C'], posterior_params['c']) - print('\np_A = True') - print('actual: ', true_p_A) - print('predicted:', pred_p_A) - print('\np_B = True | A = False/True') - print('actual: ', true_p_B) - print('predicted:', pred_p_B) - print('\np_C = True | B = False/True') - print('actual: ', true_p_C) - print('predicted:', pred_p_C) + true_p_A, pred_p_A = get_true_pred_CPDs(CPDs["p_A"], posterior_params["a"]) + true_p_B, pred_p_B = get_true_pred_CPDs(CPDs["p_B"], posterior_params["b"]) + true_p_C, pred_p_C = get_true_pred_CPDs(CPDs["p_C"], posterior_params["c"]) + print("\np_A = True") + print("actual: ", true_p_A) + print("predicted:", pred_p_A) + print("\np_B = True | A = False/True") + print("actual: ", true_p_B) + print("predicted:", pred_p_B) + print("\np_C = True | B = False/True") + print("actual: ", true_p_C) + print("predicted:", pred_p_C) def get_true_pred_CPDs(CPD, posterior_param): true_p = CPD.numpy() - pred_p = posterior_param[:, 0]/np.sum(posterior_param, axis=1) + pred_p = posterior_param[:, 0] / np.sum(posterior_param, axis=1) return true_p, pred_p if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="Toy mixture model") parser.add_argument("-n", "--num-steps", default=4000, type=int) parser.add_argument("-o", "--num-obs", default=10000, type=int) diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index a5a5b3c490..e70f3a84b0 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -40,8 +40,17 @@ class SSVAE(nn.Module): :param use_cuda: use GPUs for faster training :param aux_loss_multiplier: the multiplier to use with the auxiliary loss """ - def __init__(self, output_size=10, input_size=784, z_dim=50, hidden_layers=(500,), - config_enum=None, use_cuda=False, aux_loss_multiplier=None): + + def __init__( + self, + output_size=10, + input_size=784, + z_dim=50, + hidden_layers=(500,), + config_enum=None, + use_cuda=False, + aux_loss_multiplier=None, + ): super().__init__() @@ -50,7 +59,7 @@ def __init__(self, output_size=10, input_size=784, z_dim=50, hidden_layers=(500, self.input_size = input_size self.z_dim = z_dim self.hidden_layers = hidden_layers - self.allow_broadcast = config_enum == 'parallel' + self.allow_broadcast = config_enum == "parallel" self.use_cuda = use_cuda self.aux_loss_multiplier = aux_loss_multiplier @@ -67,29 +76,33 @@ def setup_networks(self): # these networks are MLPs (multi-layered perceptrons or simple feed-forward networks) # where the provided activation parameter is used on every linear layer except # for the output layer where we use the provided output_activation parameter - self.encoder_y = MLP([self.input_size] + hidden_sizes + [self.output_size], - activation=nn.Softplus, - output_activation=nn.Softmax, - allow_broadcast=self.allow_broadcast, - use_cuda=self.use_cuda) + self.encoder_y = MLP( + [self.input_size] + hidden_sizes + [self.output_size], + activation=nn.Softplus, + output_activation=nn.Softmax, + allow_broadcast=self.allow_broadcast, + use_cuda=self.use_cuda, + ) # a split in the final layer's size is used for multiple outputs # and potentially applying separate activation functions on them # e.g. in this network the final output is of size [z_dim,z_dim] # to produce loc and scale, and apply different activations [None,Exp] on them - self.encoder_z = MLP([self.input_size + self.output_size] + - hidden_sizes + [[z_dim, z_dim]], - activation=nn.Softplus, - output_activation=[None, Exp], - allow_broadcast=self.allow_broadcast, - use_cuda=self.use_cuda) - - self.decoder = MLP([z_dim + self.output_size] + - hidden_sizes + [self.input_size], - activation=nn.Softplus, - output_activation=nn.Sigmoid, - allow_broadcast=self.allow_broadcast, - use_cuda=self.use_cuda) + self.encoder_z = MLP( + [self.input_size + self.output_size] + hidden_sizes + [[z_dim, z_dim]], + activation=nn.Softplus, + output_activation=[None, Exp], + allow_broadcast=self.allow_broadcast, + use_cuda=self.use_cuda, + ) + + self.decoder = MLP( + [z_dim + self.output_size] + hidden_sizes + [self.input_size], + activation=nn.Softplus, + output_activation=nn.Sigmoid, + allow_broadcast=self.allow_broadcast, + use_cuda=self.use_cuda, + ) # using GPUs for faster training of the networks if self.use_cuda: @@ -122,7 +135,9 @@ def model(self, xs, ys=None): # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) - alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size) + alpha_prior = torch.ones(batch_size, self.output_size, **options) / ( + 1.0 * self.output_size + ) ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # Finally, score the image (x) using the handwriting style (z) and @@ -131,8 +146,9 @@ def model(self, xs, ys=None): # where `decoder` is a neural network. We disable validation # since the decoder output is a relaxed Bernoulli value. loc = self.decoder.forward([zs, ys]) - pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), - obs=xs) + pyro.sample( + "x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs + ) # return the loc so we can visualize it later return loc @@ -219,8 +235,8 @@ def run_inference_for_epoch(data_loaders, losses, periodic_interval_batches): batches_per_epoch = sup_batches + unsup_batches # initialize variables to store loss values - epoch_losses_sup = [0.] * num_losses - epoch_losses_unsup = [0.] * num_losses + epoch_losses_sup = [0.0] * num_losses + epoch_losses_unsup = [0.0] * num_losses # setup the iterators for training data loaders sup_iter = iter(data_loaders["sup"]) @@ -271,7 +287,7 @@ def get_accuracy(data_loader, classifier_fn, batch_size): for pred, act in zip(predictions, actuals): for i in range(pred.size(0)): v = torch.sum(pred[i] == act[i]) - accurate_preds += (v.item() == 10) + accurate_preds += v.item() == 10 # calculate the accuracy between 0 and 1 accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size) @@ -299,11 +315,13 @@ def main(args): mkdir_p("./vae_results") # batch_size: number of images (and labels) to be considered in a batch - ss_vae = SSVAE(z_dim=args.z_dim, - hidden_layers=args.hidden_layers, - use_cuda=args.cuda, - config_enum=args.enum_discrete, - aux_loss_multiplier=args.aux_loss_multiplier) + ss_vae = SSVAE( + z_dim=args.z_dim, + hidden_layers=args.hidden_layers, + use_cuda=args.cuda, + config_enum=args.enum_discrete, + aux_loss_multiplier=args.aux_loss_multiplier, + ) # setup the optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)} @@ -322,19 +340,25 @@ def main(args): # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al) if args.aux_loss: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() - loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo) + loss_aux = SVI( + ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo + ) losses.append(loss_aux) try: # setup the logger if a filename is provided logger = open(args.logfile, "w") if args.logfile else None - data_loaders = setup_data_loaders(MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num) + data_loaders = setup_data_loaders( + MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num + ) # how often would a supervised batch be encountered during inference # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised # until we have traversed through the all supervised batches - periodic_interval_batches = int(MNISTCached.train_data_size / (1.0 * args.sup_num)) + periodic_interval_batches = int( + MNISTCached.train_data_size / (1.0 * args.sup_num) + ) # number of unsupervised examples unsup_num = MNISTCached.train_data_size - args.sup_num @@ -348,8 +372,9 @@ def main(args): for i in range(0, args.num_epochs): # get the losses for an epoch - epoch_losses_sup, epoch_losses_unsup = \ - run_inference_for_epoch(data_loaders, losses, periodic_interval_batches) + epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch( + data_loaders, losses, periodic_interval_batches + ) # compute average epoch losses i.e. losses per example avg_epoch_losses_sup = map(lambda v: v / args.sup_num, epoch_losses_sup) @@ -359,14 +384,20 @@ def main(args): str_loss_sup = " ".join(map(str, avg_epoch_losses_sup)) str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup)) - str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup)) + str_print = "{} epoch: avg losses {}".format( + i, "{} {}".format(str_loss_sup, str_loss_unsup) + ) - validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, args.batch_size) + validation_accuracy = get_accuracy( + data_loaders["valid"], ss_vae.classifier, args.batch_size + ) str_print += " validation accuracy {}".format(validation_accuracy) # this test accuracy is only for logging, this is not used # to make any decisions during training - test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size) + test_accuracy = get_accuracy( + data_loaders["test"], ss_vae.classifier, args.batch_size + ) str_print += " test accuracy {}".format(test_accuracy) # update the best validation accuracy and the corresponding @@ -377,9 +408,16 @@ def main(args): print_and_log(logger, str_print) - final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size) - print_and_log(logger, "best validation accuracy {} corresponding testing accuracy {} " - "last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy)) + final_test_accuracy = get_accuracy( + data_loaders["test"], ss_vae.classifier, args.batch_size + ) + print_and_log( + logger, + "best validation accuracy {} corresponding testing accuracy {} " + "last testing accuracy {}".format( + best_valid_acc, corresponding_test_acc, final_test_accuracy + ), + ) # visualize the conditional samples visualize(ss_vae, viz, data_loaders["test"]) @@ -389,56 +427,120 @@ def main(args): logger.close() -EXAMPLE_RUN = "example run: python ss_vae_M2.py --seed 0 --cuda -n 2 --aux-loss -alm 46 -enum parallel " \ - "-sup 3000 -zd 50 -hl 500 -lr 0.00042 -b1 0.95 -bs 200 -log ./tmp.log" +EXAMPLE_RUN = ( + "example run: python ss_vae_M2.py --seed 0 --cuda -n 2 --aux-loss -alm 46 -enum parallel " + "-sup 3000 -zd 50 -hl 500 -lr 0.00042 -b1 0.95 -bs 200 -log ./tmp.log" +) if __name__ == "__main__": - assert pyro.__version__.startswith('1.6.0') + assert pyro.__version__.startswith("1.6.0") parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN)) - parser.add_argument('--cuda', action='store_true', - help="use GPU(s) to speed up training") - parser.add_argument('--jit', action='store_true', - help="use PyTorch jit to speed up training") - parser.add_argument('-n', '--num-epochs', default=50, type=int, - help="number of epochs to run") - parser.add_argument('--aux-loss', action="store_true", - help="whether to use the auxiliary loss from NIPS 14 paper " - "(Kingma et al). It is not used by default ") - parser.add_argument('-alm', '--aux-loss-multiplier', default=46, type=float, - help="the multiplier to use with the auxiliary loss") - parser.add_argument('-enum', '--enum-discrete', default="parallel", - help="parallel, sequential or none. uses parallel enumeration by default") - parser.add_argument('-sup', '--sup-num', default=3000, - type=float, help="supervised amount of the data i.e. " - "how many of the images have supervised labels") - parser.add_argument('-zd', '--z-dim', default=50, type=int, - help="size of the tensor representing the latent variable z " - "variable (handwriting style for our MNIST dataset)") - parser.add_argument('-hl', '--hidden-layers', nargs='+', default=[500], type=int, - help="a tuple (or list) of MLP layers to be used in the neural networks " - "representing the parameters of the distributions in our model") - parser.add_argument('-lr', '--learning-rate', default=0.00042, type=float, - help="learning rate for Adam optimizer") - parser.add_argument('-b1', '--beta-1', default=0.9, type=float, - help="beta-1 parameter for Adam optimizer") - parser.add_argument('-bs', '--batch-size', default=200, type=int, - help="number of images (and labels) to be considered in a batch") - parser.add_argument('-log', '--logfile', default="./tmp.log", type=str, - help="filename for logging the outputs") - parser.add_argument('--seed', default=None, type=int, - help="seed for controlling randomness in this example") - parser.add_argument('--visualize', action="store_true", - help="use a visdom server to visualize the embeddings") + parser.add_argument( + "--cuda", action="store_true", help="use GPU(s) to speed up training" + ) + parser.add_argument( + "--jit", action="store_true", help="use PyTorch jit to speed up training" + ) + parser.add_argument( + "-n", "--num-epochs", default=50, type=int, help="number of epochs to run" + ) + parser.add_argument( + "--aux-loss", + action="store_true", + help="whether to use the auxiliary loss from NIPS 14 paper " + "(Kingma et al). It is not used by default ", + ) + parser.add_argument( + "-alm", + "--aux-loss-multiplier", + default=46, + type=float, + help="the multiplier to use with the auxiliary loss", + ) + parser.add_argument( + "-enum", + "--enum-discrete", + default="parallel", + help="parallel, sequential or none. uses parallel enumeration by default", + ) + parser.add_argument( + "-sup", + "--sup-num", + default=3000, + type=float, + help="supervised amount of the data i.e. " + "how many of the images have supervised labels", + ) + parser.add_argument( + "-zd", + "--z-dim", + default=50, + type=int, + help="size of the tensor representing the latent variable z " + "variable (handwriting style for our MNIST dataset)", + ) + parser.add_argument( + "-hl", + "--hidden-layers", + nargs="+", + default=[500], + type=int, + help="a tuple (or list) of MLP layers to be used in the neural networks " + "representing the parameters of the distributions in our model", + ) + parser.add_argument( + "-lr", + "--learning-rate", + default=0.00042, + type=float, + help="learning rate for Adam optimizer", + ) + parser.add_argument( + "-b1", + "--beta-1", + default=0.9, + type=float, + help="beta-1 parameter for Adam optimizer", + ) + parser.add_argument( + "-bs", + "--batch-size", + default=200, + type=int, + help="number of images (and labels) to be considered in a batch", + ) + parser.add_argument( + "-log", + "--logfile", + default="./tmp.log", + type=str, + help="filename for logging the outputs", + ) + parser.add_argument( + "--seed", + default=None, + type=int, + help="seed for controlling randomness in this example", + ) + parser.add_argument( + "--visualize", + action="store_true", + help="use a visdom server to visualize the embeddings", + ) args = parser.parse_args() # some assertions to make sure that batching math assumptions are met assert args.sup_num % args.batch_size == 0, "assuming simplicity of batching math" - assert MNISTCached.validation_size % args.batch_size == 0, \ - "batch size should divide the number of validation examples" - assert MNISTCached.train_data_size % args.batch_size == 0, \ - "batch size doesn't divide total number of training data examples" - assert MNISTCached.test_size % args.batch_size == 0, "batch size should divide the number of test examples" + assert ( + MNISTCached.validation_size % args.batch_size == 0 + ), "batch size should divide the number of validation examples" + assert ( + MNISTCached.train_data_size % args.batch_size == 0 + ), "batch size doesn't divide total number of training data examples" + assert ( + MNISTCached.test_size % args.batch_size == 0 + ), "batch size should divide the number of test examples" main(args) diff --git a/examples/vae/utils/custom_mlp.py b/examples/vae/utils/custom_mlp.py index cc350a3d0f..0767d7f1b0 100644 --- a/examples/vae/utils/custom_mlp.py +++ b/examples/vae/utils/custom_mlp.py @@ -13,6 +13,7 @@ class Exp(nn.Module): """ a custom module for exponentiation of tensors """ + def __init__(self): super().__init__() @@ -24,6 +25,7 @@ class ConcatModule(nn.Module): """ a custom module for concatenation of tensors """ + def __init__(self, allow_broadcast=False): self.allow_broadcast = allow_broadcast super().__init__() @@ -50,6 +52,7 @@ class ListOutModule(nn.ModuleList): """ a custom module for outputting a list of tensors from a list of nn modules """ + def __init__(self, modules): super().__init__(modules) @@ -73,21 +76,32 @@ def call_nn_op(op): class MLP(nn.Module): - - def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, - post_layer_fct=lambda layer_ix, total_layers, layer: None, - post_act_fct=lambda layer_ix, total_layers, layer: None, - allow_broadcast=False, use_cuda=False): + def __init__( + self, + mlp_sizes, + activation=nn.ReLU, + output_activation=None, + post_layer_fct=lambda layer_ix, total_layers, layer: None, + post_act_fct=lambda layer_ix, total_layers, layer: None, + allow_broadcast=False, + use_cuda=False, + ): # init the module object super().__init__() assert len(mlp_sizes) >= 2, "Must have input and output layer sizes defined" # get our inputs, outputs, and hidden - input_size, hidden_sizes, output_size = mlp_sizes[0], mlp_sizes[1:-1], mlp_sizes[-1] + input_size, hidden_sizes, output_size = ( + mlp_sizes[0], + mlp_sizes[1:-1], + mlp_sizes[-1], + ) # assume int or list - assert isinstance(input_size, (int, list, tuple)), "input_size must be int, list, tuple" + assert isinstance( + input_size, (int, list, tuple) + ), "input_size must be int, list, tuple" # everything in MLP will be concatted if it's multiple arguments last_layer_size = input_size if type(input_size) == int else sum(input_size) @@ -114,7 +128,9 @@ def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, all_modules.append(cur_linear_layer) # handle post_linear - post_linear = post_layer_fct(layer_ix + 1, len(hidden_sizes), all_modules[-1]) + post_linear = post_layer_fct( + layer_ix + 1, len(hidden_sizes), all_modules[-1] + ) # if we send something back, add it to sequential # here we could return a batch norm for example @@ -125,7 +141,9 @@ def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, all_modules.append(activation()) # now handle after activation - post_activation = post_act_fct(layer_ix + 1, len(hidden_sizes), all_modules[-1]) + post_activation = post_act_fct( + layer_ix + 1, len(hidden_sizes), all_modules[-1] + ) # handle post_activation if not null # could add batch norm for example @@ -137,13 +155,18 @@ def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, # now we have all of our hidden layers # we handle outputs - assert isinstance(output_size, (int, list, tuple)), "output_size must be int, list, tuple" + assert isinstance( + output_size, (int, list, tuple) + ), "output_size must be int, list, tuple" if type(output_size) == int: all_modules.append(nn.Linear(last_layer_size, output_size)) if output_activation is not None: - all_modules.append(call_nn_op(output_activation) - if isclass(output_activation) else output_activation) + all_modules.append( + call_nn_op(output_activation) + if isclass(output_activation) + else output_activation + ) else: # we're going to have a bunch of separate layers we can spit out (a tuple of outputs) @@ -159,14 +182,18 @@ def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, split_layer.append(nn.Linear(last_layer_size, out_size)) # then we get our output activation (either we repeat all or we index into a same sized array) - act_out_fct = output_activation if not isinstance(output_activation, (list, tuple)) \ + act_out_fct = ( + output_activation + if not isinstance(output_activation, (list, tuple)) else output_activation[out_ix] + ) - if(act_out_fct): + if act_out_fct: # we check if it's a class. if so, instantiate the object # otherwise, use the object directly (e.g. pre-instaniated) - split_layer.append(call_nn_op(act_out_fct) - if isclass(act_out_fct) else act_out_fct) + split_layer.append( + call_nn_op(act_out_fct) if isclass(act_out_fct) else act_out_fct + ) # our outputs is just a sequential of the two out_layers.append(nn.Sequential(*split_layer)) diff --git a/examples/vae/utils/mnist_cached.py b/examples/vae/utils/mnist_cached.py index 3270ab46f0..d7ad578fda 100644 --- a/examples/vae/utils/mnist_cached.py +++ b/examples/vae/utils/mnist_cached.py @@ -19,7 +19,7 @@ # transformations for MNIST data def fn_x_mnist(x, use_cuda): # normalize pixel values of the image to be in [0,1] instead of [0,255] - xp = x * (1. / 255) + xp = x * (1.0 / 255) # transform x to a linear tensor from bx * a1 * a2 * ... --> bs * A xp_1d_size = reduce(lambda a, b: a * b, xp.size()[1:]) @@ -65,7 +65,7 @@ def get_ss_indices_per_class(y, sup_per_class): for j in range(10): np.random.shuffle(idxs_per_class[j]) idxs_sup.extend(idxs_per_class[j][:sup_per_class]) - idxs_unsup.extend(idxs_per_class[j][sup_per_class:len(idxs_per_class[j])]) + idxs_unsup.extend(idxs_per_class[j][sup_per_class : len(idxs_per_class[j])]) return idxs_sup, idxs_unsup @@ -142,44 +142,62 @@ def target_transform(y): self.mode = mode - assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values" + assert mode in [ + "sup", + "unsup", + "test", + "valid", + ], "invalid train/test option values" if mode in ["sup", "unsup", "valid"]: # transform the training data if transformations are provided if transform is not None: - self.data = (transform(self.data.float())) + self.data = transform(self.data.float()) if target_transform is not None: - self.targets = (target_transform(self.targets)) + self.targets = target_transform(self.targets) if MNISTCached.train_data_sup is None: if sup_num is None: assert mode == "unsup" - MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \ - self.data, self.targets + MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = ( + self.data, + self.targets, + ) else: - MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \ - MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \ - MNISTCached.data_valid, MNISTCached.labels_valid = \ - split_sup_unsup_valid(self.data, self.targets, sup_num) + ( + MNISTCached.train_data_sup, + MNISTCached.train_labels_sup, + MNISTCached.train_data_unsup, + MNISTCached.train_labels_unsup, + MNISTCached.data_valid, + MNISTCached.labels_valid, + ) = split_sup_unsup_valid(self.data, self.targets, sup_num) if mode == "sup": - self.data, self.targets = MNISTCached.train_data_sup, MNISTCached.train_labels_sup + self.data, self.targets = ( + MNISTCached.train_data_sup, + MNISTCached.train_labels_sup, + ) elif mode == "unsup": self.data = MNISTCached.train_data_unsup # making sure that the unsupervised labels are not available to inference - self.targets = (torch.Tensor( - MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan + self.targets = ( + torch.Tensor(MNISTCached.train_labels_unsup.shape[0]).view(-1, 1) + ) * np.nan else: - self.data, self.targets = MNISTCached.data_valid, MNISTCached.labels_valid + self.data, self.targets = ( + MNISTCached.data_valid, + MNISTCached.labels_valid, + ) else: # transform the testing data if transformations are provided if transform is not None: - self.data = (transform(self.data.float())) + self.data = transform(self.data.float()) if target_transform is not None: - self.targets = (target_transform(self.targets)) + self.targets = target_transform(self.targets) def __getitem__(self, index): """ @@ -195,7 +213,9 @@ def __getitem__(self, index): return img, target -def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs): +def setup_data_loaders( + dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs +): """ helper function for setting up pytorch data loaders for a semi-supervised dataset :param dataset: the data to use @@ -210,8 +230,8 @@ def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, d # instantiate the dataset as training/testing sets if root is None: root = get_data_directory(__file__) - if 'num_workers' not in kwargs: - kwargs = {'num_workers': 0, 'pin_memory': False} + if "num_workers" not in kwargs: + kwargs = {"num_workers": 0, "pin_memory": False} cached_data = {} loaders = {} @@ -219,9 +239,12 @@ def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, d if sup_num is None and mode == "sup": # in this special case, we do not want "sup" and "valid" data loaders return loaders["unsup"], loaders["test"] - cached_data[mode] = dataset(root=root, mode=mode, download=download, - sup_num=sup_num, use_cuda=use_cuda) - loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs) + cached_data[mode] = dataset( + root=root, mode=mode, download=download, sup_num=sup_num, use_cuda=use_cuda + ) + loaders[mode] = DataLoader( + cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs + ) return loaders @@ -237,5 +260,5 @@ def mkdir_p(path): EXAMPLE_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) -DATA_DIR = os.path.join(EXAMPLE_DIR, 'data') -RESULTS_DIR = os.path.join(EXAMPLE_DIR, 'results') +DATA_DIR = os.path.join(EXAMPLE_DIR, "data") +RESULTS_DIR = os.path.join(EXAMPLE_DIR, "results") diff --git a/examples/vae/utils/vae_plots.py b/examples/vae/utils/vae_plots.py index 3a7eaa4ce2..00a18f2e09 100644 --- a/examples/vae/utils/vae_plots.py +++ b/examples/vae/utils/vae_plots.py @@ -31,15 +31,18 @@ def plot_llk(train_elbo, test_elbo): import pandas as pd import scipy as sp import seaborn as sns + plt.figure(figsize=(30, 10)) sns.set_style("whitegrid") - data = np.concatenate([np.arange(len(test_elbo))[:, sp.newaxis], -test_elbo[:, sp.newaxis]], axis=1) - df = pd.DataFrame(data=data, columns=['Training Epoch', 'Test ELBO']) + data = np.concatenate( + [np.arange(len(test_elbo))[:, sp.newaxis], -test_elbo[:, sp.newaxis]], axis=1 + ) + df = pd.DataFrame(data=data, columns=["Training Epoch", "Test ELBO"]) g = sns.FacetGrid(df, size=10, aspect=1.5) g.map(plt.scatter, "Training Epoch", "Test ELBO") g.map(plt.plot, "Training Epoch", "Test ELBO") - plt.savefig('./vae_results/test_elbo_vae.png') - plt.close('all') + plt.savefig("./vae_results/test_elbo_vae.png") + plt.close("all") def plot_vae_samples(vae, visdom_session): @@ -59,7 +62,7 @@ def mnist_test_tsne(vae=None, test_loader=None): """ This is used to generate a t-sne embedding of the vae """ - name = 'VAE' + name = "VAE" data = test_loader.dataset.test_data.float() mnist_labels = test_loader.dataset.test_labels z_loc, z_scale = vae.encoder(data) @@ -71,7 +74,7 @@ def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None): This is used to generate a t-sne embedding of the ss-vae """ if name is None: - name = 'SS-VAE' + name = "SS-VAE" data = test_loader.dataset.test_data.float() mnist_labels = test_loader.dataset.test_labels z_loc, z_scale = ssvae.encoder_z([data, mnist_labels]) @@ -80,10 +83,12 @@ def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None): def plot_tsne(z_loc, classes, name): import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from sklearn.manifold import TSNE + model_tsne = TSNE(n_components=2, random_state=0) z_states = z_loc.detach().cpu().numpy() z_embed = model_tsne.fit_transform(z_states) @@ -96,5 +101,5 @@ def plot_tsne(z_loc, classes, name): color = plt.cm.Set1(ic) plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color) plt.title("Latent Variable T-SNE per Class") - fig.savefig('./vae_results/'+str(name)+'_embedding_'+str(ic)+'.png') - fig.savefig('./vae_results/'+str(name)+'_embedding.png') + fig.savefig("./vae_results/" + str(name) + "_embedding_" + str(ic) + ".png") + fig.savefig("./vae_results/" + str(name) + "_embedding.png") diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 3f5f21c71c..cca12d7a1c 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -93,10 +93,11 @@ def model(self, x): # decode the latent code z loc_img = self.decoder.forward(z) # score against actual images (with relaxed Bernoulli values) - pyro.sample("obs", - dist.Bernoulli(loc_img, validate_args=False) - .to_event(1), - obs=x.reshape(-1, 784)) + pyro.sample( + "obs", + dist.Bernoulli(loc_img, validate_args=False).to_event(1), + obs=x.reshape(-1, 784), + ) # return the loc so we can visualize it later return loc_img @@ -127,7 +128,9 @@ def main(args): # setup MNIST data loaders # train_loader, test_loader - train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256) + train_loader, test_loader = setup_data_loaders( + MNIST, use_cuda=args.cuda, batch_size=256 + ) # setup the VAE vae = VAE(use_cuda=args.cuda) @@ -149,7 +152,7 @@ def main(args): # training loop for epoch in range(args.num_epochs): # initialize loss accumulator - epoch_loss = 0. + epoch_loss = 0.0 # do a training epoch over each mini-batch x returned # by the data loader for x, _ in train_loader: @@ -163,11 +166,14 @@ def main(args): normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) - print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) + print( + "[epoch %03d] average training loss: %.4f" + % (epoch, total_epoch_loss_train) + ) if epoch % args.test_frequency == 0: # initialize loss accumulator - test_loss = 0. + test_loss = 0.0 # compute the loss over the entire test set for i, (x, _) in enumerate(test_loader): # if on GPU put mini-batch into CUDA memory @@ -185,16 +191,22 @@ def main(args): for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_img(test_img) - vis.image(test_img.reshape(28, 28).detach().cpu().numpy(), - opts={'caption': 'test image'}) - vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(), - opts={'caption': 'reconstructed image'}) + vis.image( + test_img.reshape(28, 28).detach().cpu().numpy(), + opts={"caption": "test image"}, + ) + vis.image( + reco_img.reshape(28, 28).detach().cpu().numpy(), + opts={"caption": "reconstructed image"}, + ) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) - print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) + print( + "[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test) + ) if epoch == args.tsne_iter: mnist_test_tsne(vae=vae, test_loader=test_loader) @@ -203,17 +215,42 @@ def main(args): return vae -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") # parse command line arguments parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs') - parser.add_argument('-tf', '--test-frequency', default=5, type=int, help='how often we evaluate the test set') - parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate') - parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda') - parser.add_argument('--jit', action='store_true', default=False, help='whether to use PyTorch jit') - parser.add_argument('-visdom', '--visdom_flag', action="store_true", help='Whether plotting in visdom is desired') - parser.add_argument('-i-tsne', '--tsne_iter', default=100, type=int, help='epoch when tsne visualization runs') + parser.add_argument( + "-n", "--num-epochs", default=101, type=int, help="number of training epochs" + ) + parser.add_argument( + "-tf", + "--test-frequency", + default=5, + type=int, + help="how often we evaluate the test set", + ) + parser.add_argument( + "-lr", "--learning-rate", default=1.0e-3, type=float, help="learning rate" + ) + parser.add_argument( + "--cuda", action="store_true", default=False, help="whether to use cuda" + ) + parser.add_argument( + "--jit", action="store_true", default=False, help="whether to use PyTorch jit" + ) + parser.add_argument( + "-visdom", + "--visdom_flag", + action="store_true", + help="Whether plotting in visdom is desired", + ) + parser.add_argument( + "-i-tsne", + "--tsne_iter", + default=100, + type=int, + help="epoch when tsne visualization runs", + ) args = parser.parse_args() model = main(args) diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index f97712ef71..88f96e2360 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -26,8 +26,8 @@ Source: https://github.com/pytorch/examples/tree/master/vae """ -TRAIN = 'train' -TEST = 'test' +TRAIN = "train" +TEST = "test" OUTPUT_DIR = RESULTS_DIR @@ -115,8 +115,11 @@ def train(self, epoch): for batch_idx, (x, _) in enumerate(self.train_loader): loss = self.compute_loss_and_gradient(x) train_loss += loss - print('====> Epoch: {} \nTraining loss: {:.4f}'.format( - epoch, train_loss / len(self.train_loader.dataset))) + print( + "====> Epoch: {} \nTraining loss: {:.4f}".format( + epoch, train_loss / len(self.train_loader.dataset) + ) + ) def test(self, epoch): self.set_train(is_train=False) @@ -127,14 +130,17 @@ def test(self, epoch): test_loss += self.compute_loss_and_gradient(x) if i == 0: n = min(x.size(0), 8) - comparison = torch.cat([x[:n], - recon_x.reshape(self.args.batch_size, 1, 28, 28)[:n]]) - save_image(comparison.detach().cpu(), - os.path.join(OUTPUT_DIR, 'reconstruction_' + str(epoch) + '.png'), - nrow=n) + comparison = torch.cat( + [x[:n], recon_x.reshape(self.args.batch_size, 1, 28, 28)[:n]] + ) + save_image( + comparison.detach().cpu(), + os.path.join(OUTPUT_DIR, "reconstruction_" + str(epoch) + ".png"), + nrow=n, + ) test_loss /= len(self.test_loader.dataset) - print('Test set loss: {:.4f}'.format(test_loss)) + print("Test set loss: {:.4f}".format(test_loss)) class PyTorchVAEImpl(VAE): @@ -150,7 +156,9 @@ def __init__(self, *args, **kwargs): def compute_loss_and_gradient(self, x): self.optimizer.zero_grad() recon_x, z_mean, z_var = self.model_eval(x) - binary_cross_entropy = functional.binary_cross_entropy(recon_x, x.reshape(-1, 784)) + binary_cross_entropy = functional.binary_cross_entropy( + recon_x, x.reshape(-1, 784) + ) # Uses analytical KL divergence expression for D_kl(q(z|x) || p(z)) # Refer to Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 @@ -164,7 +172,9 @@ def compute_loss_and_gradient(self, x): return loss.item() def initialize_optimizer(self, lr=1e-3): - model_params = itertools.chain(self.vae_encoder.parameters(), self.vae_decoder.parameters()) + model_params = itertools.chain( + self.vae_encoder.parameters(), self.vae_decoder.parameters() + ) return torch.optim.Adam(model_params, lr) @@ -180,20 +190,22 @@ def __init__(self, *args, **kwargs): self.optimizer = self.initialize_optimizer(lr=1e-3) def model(self, data): - decoder = pyro.module('decoder', self.vae_decoder) + decoder = pyro.module("decoder", self.vae_decoder) z_mean, z_std = torch.zeros([data.size(0), 20]), torch.ones([data.size(0), 20]) - with pyro.plate('data', data.size(0)): - z = pyro.sample('latent', Normal(z_mean, z_std).to_event(1)) + with pyro.plate("data", data.size(0)): + z = pyro.sample("latent", Normal(z_mean, z_std).to_event(1)) img = decoder.forward(z) - pyro.sample('obs', - Bernoulli(img, validate_args=False).to_event(1), - obs=data.reshape(-1, 784)) + pyro.sample( + "obs", + Bernoulli(img, validate_args=False).to_event(1), + obs=data.reshape(-1, 784), + ) def guide(self, data): - encoder = pyro.module('encoder', self.vae_encoder) - with pyro.plate('data', data.size(0)): + encoder = pyro.module("encoder", self.vae_encoder) + with pyro.plate("data", data.size(0)): z_mean, z_var = encoder.forward(data) - pyro.sample('latent', Normal(z_mean, z_var.sqrt()).to_event(1)) + pyro.sample("latent", Normal(z_mean, z_var.sqrt()).to_event(1)) def compute_loss_and_gradient(self, x): if self.mode == TRAIN: @@ -204,23 +216,27 @@ def compute_loss_and_gradient(self, x): return loss def initialize_optimizer(self, lr): - optimizer = Adam({'lr': lr}) + optimizer = Adam({"lr": lr}) elbo = JitTrace_ELBO() if self.args.jit else Trace_ELBO() return SVI(self.model, self.guide, optimizer, loss=elbo) def setup(args): pyro.set_rng_seed(args.rng_seed) - train_loader = util.get_data_loader(dataset_name='MNIST', - data_dir=DATA_DIR, - batch_size=args.batch_size, - is_training_set=True, - shuffle=True) - test_loader = util.get_data_loader(dataset_name='MNIST', - data_dir=DATA_DIR, - batch_size=args.batch_size, - is_training_set=False, - shuffle=True) + train_loader = util.get_data_loader( + dataset_name="MNIST", + data_dir=DATA_DIR, + batch_size=args.batch_size, + is_training_set=True, + shuffle=True, + ) + test_loader = util.get_data_loader( + dataset_name="MNIST", + data_dir=DATA_DIR, + batch_size=args.batch_size, + is_training_set=False, + shuffle=True, + ) global OUTPUT_DIR OUTPUT_DIR = os.path.join(RESULTS_DIR, args.impl) if not os.path.exists(OUTPUT_DIR): @@ -231,29 +247,29 @@ def setup(args): def main(args): train_loader, test_loader = setup(args) - if args.impl == 'pyro': + if args.impl == "pyro": vae = PyroVAEImpl(args, train_loader, test_loader) - print('Running Pyro VAE implementation') - elif args.impl == 'pytorch': + print("Running Pyro VAE implementation") + elif args.impl == "pytorch": vae = PyTorchVAEImpl(args, train_loader, test_loader) - print('Running PyTorch VAE implementation') + print("Running PyTorch VAE implementation") else: - raise ValueError('Incorrect implementation specified: {}'.format(args.impl)) + raise ValueError("Incorrect implementation specified: {}".format(args.impl)) for i in range(args.num_epochs): vae.train(i) if not args.skip_eval: vae.test(i) -if __name__ == '__main__': - assert pyro.__version__.startswith('1.6.0') - parser = argparse.ArgumentParser(description='VAE using MNIST dataset') - parser.add_argument('-n', '--num-epochs', nargs='?', default=10, type=int) - parser.add_argument('--batch_size', nargs='?', default=128, type=int) - parser.add_argument('--rng_seed', nargs='?', default=0, type=int) - parser.add_argument('--impl', nargs='?', default='pyro', type=str) - parser.add_argument('--skip_eval', action='store_true') - parser.add_argument('--jit', action='store_true') +if __name__ == "__main__": + assert pyro.__version__.startswith("1.6.0") + parser = argparse.ArgumentParser(description="VAE using MNIST dataset") + parser.add_argument("-n", "--num-epochs", nargs="?", default=10, type=int) + parser.add_argument("--batch_size", nargs="?", default=128, type=int) + parser.add_argument("--rng_seed", nargs="?", default=0, type=int) + parser.add_argument("--impl", nargs="?", default="pyro", type=str) + parser.add_argument("--skip_eval", action="store_true") + parser.add_argument("--jit", action="store_true") parser.set_defaults(skip_eval=False) args = parser.parse_args() main(args) diff --git a/profiler/distributions.py b/profiler/distributions.py index a4f19ed6b1..e739b0b8de 100644 --- a/profiler/distributions.py +++ b/profiler/distributions.py @@ -27,51 +27,39 @@ def T(arr): return Variable(torch.DoubleTensor(arr)) -TOOL = 'timeit' +TOOL = "timeit" TOOL_CFG = {} DISTRIBUTIONS = { - 'Bernoulli': (Bernoulli, { - 'probs': T([0.3, 0.3, 0.3, 0.3]) - }), - 'Beta': (Beta, { - 'concentration1': T([2.4, 2.4, 2.4, 2.4]), - 'concentration0': T([3.2, 3.2, 3.2, 3.2]) - }), - 'Categorical': (Categorical, { - 'probs': T([0.1, 0.3, 0.4, 0.2]) - }), - 'OneHotCategorical': (OneHotCategorical, { - 'probs': T([0.1, 0.3, 0.4, 0.2]) - }), - 'Dirichlet': (Dirichlet, { - 'concentration': T([2.4, 3, 6, 6]) - }), - 'Normal': (Normal, { - 'loc': T([0.5, 0.5, 0.5, 0.5]), - 'scale': T([1.2, 1.2, 1.2, 1.2]) - }), - 'LogNormal': (LogNormal, { - 'loc': T([0.5, 0.5, 0.5, 0.5]), - 'scale': T([1.2, 1.2, 1.2, 1.2]) - }), - 'Cauchy': (Cauchy, { - 'loc': T([0.5, 0.5, 0.5, 0.5]), - 'scale': T([1.2, 1.2, 1.2, 1.2]) - }), - 'Exponential': (Exponential, { - 'rate': T([5.5, 3.2, 4.1, 5.6]) - }), - 'Poisson': (Poisson, { - 'rate': T([5.5, 3.2, 4.1, 5.6]) - }), - 'Gamma': (Gamma, { - 'concentration': T([2.4, 2.4, 2.4, 2.4]), - 'rate': T([3.2, 3.2, 3.2, 3.2]) - }), - 'Uniform': (Uniform, { - 'low': T([0, 0, 0, 0]), - 'high': T([4, 4, 4, 4]) - }) + "Bernoulli": (Bernoulli, {"probs": T([0.3, 0.3, 0.3, 0.3])}), + "Beta": ( + Beta, + { + "concentration1": T([2.4, 2.4, 2.4, 2.4]), + "concentration0": T([3.2, 3.2, 3.2, 3.2]), + }, + ), + "Categorical": (Categorical, {"probs": T([0.1, 0.3, 0.4, 0.2])}), + "OneHotCategorical": (OneHotCategorical, {"probs": T([0.1, 0.3, 0.4, 0.2])}), + "Dirichlet": (Dirichlet, {"concentration": T([2.4, 3, 6, 6])}), + "Normal": ( + Normal, + {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])}, + ), + "LogNormal": ( + LogNormal, + {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])}, + ), + "Cauchy": ( + Cauchy, + {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])}, + ), + "Exponential": (Exponential, {"rate": T([5.5, 3.2, 4.1, 5.6])}), + "Poisson": (Poisson, {"rate": T([5.5, 3.2, 4.1, 5.6])}), + "Gamma": ( + Gamma, + {"concentration": T([2.4, 2.4, 2.4, 2.4]), "rate": T([3.2, 3.2, 3.2, 3.2])}, + ), + "Uniform": (Uniform, {"low": T([0, 0, 0, 0]), "high": T([4, 4, 4, 4])}), } @@ -86,7 +74,11 @@ def get_tool_cfg(): @Profile( tool=get_tool, tool_cfg=get_tool_cfg, - fn_id=lambda dist, batch_size, *args, **kwargs: 'sample_' + dist.dist_class.__name__ + '_N=' + str(batch_size)) + fn_id=lambda dist, batch_size, *args, **kwargs: "sample_" + + dist.dist_class.__name__ + + "_N=" + + str(batch_size), +) def sample(dist, batch_size): return dist.sample(sample_shape=(batch_size,)) @@ -94,27 +86,33 @@ def sample(dist, batch_size): @Profile( tool=get_tool, tool_cfg=get_tool_cfg, - fn_id=lambda dist, batch, *args, **kwargs: # - 'log_prob_' + dist.dist_class.__name__ + '_N=' + str(batch.size()[0])) + fn_id=lambda dist, batch, *args, **kwargs: "log_prob_" # + + dist.dist_class.__name__ + + "_N=" + + str(batch.size()[0]), +) def log_prob(dist, batch): return dist.log_prob(batch) def run_with_tool(tool, dists, batch_sizes): column_widths, field_format, template = None, None, None - if tool == 'timeit': + if tool == "timeit": profile_cols = 2 * len(batch_sizes) column_widths = [14] * (profile_cols + 1) - field_format = [None] + ['{:.6f}'] * profile_cols - template = 'column' - elif tool == 'cprofile': + field_format = [None] + ["{:.6f}"] * profile_cols + template = "column" + elif tool == "cprofile": column_widths = [14, 80] - template = 'row' + template = "row" with profile_print(column_widths, field_format, template) as out: column_headers = [] for size in batch_sizes: - column_headers += ['SAMPLE (N=' + str(size) + ')', 'LOG_PROB (N=' + str(size) + ')'] - out.header(['DISTRIBUTION'] + column_headers) + column_headers += [ + "SAMPLE (N=" + str(size) + ")", + "LOG_PROB (N=" + str(size) + ")", + ] + out.header(["DISTRIBUTION"] + column_headers) for dist_name in dists: Dist, params = DISTRIBUTIONS[dist_name] result_row = [dist_name] @@ -130,45 +128,51 @@ def set_tool_cfg(args): global TOOL, TOOL_CFG TOOL = args.tool tool_cfg = {} - if args.tool == 'timeit': + if args.tool == "timeit": repeat = 5 if args.repeat is not None: repeat = args.repeat - tool_cfg = {'repeat': repeat} + tool_cfg = {"repeat": repeat} TOOL_CFG = tool_cfg def main(): - parser = argparse.ArgumentParser(description='Profiling distributions library using various' 'tools.') + parser = argparse.ArgumentParser( + description="Profiling distributions library using various" "tools." + ) parser.add_argument( - '--tool', - nargs='?', - default='timeit', - help='Profile using tool. One of following should be specified:' - ' ["timeit", "cprofile"]') + "--tool", + nargs="?", + default="timeit", + help="Profile using tool. One of following should be specified:" + ' ["timeit", "cprofile"]', + ) parser.add_argument( - '--batch_sizes', - nargs='*', + "--batch_sizes", + nargs="*", type=int, - help='Batch size of tensor - max of 4 values allowed. ' - 'Default = [10000, 100000]') + help="Batch size of tensor - max of 4 values allowed. " + "Default = [10000, 100000]", + ) parser.add_argument( - '--dists', - nargs='*', + "--dists", + nargs="*", type=str, - help='Run tests on distributions. One or more of following distributions ' + help="Run tests on distributions. One or more of following distributions " 'are supported: ["bernoulli, "beta", "categorical", "dirichlet", ' '"normal", "lognormal", "halfcauchy", "cauchy", "exponential", ' '"poisson", "one_hot_categorical", "gamma", "uniform"] ' - 'Default - Run profiling on all distributions') + "Default - Run profiling on all distributions", + ) parser.add_argument( - '--repeat', - nargs='?', + "--repeat", + nargs="?", default=5, type=int, help='When profiling using "timeit", the number of repetitions to ' - 'use for the profiled function. default=5. The minimum value ' - 'is reported.') + "use for the profiled function. default=5. The minimum value " + "is reported.", + ) args = parser.parse_args() set_tool_cfg(args) dists = args.dists @@ -182,5 +186,5 @@ def main(): run_with_tool(args.tool, dists, batch_sizes) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/profiler/hmm.py b/profiler/hmm.py index 1825c82c20..770f859d30 100644 --- a/profiler/hmm.py +++ b/profiler/hmm.py @@ -22,8 +22,11 @@ def main(args): configs = [] for model in args.model.split(","): for seed in args.seed.split(","): - config = ["--seed={}".format(seed), "--model={}".format(model), - "--num-steps={}".format(args.num_steps)] + config = [ + "--seed={}".format(seed), + "--model={}".format(model), + "--num-steps={}".format(args.num_steps), + ] if args.cuda: config.append("--cuda") if args.jit: @@ -41,8 +44,10 @@ def main(args): pass for config in configs: with timed() as t: - out = subprocess.check_output((sys.executable, "-O", abspath(join(EXAMPLES_DIR, "hmm.py"))) + config, - encoding="utf-8") + out = subprocess.check_output( + (sys.executable, "-O", abspath(join(EXAMPLES_DIR, "hmm.py"))) + config, + encoding="utf-8", + ) results[config] = t.elapsed if "--jit" in config: matched = re.search(r"time to compile: (\d+\.\d+)", out) @@ -61,8 +66,11 @@ def main(args): print("| Min (sec) | Mean (sec) | Max (sec) | python -O examples/hmm.py ... |") print("| -: | -: | -: | - |") for config, times in sorted(grouped.items()): - print("| {:0.1f} | {:0.1f} | {:0.1f} | {} |".format( - min(times), median(times), max(times), " ".join(config))) + print( + "| {:0.1f} | {:0.1f} | {:0.1f} | {} |".format( + min(times), median(times), max(times), " ".join(config) + ) + ) if __name__ == "__main__": diff --git a/profiler/profiling_utils.py b/profiler/profiling_utils.py index aee4f9b564..da083b89ce 100644 --- a/profiler/profiling_utils.py +++ b/profiler/profiling_utils.py @@ -12,20 +12,19 @@ from prettytable import ALL, PrettyTable FILE = os.path.abspath(__file__) -PROF_DIR = os.path.join(os.path.dirname(FILE), 'data') +PROF_DIR = os.path.join(os.path.dirname(FILE), "data") if not os.path.exists(PROF_DIR): os.makedirs(PROF_DIR) class ProfilePrinter: - - def __init__(self, column_widths=None, field_format=None, template='column'): - assert template in ('column', 'row') + def __init__(self, column_widths=None, field_format=None, template="column"): + assert template in ("column", "row") self._template = template self._column_widths = column_widths self._field_format = field_format self._header = None - if template == 'column': + if template == "column": self.table = PrettyTable(header=False, hrules=ALL) else: self.table = PrettyTable(header=False, hrules=ALL) @@ -33,7 +32,10 @@ def __init__(self, column_widths=None, field_format=None, template='column'): def _formatted_values(self, values): if self._field_format is not None: assert len(self._field_format) == len(values) - return [f.format(val) if f else str(val) for f, val in zip(self._field_format, values)] + return [ + f.format(val) if f else str(val) + for f, val in zip(self._field_format, values) + ] return values def _add_using_row_format(self, values): @@ -47,21 +49,21 @@ def _add_using_column_format(self, values): self.table.add_row(formatted_vals) def push(self, values): - if self._template == 'column': + if self._template == "column": self._add_using_column_format(values) else: self._add_using_row_format(values) def header(self, values): self._header = values - if self._template == 'column': + if self._template == "column": field_names = values self.table.add_row(values) else: - field_names = ['KEY', 'VALUE'] + field_names = ["KEY", "VALUE"] self.table.field_names = field_names for i in range(len(field_names)): - self.table.align[field_names[i]] = 'l' + self.table.align[field_names[i]] = "l" if self._column_widths: self.table.max_width[field_names[i]] = self._column_widths[i] @@ -70,7 +72,7 @@ def print(self): @contextmanager -def profile_print(column_widths=None, field_format=None, template='column'): +def profile_print(column_widths=None, field_format=None, template="column"): out_buffer = ProfilePrinter(column_widths, field_format, template) try: yield out_buffer @@ -89,12 +91,11 @@ def profile_cprofile(fn_callable, prof_file): prof.dump_stats(prof_file) prof_stats = StringIO() p = pstats.Stats(prof_file, stream=prof_stats) - p.strip_dirs().sort_stats('cumulative').print_stats(0.5) + p.strip_dirs().sort_stats("cumulative").print_stats(0.5) return ret, prof_stats.getvalue() class Profile: - def __init__(self, tool, tool_cfg, fn_id): self.tool = tool self.tool_cfg = tool_cfg @@ -107,16 +108,17 @@ def _set_decorator_params(self): self.tool_cfg = self.tool_cfg() def __call__(self, fn): - def wrapped_fn(*args, **kwargs): self._set_decorator_params() fn_callable = functools.partial(fn, *args, **kwargs) - if self.tool == 'timeit': + if self.tool == "timeit": return profile_timeit(fn_callable, **self.tool_cfg) - elif self.tool == 'cprofile': + elif self.tool == "cprofile": prof_file = os.path.join(PROF_DIR, self.fn_id(*args, **kwargs)) return profile_cprofile(fn_callable, prof_file=prof_file) else: - raise ValueError('Invalid profiling tool specified: {}.'.format(self.tool)) + raise ValueError( + "Invalid profiling tool specified: {}.".format(self.tool) + ) return wrapped_fn diff --git a/pyro/__init__.py b/pyro/__init__.py index e862294757..9cfd2a7d78 100644 --- a/pyro/__init__.py +++ b/pyro/__init__.py @@ -25,7 +25,7 @@ from pyro.util import set_rng_seed # After changing this, run scripts/update_version.py -version_prefix = '1.6.0' +version_prefix = "1.6.0" # Get the __version__ string from the auto-generated _version.py file, if exists. try: diff --git a/pyro/contrib/__init__.py b/pyro/contrib/__init__.py index abd412c6e4..6955329860 100644 --- a/pyro/contrib/__init__.py +++ b/pyro/contrib/__init__.py @@ -36,6 +36,7 @@ import funsor as funsor_ # noqa: F401 from pyro.contrib import funsor + __all__ += ["funsor"] except ImportError: pass diff --git a/pyro/contrib/autoguide.py b/pyro/contrib/autoguide.py index 42c5310e17..820b240782 100644 --- a/pyro/contrib/autoguide.py +++ b/pyro/contrib/autoguide.py @@ -5,6 +5,8 @@ from pyro.infer.autoguide import * # noqa F403 -warnings.warn("pyro.contrib.autoguide has moved to pyro.infer.autoguide. " - "The contrib alias will stop working in Pyro 0.5.", - DeprecationWarning) +warnings.warn( + "pyro.contrib.autoguide has moved to pyro.infer.autoguide. " + "The contrib alias will stop working in Pyro 0.5.", + DeprecationWarning, +) diff --git a/pyro/contrib/autoname/named.py b/pyro/contrib/autoname/named.py index d1a3cee225..64532dad9d 100644 --- a/pyro/contrib/autoname/named.py +++ b/pyro/contrib/autoname/named.py @@ -77,6 +77,7 @@ class Object: not be mutated or removed. Trying to mutate this data structure may result in silent errors. """ + def __init__(self, name): super().__setattr__("_name", name) super().__setattr__("_is_placeholder", True) @@ -91,7 +92,8 @@ def __getattribute__(self, key): name = "{}.{}".format(self, key) value = Object(name) super(Object, value).__setattr__( - "_set_value", lambda value: super(Object, self).__setattr__(key, value)) + "_set_value", lambda value: super(Object, self).__setattr__(key, value) + ) super().__setattr__(key, value) super().__setattr__("_is_placeholder", False) return value @@ -108,7 +110,9 @@ def __setattr__(self, key, value): @functools.wraps(pyro.sample) def sample_(self, fn, *args, **kwargs): if not self._is_placeholder: - raise RuntimeError("Cannot .sample_ an initialized named.Object {}".format(self)) + raise RuntimeError( + "Cannot .sample_ an initialized named.Object {}".format(self) + ) value = pyro.sample(str(self), fn, *args, **kwargs) self._set_value(value) return value @@ -139,6 +143,7 @@ class List(list): not be mutated or removed. Trying to mutate this data structure may result in silent errors. """ + def __init__(self, name=None): self._name = name @@ -160,18 +165,25 @@ def add(self): :rtype: named.Object """ if self._name is None: - raise RuntimeError("Cannot .add() to a named.List before storing it in a named.Object") + raise RuntimeError( + "Cannot .add() to a named.List before storing it in a named.Object" + ) i = len(self) value = Object("{}[{}]".format(self._name, i)) super(Object, value).__setattr__( - "_set_value", lambda value, i=i: self.__setitem__(i, value)) + "_set_value", lambda value, i=i: self.__setitem__(i, value) + ) self.append(value) return value def __setitem__(self, pos, value): name = "{}[{}]".format(self._name, pos) if isinstance(value, Object): - raise RuntimeError("Cannot store named.Object {} in named.Dict {}".format(value, self._name)) + raise RuntimeError( + "Cannot store named.Object {} in named.Dict {}".format( + value, self._name + ) + ) elif isinstance(value, (List, Dict)): value._set_name(name) old = self[pos] @@ -197,6 +209,7 @@ class Dict(dict): not be mutated or removed. Trying to mutate this data structure may result in silent errors. """ + def __init__(self, name=None): self._name = name @@ -218,7 +231,8 @@ def __getitem__(self, key): raise RuntimeError("Cannot access an unnamed named.Dict") from e value = Object("{}[{!r}]".format(self._name, key)) super(Object, value).__setattr__( - "_set_value", lambda value: self.__setitem__(key, value)) + "_set_value", lambda value: self.__setitem__(key, value) + ) super().__setitem__(key, value) return value @@ -229,7 +243,11 @@ def __setitem__(self, key, value): if not isinstance(old, Object) or not old._is_placeholder: raise RuntimeError("Cannot overwrite {}".format(name)) if isinstance(value, Object): - raise RuntimeError("Cannot store named.Object {} in named.Dict {}".format(value, self._name)) + raise RuntimeError( + "Cannot store named.Object {} in named.Dict {}".format( + value, self._name + ) + ) elif isinstance(value, (List, Dict)): value._set_name(name) super().__setitem__(key, value) diff --git a/pyro/contrib/autoname/scoping.py b/pyro/contrib/autoname/scoping.py index 06295e4014..b9b746d0ba 100644 --- a/pyro/contrib/autoname/scoping.py +++ b/pyro/contrib/autoname/scoping.py @@ -16,6 +16,7 @@ class NameCountMessenger(Messenger): """ ``NameCountMessenger`` is the implementation of :func:`pyro.contrib.autoname.name_count` """ + def __enter__(self): self._names = set() return super().__enter__() @@ -47,6 +48,7 @@ class ScopeMessenger(Messenger): """ ``ScopeMessenger`` is the implementation of :func:`pyro.contrib.autoname.scope` """ + def __init__(self, prefix=None, inner=None): super().__init__() self.prefix = prefix @@ -76,6 +78,7 @@ def __call__(self, fn): def _fn(*args, **kwargs): with type(self)(prefix=self.prefix, inner=self.inner): return fn(*args, **kwargs) + return _fn def _pyro_scope(self, msg): diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index cc97b051fa..43000e820e 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -57,14 +57,23 @@ class HiddenLayer(TorchDistribution): """ has_rsample = True - def __init__(self, X=None, A_mean=None, A_scale=None, non_linearity=F.relu, - KL_factor=1.0, A_prior_scale=1.0, include_hidden_bias=True, - weight_space_sampling=False): + def __init__( + self, + X=None, + A_mean=None, + A_scale=None, + non_linearity=F.relu, + KL_factor=1.0, + A_prior_scale=1.0, + include_hidden_bias=True, + weight_space_sampling=False, + ): self.X = X self.dim_X = X.size(-1) self.dim_H = A_mean.size(-1) - assert A_mean.size(0) == self.dim_X, \ - "The dimensions of X and A_mean and A_scale must match accordingly; see documentation" + assert ( + A_mean.size(0) == self.dim_X + ), "The dimensions of X and A_mean and A_scale must match accordingly; see documentation" self.A_mean = A_mean self.A_scale = A_scale self.non_linearity = non_linearity @@ -91,14 +100,20 @@ def KL(self): def rsample(self, sample_shape=torch.Size()): # note: weight space sampling is only meant for testing if self.weight_space_sampling: - A = self.A_mean + torch.randn(sample_shape + self.A_scale.shape).type_as(self.A_mean) * self.A_scale + A = ( + self.A_mean + + torch.randn(sample_shape + self.A_scale.shape).type_as(self.A_mean) + * self.A_scale + ) activation = torch.matmul(self.X, A) else: _mean = torch.matmul(self.X, self.A_mean) X_sqr = torch.pow(self.X, 2.0).unsqueeze(-1) A_scale_sqr = torch.pow(self.A_scale, 2.0) _std = (X_sqr * A_scale_sqr).sum(-2).sqrt() - activation = _mean + torch.randn(sample_shape + _std.shape).type_as(_std) * _std + activation = ( + _mean + torch.randn(sample_shape + _std.shape).type_as(_std) * _std + ) # apply non-linearity activation = self.non_linearity(activation) diff --git a/pyro/contrib/cevae/__init__.py b/pyro/contrib/cevae/__init__.py index ad3d648b3b..ee5483e582 100644 --- a/pyro/contrib/cevae/__init__.py +++ b/pyro/contrib/cevae/__init__.py @@ -43,6 +43,7 @@ class FullyConnected(nn.Sequential): """ Fully connected multi-layer network with ELU activations. """ + def __init__(self, sizes, final_activation=None): layers = [] for in_size, out_size in zip(sizes, sizes[1:]): @@ -62,6 +63,7 @@ class DistributionNet(nn.Module): """ Base class for distribution nets. """ + @staticmethod def get_class(dtype): """ @@ -88,6 +90,7 @@ class BernoulliNet(DistributionNet): logits, = net(z) t = net.make_dist(logits).sample() """ + def __init__(self, sizes): assert len(sizes) >= 1 super().__init__() @@ -95,7 +98,7 @@ def __init__(self, sizes): def forward(self, x): logits = self.fc(x).squeeze(-1).clamp(min=-10, max=10) - return logits, + return (logits,) @staticmethod def make_dist(logits): @@ -115,6 +118,7 @@ class ExponentialNet(DistributionNet): rate, = net(x) y = net.make_dist(rate).sample() """ + def __init__(self, sizes): assert len(sizes) >= 1 super().__init__() @@ -123,7 +127,7 @@ def __init__(self, sizes): def forward(self, x): scale = nn.functional.softplus(self.fc(x).squeeze(-1)).clamp(min=1e-3, max=1e6) rate = scale.reciprocal() - return rate, + return (rate,) @staticmethod def make_dist(rate): @@ -144,6 +148,7 @@ class LaplaceNet(DistributionNet): loc, scale = net(x) y = net.make_dist(loc, scale).sample() """ + def __init__(self, sizes): assert len(sizes) >= 1 super().__init__() @@ -174,6 +179,7 @@ class NormalNet(DistributionNet): loc, scale = net(x) y = net.make_dist(loc, scale).sample() """ + def __init__(self, sizes): assert len(sizes) >= 1 super().__init__() @@ -204,6 +210,7 @@ class StudentTNet(DistributionNet): df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample() """ + def __init__(self, sizes): assert len(sizes) >= 1 super().__init__() @@ -239,6 +246,7 @@ class DiagNormalNet(nn.Module): This is intended for the latent ``z`` distribution and the prewhitened ``x`` features, and conservatively clips ``loc`` and ``scale`` values. """ + def __init__(self, sizes): assert len(sizes) >= 2 self.dim = sizes[-1] @@ -247,8 +255,10 @@ def __init__(self, sizes): def forward(self, x): loc_scale = self.fc(x) - loc = loc_scale[..., :self.dim].clamp(min=-1e2, max=1e2) - scale = nn.functional.softplus(loc_scale[..., self.dim:]).add(1e-3).clamp(max=1e2) + loc = loc_scale[..., : self.dim].clamp(min=-1e2, max=1e2) + scale = ( + nn.functional.softplus(loc_scale[..., self.dim :]).add(1e-3).clamp(max=1e2) + ) return loc, scale @@ -256,12 +266,13 @@ class PreWhitener(nn.Module): """ Data pre-whitener. """ + def __init__(self, data): super().__init__() with torch.no_grad(): loc = data.mean(0) scale = data.std(0) - scale[~(scale > 0)] = 1. + scale[~(scale > 0)] = 1.0 self.register_buffer("loc", loc) self.register_buffer("inv_scale", scale.reciprocal()) @@ -286,18 +297,23 @@ class Model(PyroModule): :param dict config: A dict specifying ``feature_dim``, ``latent_dim``, ``hidden_dim``, ``num_layers``, and ``outcome_dist``. """ + def __init__(self, config): self.latent_dim = config["latent_dim"] super().__init__() - self.x_nn = DiagNormalNet([config["latent_dim"]] + - [config["hidden_dim"]] * config["num_layers"] + - [config["feature_dim"]]) + self.x_nn = DiagNormalNet( + [config["latent_dim"]] + + [config["hidden_dim"]] * config["num_layers"] + + [config["feature_dim"]] + ) OutcomeNet = DistributionNet.get_class(config["outcome_dist"]) # The y network is split between the two t values. - self.y0_nn = OutcomeNet([config["latent_dim"]] + - [config["hidden_dim"]] * config["num_layers"]) - self.y1_nn = OutcomeNet([config["latent_dim"]] + - [config["hidden_dim"]] * config["num_layers"]) + self.y0_nn = OutcomeNet( + [config["latent_dim"]] + [config["hidden_dim"]] * config["num_layers"] + ) + self.y1_nn = OutcomeNet( + [config["latent_dim"]] + [config["hidden_dim"]] * config["num_layers"] + ) self.t_nn = BernoulliNet([config["latent_dim"]]) def forward(self, x, t=None, y=None, size=None): @@ -333,7 +349,7 @@ def y_dist(self, t, z): return self.y0_nn.make_dist(*params) def t_dist(self, z): - logits, = self.t_nn(z) + (logits,) = self.t_nn(z) return dist.Bernoulli(logits=logits) @@ -354,6 +370,7 @@ class Guide(PyroModule): :param dict config: A dict specifying ``feature_dim``, ``latent_dim``, ``hidden_dim``, ``num_layers``, and ``outcome_dist``. """ + def __init__(self, config): self.latent_dim = config["latent_dim"] OutcomeNet = DistributionNet.get_class(config["outcome_dist"]) @@ -362,14 +379,18 @@ def __init__(self, config): # The y and z networks both follow an architecture where the first few # layers are shared for t in {0,1}, but the final layer is split # between the two t values. - self.y_nn = FullyConnected([config["feature_dim"]] + - [config["hidden_dim"]] * (config["num_layers"] - 1), - final_activation=nn.ELU()) + self.y_nn = FullyConnected( + [config["feature_dim"]] + + [config["hidden_dim"]] * (config["num_layers"] - 1), + final_activation=nn.ELU(), + ) self.y0_nn = OutcomeNet([config["hidden_dim"]]) self.y1_nn = OutcomeNet([config["hidden_dim"]]) - self.z_nn = FullyConnected([1 + config["feature_dim"]] + - [config["hidden_dim"]] * (config["num_layers"] - 1), - final_activation=nn.ELU()) + self.z_nn = FullyConnected( + [1 + config["feature_dim"]] + + [config["hidden_dim"]] * (config["num_layers"] - 1), + final_activation=nn.ELU(), + ) self.z0_nn = DiagNormalNet([config["hidden_dim"], config["latent_dim"]]) self.z1_nn = DiagNormalNet([config["hidden_dim"], config["latent_dim"]]) @@ -386,7 +407,7 @@ def forward(self, x, t=None, y=None, size=None): pyro.sample("z", self.z_dist(y, t, x)) def t_dist(self, x): - logits, = self.t_nn(x) + (logits,) = self.t_nn(x) return dist.Bernoulli(logits=logits) def y_dist(self, t, x): @@ -418,15 +439,20 @@ class TraceCausalEffect_ELBO(Trace_ELBO): -loss = ELBO + log q(t|x) + log q(y|t,x) """ + def _differentiable_loss_particle(self, model_trace, guide_trace): # Construct -ELBO part. - blocked_names = [name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and site["is_observed"]] + blocked_names = [ + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and site["is_observed"] + ] blocked_guide_trace = guide_trace.copy() for name in blocked_names: del blocked_guide_trace.nodes[name] loss, surrogate_loss = super()._differentiable_loss_particle( - model_trace, blocked_guide_trace) + model_trace, blocked_guide_trace + ) # Add log q terms. for name in blocked_names: @@ -482,11 +508,23 @@ class CEVAE(nn.Module): :param int num_samples: Default number of samples for the :meth:`ite` method. Defaults to 100. """ - def __init__(self, feature_dim, outcome_dist="bernoulli", - latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100): - config = dict(feature_dim=feature_dim, latent_dim=latent_dim, - hidden_dim=hidden_dim, num_layers=num_layers, - num_samples=num_samples) + + def __init__( + self, + feature_dim, + outcome_dist="bernoulli", + latent_dim=20, + hidden_dim=200, + num_layers=3, + num_samples=100, + ): + config = dict( + feature_dim=feature_dim, + latent_dim=latent_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + num_samples=num_samples, + ) for name, size in config.items(): if not (isinstance(size, int) and size > 0): raise ValueError("Expected {} > 0 but got {}".format(name, size)) @@ -498,13 +536,18 @@ def __init__(self, feature_dim, outcome_dist="bernoulli", self.model = Model(config) self.guide = Guide(config) - def fit(self, x, t, y, - num_epochs=100, - batch_size=100, - learning_rate=1e-3, - learning_rate_decay=0.1, - weight_decay=1e-4, - log_every=100): + def fit( + self, + x, + t, + y, + num_epochs=100, + batch_size=100, + learning_rate=1e-3, + learning_rate_decay=0.1, + weight_decay=1e-4, + log_every=100, + ): """ Train using :class:`~pyro.infer.svi.SVI` with the :class:`TraceCausalEffect_ELBO` loss. @@ -534,9 +577,13 @@ def fit(self, x, t, y, dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) logger.info("Training with {} minibatches per epoch".format(len(dataloader))) num_steps = num_epochs * len(dataloader) - optim = ClippedAdam({"lr": learning_rate, - "weight_decay": weight_decay, - "lrd": learning_rate_decay ** (1 / num_steps)}) + optim = ClippedAdam( + { + "lr": learning_rate, + "weight_decay": weight_decay, + "lrd": learning_rate_decay ** (1 / num_steps), + } + ) svi = SVI(self.model, self.guide, optim, TraceCausalEffect_ELBO()) losses = [] for epoch in range(num_epochs): @@ -544,7 +591,9 @@ def fit(self, x, t, y, x = self.whiten(x) loss = svi.step(x, t, y, size=len(dataset)) / len(dataset) if log_every and len(losses) % log_every == 0: - logger.debug("step {: >5d} loss = {:0.6g}".format(len(losses), loss)) + logger.debug( + "step {: >5d} loss = {:0.6g}".format(len(losses), loss) + ) assert not torch_isnan(loss) losses.append(loss) return losses diff --git a/pyro/contrib/conjugate/infer.py b/pyro/contrib/conjugate/infer.py index 0c815c0126..e1332eb652 100644 --- a/pyro/contrib/conjugate/infer.py +++ b/pyro/contrib/conjugate/infer.py @@ -26,6 +26,7 @@ def _make_cls(base, static_attrs, instance_attrs, parent_linkage=None): a reference to the distribution class. :return cls: dynamically generated class. """ + def _expand(self, batch_shape, _instance=None): new = self._get_checked_instance(cls, _instance) for attr in instance_attrs: @@ -43,11 +44,15 @@ def _expand(self, batch_shape, _instance=None): def _latent(base, parent): - return _make_cls(base, {"collapsible": True}, {"site_name": None, "parent": parent}, "_latent") + return _make_cls( + base, {"collapsible": True}, {"site_name": None, "parent": parent}, "_latent" + ) def _conditional(base, parent): - return _make_cls(base, {"marginalize_latent": True}, {"parent": parent}, "_conditional") + return _make_cls( + base, {"marginalize_latent": True}, {"parent": parent}, "_conditional" + ) def _compound(base, parent): @@ -76,14 +81,18 @@ def posterior(self, obs): # Raise exception if this isn't possible. total_count = sum_leftmost(total_count, reduce_dims) summed_obs = sum_leftmost(obs, reduce_dims) - return dist.Beta(concentration1 + summed_obs, - total_count + concentration0 - summed_obs, - validate_args=self._latent._validate_args) + return dist.Beta( + concentration1 + summed_obs, + total_count + concentration0 - summed_obs, + validate_args=self._latent._validate_args, + ) def compound(self): - return _compound(dist.BetaBinomial, parent=self)(concentration1=self._latent.concentration1, - concentration0=self._latent.concentration0, - total_count=self._conditional.total_count) + return _compound(dist.BetaBinomial, parent=self)( + concentration1=self._latent.concentration1, + concentration0=self._latent.concentration0, + total_count=self._conditional.total_count, + ) class GammaPoissonPair: @@ -108,8 +117,9 @@ def posterior(self, obs): return dist.Gamma(concentration + summed_obs, rate + num_obs) def compound(self): - return _compound(dist.GammaPoisson, parent=self)(concentration=self._latent.concentration, - rate=self._latent.rate) + return _compound(dist.GammaPoisson, parent=self)( + concentration=self._latent.concentration, rate=self._latent.rate + ) class UncollapseConjugateMessenger(Messenger): @@ -117,6 +127,7 @@ class UncollapseConjugateMessenger(Messenger): Replay regular sample sites in addition to uncollapsing any collapsed conjugate sites. """ + def __init__(self, trace): """ :param trace: a trace whose values should be reused @@ -137,8 +148,11 @@ def _pyro_sample(self, msg): if parent is not None and parent._latent.site_name == msg["name"]: conj_node = self.trace.nodes[site_name] break - assert conj_node is not None, "Collapsible latent site `{}` with no corresponding conjugate site."\ - .format(msg["name"]) + assert ( + conj_node is not None + ), "Collapsible latent site `{}` with no corresponding conjugate site.".format( + msg["name"] + ) msg["fn"] = parent.posterior(conj_node["value"]) msg["value"] = msg["fn"].sample() # regular replay behavior. @@ -208,16 +222,21 @@ def posterior_replay(model, posterior_samples, *args, **kwargs): """ posterior_samples = posterior_samples.copy() num_samples = kwargs.pop("num_samples", None) - assert posterior_samples or num_samples, "`num_samples` must be provided if `posterior_samples` is empty." + assert ( + posterior_samples or num_samples + ), "`num_samples` must be provided if `posterior_samples` is empty." if num_samples is None: num_samples = list(posterior_samples.values())[0].shape[0] return_samples = defaultdict(list) for i in range(num_samples): conditioned_nodes = {k: v[i] for k, v in posterior_samples.items()} - collapsed_trace = poutine.trace(poutine.condition(collapse_conjugate(model), conditioned_nodes))\ - .get_trace(*args, **kwargs) - trace = poutine.trace(uncollapse_conjugate(model, collapsed_trace)).get_trace(*args, **kwargs) + collapsed_trace = poutine.trace( + poutine.condition(collapse_conjugate(model), conditioned_nodes) + ).get_trace(*args, **kwargs) + trace = poutine.trace(uncollapse_conjugate(model, collapsed_trace)).get_trace( + *args, **kwargs + ) for name, site in trace.iter_stochastic_nodes(): if not site_is_subsample(site): return_samples[name].append(site["value"]) diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index 55535ae72d..13b3718ad4 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -42,6 +42,7 @@ class EasyGuide(PyroModule, metaclass=_EasyGuideMeta): :param callable model: A Pyro model. """ + def __init__(self, model): super().__init__() self._pyro_name = type(self).__name__ @@ -58,12 +59,16 @@ def model(self): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = poutine.block(InitMessenger(self.init)(self.model), prototype_hide_fn) - self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) + self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( + *args, **kwargs + ) for name, site in self.prototype_trace.iter_stochastic_nodes(): for frame in site["cond_indep_stack"]: if not frame.vectorized: - raise NotImplementedError("EasyGuide does not support sequential pyro.plate") + raise NotImplementedError( + "EasyGuide does not support sequential pyro.plate" + ) self.frames[frame.name] = frame @abstractmethod @@ -100,14 +105,18 @@ def forward(self, *args, **kwargs): self.plates.clear() return result - def plate(self, name, size=None, subsample_size=None, subsample=None, *args, **kwargs): + def plate( + self, name, size=None, subsample_size=None, subsample=None, *args, **kwargs + ): """ A wrapper around :class:`pyro.plate` to allow `EasyGuide` to automatically construct plates. You should use this rather than :class:`pyro.plate` inside your :meth:`guide` implementation. """ if name not in self.plates: - self.plates[name] = pyro.plate(name, size, subsample_size, subsample, *args, **kwargs) + self.plates[name] = pyro.plate( + name, size, subsample_size, subsample, *args, **kwargs + ) return self.plates[name] def group(self, match=".*"): @@ -119,12 +128,17 @@ def group(self, match=".*"): :rtype: Group """ if match not in self.groups: - sites = [site - for name, site in self.prototype_trace.iter_stochastic_nodes() - if re.match(match, name)] + sites = [ + site + for name, site in self.prototype_trace.iter_stochastic_nodes() + if re.match(match, name) + ] if not sites: - raise ValueError("EasyGuide.group() pattern {} matched no model sites" - .format(repr(match))) + raise ValueError( + "EasyGuide.group() pattern {} matched no model sites".format( + repr(match) + ) + ) self.groups[match] = Group(self, sites) return self.groups[match] @@ -171,6 +185,7 @@ class Group: :param EasyGuide guide: An easyguide instance. :param list sites: A list of model sites. """ + def __init__(self, guide, sites): assert isinstance(sites, list) assert sites @@ -181,10 +196,13 @@ def __init__(self, guide, sites): # A group is in a frame only if all its sample sites are in that frame. # Thus a group can be subsampled only if all its sites can be subsampled. - self.common_frames = frozenset.intersection(*( - frozenset(f for f in site["cond_indep_stack"] if f.vectorized) - for site in sites)) - rightmost_common_dim = -float('inf') + self.common_frames = frozenset.intersection( + *( + frozenset(f for f in site["cond_indep_stack"] if f.vectorized) + for site in sites + ) + ) + rightmost_common_dim = -float("inf") if self.common_frames: rightmost_common_dim = max(f.dim for f in self.common_frames) @@ -202,8 +220,10 @@ def __init__(self, guide, sites): if len(site_batch_shape) > -rightmost_common_dim: raise ValueError( "Group expects all per-site plates to be right of all common plates, " - "but found a per-site plate {} on left at site {}" - .format(-len(site_batch_shape), repr(site["name"]))) + "but found a per-site plate {} on left at site {}".format( + -len(site_batch_shape), repr(site["name"]) + ) + ) site_batch_shape = torch.Size(site_batch_shape) self._site_batch_shapes[site["name"]] = site_batch_shape self._site_sizes[site["name"]] = site_batch_shape.numel() * site_event_numel @@ -211,7 +231,7 @@ def __init__(self, guide, sites): def __getstate__(self): state = self.__dict__.copy() - state['_guide'] = state['_guide']() # weakref -> ref + state["_guide"] = state["_guide"]() # weakref -> ref return state def __setstate__(self, state): @@ -237,8 +257,11 @@ def sample(self, guide_name, fn, infer=None): """ # Sample a packed tensor. if fn.event_shape != self.event_shape: - raise ValueError("Invalid fn.event_shape for group: expected {}, actual {}" - .format(tuple(self.event_shape), tuple(fn.event_shape))) + raise ValueError( + "Invalid fn.event_shape for group: expected {}, actual {}".format( + tuple(self.event_shape), tuple(fn.event_shape) + ) + ) if infer is None: infer = {} infer["is_auxiliary"] = True @@ -253,8 +276,10 @@ def sample(self, guide_name, fn, infer=None): # Extract slice from packed sample. size = self._site_sizes[name] - batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) - unconstrained_z = guide_z[..., pos: pos + size] + batch_shape = broadcast_shape( + common_batch_shape, self._site_batch_shapes[name] + ) + unconstrained_z = guide_z[..., pos : pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size @@ -262,7 +287,9 @@ def sample(self, guide_name, fn, infer=None): transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian(z, unconstrained_z) - log_density = sum_rightmost(log_density, log_density.dim() - z.dim() + fn.event_dim) + log_density = sum_rightmost( + log_density, log_density.dim() - z.dim() + fn.event_dim + ) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. @@ -282,8 +309,10 @@ def map_estimate(self): :return: A dict mapping model site name to sampled value. :rtype: dict """ - return {site["name"]: self.guide.map_estimate(site["name"]) - for site in self.prototype_sites} + return { + site["name"]: self.guide.map_estimate(site["name"]) + for site in self.prototype_sites + } def easy_guide(model): diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 865fe693a1..6531f00e3a 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -56,9 +56,11 @@ def _require_double_precision(): if torch.get_default_dtype() != torch.float64: - warnings.warn("CompartmentalModel is unstable for dtypes less than torch.float64; " - "try torch.set_default_dtype(torch.float64)", - RuntimeWarning) + warnings.warn( + "CompartmentalModel is unstable for dtypes less than torch.float64; " + "try torch.set_default_dtype(torch.float64)", + RuntimeWarning, + ) @contextmanager @@ -71,8 +73,9 @@ def _disallow_latent_variables(section_name): yield for name, site in tr.trace.nodes.items(): if site["type"] == "sample" and not site["is_observed"]: - raise NotImplementedError("{} contained latent variable {}" - .format(section_name, name)) + raise NotImplementedError( + "{} contained latent variable {}".format(section_name, name) + ) class CompartmentalModel(ABC): @@ -144,8 +147,7 @@ def transition(self, params, state, t): ... are continuous-valued with support ``(-0.5, population + 0.5)``. """ - def __init__(self, compartments, duration, population, *, - approximate=()): + def __init__(self, compartments, duration, population, *, approximate=()): super().__init__() assert isinstance(duration, int) @@ -183,8 +185,9 @@ def time_plate(self): A ``pyro.plate`` for the time dimension. """ if self._time_plate is None: - self._time_plate = pyro.plate("time", self.duration, - dim=-2 if self.is_regional else -1) + self._time_plate = pyro.plate( + "time", self.duration, dim=-2 if self.is_regional else -1 + ) return self._time_plate @property @@ -211,8 +214,13 @@ def full_mass(self): """ with torch.no_grad(), poutine.block(), poutine.trace() as tr: self.global_model() - return [tuple(name for name, site in tr.trace.iter_stochastic_nodes() - if not site_is_subsample(site))] + return [ + tuple( + name + for name, site in tr.trace.iter_stochastic_nodes() + if not site_is_subsample(site) + ) + ] @lazy_property def series(self): @@ -228,10 +236,12 @@ def series(self): curr = prev.copy() with poutine.trace() as tr: self.transition(params, curr, 0) - return frozenset(re.match("(.*)_0", name).group(1) - for name, site in tr.trace.nodes.items() - if site["type"] == "sample" - if not site_is_subsample(site)) + return frozenset( + re.match("(.*)_0", name).group(1) + for name, site in tr.trace.nodes.items() + if site["type"] == "sample" + if not site_is_subsample(site) + ) # Overridable attributes and methods ######################################## @@ -362,26 +372,31 @@ def generate(self, fixed={}): model = self._generative_model model = poutine.condition(model, fixed) trace = poutine.trace(model).get_trace() - samples = OrderedDict((name, site["value"]) - for name, site in trace.nodes.items() - if site["type"] == "sample") + samples = OrderedDict( + (name, site["value"]) + for name, site in trace.nodes.items() + if site["type"] == "sample" + ) self._concat_series(samples, trace) return samples - def fit_svi(self, *, - num_samples=100, - num_steps=2000, - num_particles=32, - learning_rate=0.1, - learning_rate_decay=0.01, - betas=(0.8, 0.99), - haar=True, - init_scale=0.01, - guide_rank=0, - jit=False, - log_every=200, - **options): + def fit_svi( + self, + *, + num_samples=100, + num_steps=2000, + num_particles=32, + learning_rate=0.1, + learning_rate_decay=0.01, + betas=(0.8, 0.99), + haar=True, + init_scale=0.01, + guide_rank=0, + jit=False, + log_every=200, + **options, + ): """ Runs stochastic variational inference to generate posterior samples. @@ -435,9 +450,11 @@ def fit_svi(self, *, haar = _HaarSplitReparam(0, self.duration, dims, supports) # Heuristically initialize to feasible latents. - heuristic_options = {k.replace("heuristic_", ""): options.pop(k) - for k in list(options) - if k.startswith("heuristic_")} + heuristic_options = { + k.replace("heuristic_", ""): options.pop(k) + for k in list(options) + if k.startswith("heuristic_") + } assert not options, "unrecognized options: {}".format(", ".join(options)) init_strategy = self._heuristic(haar, **heuristic_options) @@ -449,19 +466,29 @@ def fit_svi(self, *, if guide_rank == 0: guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank == "full": - guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy, - init_scale=init_scale) + guide = AutoMultivariateNormal( + model, init_loc_fn=init_strategy, init_scale=init_scale + ) elif guide_rank is None or isinstance(guide_rank, int): - guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy, - init_scale=init_scale, rank=guide_rank) + guide = AutoLowRankMultivariateNormal( + model, init_loc_fn=init_strategy, init_scale=init_scale, rank=guide_rank + ) else: raise ValueError("Invalid guide_rank: {}".format(guide_rank)) Elbo = JitTrace_ELBO if jit else Trace_ELBO - elbo = Elbo(max_plate_nesting=self.max_plate_nesting, - num_particles=num_particles, vectorize_particles=True, - ignore_jit_warnings=True) - optim = ClippedAdam({"lr": learning_rate, "betas": betas, - "lrd": learning_rate_decay ** (1 / num_steps)}) + elbo = Elbo( + max_plate_nesting=self.max_plate_nesting, + num_particles=num_particles, + vectorize_particles=True, + ignore_jit_warnings=True, + ) + optim = ClippedAdam( + { + "lr": learning_rate, + "betas": betas, + "lrd": learning_rate_decay ** (1 / num_steps), + } + ) svi = SVI(model, guide, optim, elbo) # Run inference. @@ -473,24 +500,33 @@ def fit_svi(self, *, logger.info("step {} loss = {:0.4g}".format(step, loss)) losses.append(loss) elapsed = default_timer() - start_time - logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec" - .format(elapsed, (1 + num_steps) / elapsed)) + logger.info( + "SVI took {:0.1f} seconds, {:0.1f} step/sec".format( + elapsed, (1 + num_steps) / elapsed + ) + ) # Draw posterior samples. with torch.no_grad(): - particle_plate = pyro.plate("particles", num_samples, - dim=-1 - self.max_plate_nesting) + particle_plate = pyro.plate( + "particles", num_samples, dim=-1 - self.max_plate_nesting + ) guide_trace = poutine.trace(particle_plate(guide)).get_trace() model_trace = poutine.trace( - poutine.replay(particle_plate(model), guide_trace)).get_trace() - self.samples = {name: site["value"] for name, site in model_trace.nodes.items() - if site["type"] == "sample" - if not site["is_observed"] - if not site_is_subsample(site)} + poutine.replay(particle_plate(model), guide_trace) + ).get_trace() + self.samples = { + name: site["value"] + for name, site in model_trace.nodes.items() + if site["type"] == "sample" + if not site["is_observed"] + if not site_is_subsample(site) + } if haar: haar.aux_to_user(self.samples) - assert all(v.size(0) == num_samples for v in self.samples.values()), \ - {k: tuple(v.shape) for k, v in self.samples.items()} + assert all(v.size(0) == num_samples for v in self.samples.values()), { + k: tuple(v.shape) for k, v in self.samples.items() + } return losses @@ -573,26 +609,31 @@ def fit_mcmc(self, **options): full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims)) # Heuristically initialize to feasible latents. - heuristic_options = {k.replace("heuristic_", ""): options.pop(k) - for k in list(options) - if k.startswith("heuristic_")} + heuristic_options = { + k.replace("heuristic_", ""): options.pop(k) + for k in list(options) + if k.startswith("heuristic_") + } init_strategy = init_to_generated( - generate=functools.partial(self._heuristic, haar, **heuristic_options)) + generate=functools.partial(self._heuristic, haar, **heuristic_options) + ) # Configure a kernel. logger.info("Running inference...") model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) - kernel = NUTS(model, - full_mass=full_mass, - init_strategy=init_strategy, - max_plate_nesting=self.max_plate_nesting, - jit_compile=options.pop("jit_compile", False), - jit_options=options.pop("jit_options", None), - ignore_jit_warnings=options.pop("ignore_jit_warnings", True), - target_accept_prob=options.pop("target_accept_prob", 0.8), - max_tree_depth=options.pop("max_tree_depth", 5)) + kernel = NUTS( + model, + full_mass=full_mass, + init_strategy=init_strategy, + max_plate_nesting=self.max_plate_nesting, + jit_compile=options.pop("jit_compile", False), + jit_options=options.pop("jit_options", None), + ignore_jit_warnings=options.pop("ignore_jit_warnings", True), + target_accept_prob=options.pop("target_accept_prob", 0.8), + max_tree_depth=options.pop("max_tree_depth", 5), + ) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() @@ -607,10 +648,12 @@ def fit_mcmc(self, **options): # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). model = self._relaxed_model if self.relaxed else self._quantized_model - self.samples = align_samples(self.samples, model, - particle_dim=-1 - self.max_plate_nesting) - assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \ - {k: tuple(v.shape) for k, v in self.samples.items()} + self.samples = align_samples( + self.samples, model, particle_dim=-1 - self.max_plate_nesting + ) + assert all( + v.size(0) == num_samples * num_chains for v in self.samples.values() + ), {k: tuple(v.shape) for k, v in self.samples.items()} return mcmc # E.g. so user can run mcmc.summary(). @@ -636,28 +679,35 @@ def predict(self, forecast=0): samples = self.samples num_samples = len(next(iter(samples.values()))) - particle_plate = pyro.plate("particles", num_samples, - dim=-1 - self.max_plate_nesting) + particle_plate = pyro.plate( + "particles", num_samples, dim=-1 - self.max_plate_nesting + ) # Sample discrete auxiliary variables conditioned on the continuous # variables sampled by _quantized_model. This samples only time steps # [0:duration]. Here infer_discrete runs a forward-filter # backward-sample algorithm. - logger.info("Predicting latent variables for {} time steps..." - .format(self.duration)) + logger.info( + "Predicting latent variables for {} time steps...".format(self.duration) + ) model = self._sequential_model model = poutine.condition(model, samples) model = particle_plate(model) if not self.relaxed: - model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting) + model = infer_discrete( + model, first_available_dim=-2 - self.max_plate_nesting + ) trace = poutine.trace(model).get_trace() - samples = OrderedDict((name, site["value"].expand(site["fn"].shape())) - for name, site in trace.nodes.items() - if site["type"] == "sample" - if not site_is_subsample(site) - if not site_is_factor(site)) - assert all(v.size(0) == num_samples for v in samples.values()), \ - {k: tuple(v.shape) for k, v in samples.items()} + samples = OrderedDict( + (name, site["value"].expand(site["fn"].shape())) + for name, site in trace.nodes.items() + if site["type"] == "sample" + if not site_is_subsample(site) + if not site_is_factor(site) + ) + assert all(v.size(0) == num_samples for v in samples.values()), { + k: tuple(v.shape) for k, v in samples.items() + } # Optionally forecast with the forward _generative_model. This samples # time steps [duration:duration+forecast]. @@ -667,15 +717,18 @@ def predict(self, forecast=0): model = poutine.condition(model, samples) model = particle_plate(model) trace = poutine.trace(model).get_trace(forecast) - samples = OrderedDict((name, site["value"]) - for name, site in trace.nodes.items() - if site["type"] == "sample" - if not site_is_subsample(site) - if not site_is_factor(site)) + samples = OrderedDict( + (name, site["value"]) + for name, site in trace.nodes.items() + if site["type"] == "sample" + if not site_is_subsample(site) + if not site_is_factor(site) + ) self._concat_series(samples, trace, forecast) - assert all(v.size(0) == num_samples for v in samples.values()), \ - {k: tuple(v.shape) for k, v in samples.items()} + assert all(v.size(0) == num_samples for v in samples.values()), { + k: tuple(v.shape) for k, v in samples.items() + } return samples @torch.no_grad() @@ -701,9 +754,13 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): model = _SMCModel(self) guide = _SMCGuide(self) for attempt in range(1, 1 + retries): - smc = SMCFilter(model, guide, num_particles=num_particles, - ess_threshold=ess_threshold, - max_plate_nesting=self.max_plate_nesting) + smc = SMCFilter( + model, + guide, + num_particles=num_particles, + ess_threshold=ess_threshold, + max_plate_nesting=self.max_plate_nesting, + ) try: smc.init() for t in range(1, self.duration): @@ -732,12 +789,16 @@ def _heuristic(self, haar, **options): with poutine.block(): init_values = self.heuristic(**options) assert isinstance(init_values, dict) - assert "auxiliary" in init_values, \ - ".heuristic() did not define auxiliary value" - logger.info("Heuristic init: {}".format(", ".join( - "{}={:0.3g}".format(k, v.item()) - for k, v in sorted(init_values.items()) - if v.numel() == 1))) + assert "auxiliary" in init_values, ".heuristic() did not define auxiliary value" + logger.info( + "Heuristic init: {}".format( + ", ".join( + "{}={:0.3g}".format(k, v.item()) + for k, v in sorted(init_values.items()) + if v.numel() == 1 + ) + ) + ) return init_to_value(values=init_values, fallback=None) def _concat_series(self, samples, trace, forecast=0): @@ -807,9 +868,13 @@ def _sample_auxiliary(self): # Sample the compartmental continuous reparameterizing variable. shape = (C, T) + R_shape - auxiliary = pyro.sample("auxiliary", - dist.Uniform(-0.5, self.population + 0.5) - .mask(False).expand(shape).to_event()) + auxiliary = pyro.sample( + "auxiliary", + dist.Uniform(-0.5, self.population + 0.5) + .mask(False) + .expand(shape) + .to_event(), + ) extra_dims = auxiliary.dim() - len(shape) # Sample any non-compartmental time series in batch. @@ -848,9 +913,10 @@ def _transition_bwd(self, params, prev, curr, t): if is_validation_enabled(): for key in self.compartments: if not torch.allclose(state[key], curr[key]): - raise ValueError("Incorrect state['{}'] update in .transition(), " - "check that .transition() matches .compute_flows()." - .format(key)) + raise ValueError( + "Incorrect state['{}'] update in .transition(), " + "check that .transition() matches .compute_flows().".format(key) + ) def _generative_model(self, forecast=0): """ @@ -861,8 +927,10 @@ def _generative_model(self, forecast=0): # Sample initial values. state = self.initialize(params) - state = {k: v if isinstance(v, torch.Tensor) else torch.tensor(float(v)) - for k, v in state.items()} + state = { + k: v if isinstance(v, torch.Tensor) else torch.tensor(float(v)) + for k, v in state.items() + } # Sequentially transition. for t in range(self.duration + forecast): @@ -871,7 +939,9 @@ def _generative_model(self, forecast=0): self.transition(params, state, t) with self.region_plate: for name in self.compartments: - pyro.deterministic("{}_{}".format(name, t), state[name], event_dim=0) + pyro.deterministic( + "{}_{}".format(name, t), state[name], event_dim=0 + ) self._clear_plates() @@ -892,8 +962,10 @@ def _sequential_model(self): auxiliary, non_compartmental = self._sample_auxiliary() # Reshape to accommodate the time_plate below. - assert auxiliary.shape == (num_samples, C, T) + R_shape, \ - (auxiliary.shape, (num_samples, C, T) + R_shape) + assert auxiliary.shape == (num_samples, C, T) + R_shape, ( + auxiliary.shape, + (num_samples, C, T) + R_shape, + ) aux = [aux.unbind(2) for aux in auxiliary.unsqueeze(1).unbind(2)] # Sequentially transition. @@ -904,13 +976,17 @@ def _sequential_model(self): # Extract any non-compartmental variables. for name, value in non_compartmental.items(): - curr[name] = value[:, t:t+1] + curr[name] = value[:, t : t + 1] # Extract and enumerate all compartmental variables. for c, name in enumerate(self.compartments): - curr[name] = quantize("{}_{}".format(name, t), aux[c][t], - min=0, max=self.population, - num_quant_bins=self.num_quant_bins) + curr[name] = quantize( + "{}_{}".format(name, t), + aux[c][t], + min=0, + max=self.population, + num_quant_bins=self.num_quant_bins, + ) # Enable approximate inference by using aux as a # non-enumerated proxy for enumerated compartment values. if name in self.approximate: @@ -936,8 +1012,9 @@ def _quantized_model(self): auxiliary, non_compartmental = self._sample_auxiliary() # Manually enumerate. - curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, - num_quant_bins=self.num_quant_bins) + curr, logp = quantize_enumerate( + auxiliary, min=0, max=self.population, num_quant_bins=self.num_quant_bins + ) curr = OrderedDict(zip(self.compartments, curr.unbind(0))) logp = OrderedDict(zip(self.compartments, logp.unbind(0))) curr.update(non_compartmental) @@ -949,8 +1026,9 @@ def _quantized_model(self): if name in self.compartments: if isinstance(value, torch.Tensor): value = value[..., None] # Because curr is enumerated on the right. - prev[name] = cat2(value, curr[name][:-1], - dim=-3 if self.is_regional else -2) + prev[name] = cat2( + value, curr[name][:-1], dim=-3 if self.is_regional else -2 + ) else: # non-compartmental prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim()) @@ -973,8 +1051,9 @@ def enum_reshape(tensor, position): for name in self.approximate: aux = auxiliary[self.compartments.index(name)] curr[name + "_approx"] = aux - prev[name + "_approx"] = cat2(init[name], aux[:-1], - dim=-2 if self.is_regional else -1) + prev[name + "_approx"] = cat2( + init[name], aux[:-1], dim=-2 if self.is_regional else -1 + ) # Record transition factors. with poutine.block(), poutine.trace() as tr: @@ -990,14 +1069,18 @@ def enum_reshape(tensor, position): continue if self.is_regional and log_prob.shape[-1:] != R_shape: # Poor man's tensor variable elimination. - log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0] + log_prob = ( + log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0] + ) logp[name] = site["log_prob"] # Manually perform variable elimination. logp = reduce(operator.add, logp.values()) logp = logp.reshape(Q ** C, Q ** C, T, -1) # prev, curr, T, batch logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr - logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) # batch, prev, curr + logp = pyro.distributions.hmm._sequential_logmatmulexp( + logp + ) # batch, prev, curr logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum() warn_if_nan(logp) pyro.factor("transition", logp) @@ -1056,6 +1139,7 @@ class _SMCModel: """ Helper to initialize a CompartmentalModel to a feasible initial state. """ + def __init__(self, model): assert isinstance(model, CompartmentalModel) self.model = model @@ -1097,6 +1181,7 @@ class _SMCGuide(_SMCModel): """ Like _SMCModel but does not update state and does not observe. """ + def init(self, state): super().init(state.copy()) @@ -1110,6 +1195,7 @@ class _HaarSplitReparam: Wrapper around ``HaarReparam`` and ``SplitReparam`` to additionally convert sample dicts between user-facing and auxiliary coordinates. """ + def __init__(self, split, duration, dims, supports): assert 0 <= split < duration self.split = split @@ -1147,10 +1233,13 @@ def aux_to_user(self, samples): if self.split: # Transform back from SplitReparam coordinates. for name, dim in self.dims.items(): - samples[name + "_haar"] = torch.cat([ - samples.pop(name + "_haar_split_0"), - samples.pop(name + "_haar_split_1"), - ], dim=dim) + samples[name + "_haar"] = torch.cat( + [ + samples.pop(name + "_haar_split_0"), + samples.pop(name + "_haar_split_1"), + ], + dim=dim, + ) # Transform back from Haar coordinates. for name, dim in self.dims.items(): diff --git a/pyro/contrib/epidemiology/distributions.py b/pyro/contrib/epidemiology/distributions.py index 7742a8641a..c312ef7f97 100644 --- a/pyro/contrib/epidemiology/distributions.py +++ b/pyro/contrib/epidemiology/distributions.py @@ -121,7 +121,8 @@ def _relaxed_beta_binomial(concentration1, concentration0, total_count): ``total_count`` and lower-bounding variance. """ concentration1, concentration0, total_count = broadcast_all( - concentration1, concentration0, total_count) + concentration1, concentration0, total_count + ) c = concentration1 + concentration0 beta_mean = concentration1 / c @@ -133,8 +134,7 @@ def _relaxed_beta_binomial(concentration1, concentration0, total_count): return dist.Normal(mean, scale) -def binomial_dist(total_count, probs, *, - overdispersion=0.): +def binomial_dist(total_count, probs, *, overdispersion=0.0): """ Returns a Beta-Binomial distribution that is an overdispersed version of a Binomial distribution, according to a parameter ``overdispersion``, @@ -191,8 +191,9 @@ def binomial_dist(total_count, probs, *, return dist.ExtendedBetaBinomial(concentration1, concentration0, total_count) -def beta_binomial_dist(concentration1, concentration0, total_count, *, - overdispersion=0.): +def beta_binomial_dist( + concentration1, concentration0, total_count, *, overdispersion=0.0 +): """ Returns a Beta-Binomial distribution that is an overdispersed version of a the usual Beta-Binomial distribution, according to an extra parameter @@ -226,28 +227,31 @@ def beta_binomial_dist(concentration1, concentration0, total_count, *, return dist.ExtendedBetaBinomial(concentration1, concentration0, total_count) -def poisson_dist(rate, *, overdispersion=0.): +def poisson_dist(rate, *, overdispersion=0.0): _validate_overdispersion(overdispersion) if _is_zero(overdispersion): return dist.Poisson(rate) raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson") -def negative_binomial_dist(concentration, probs=None, *, - logits=None, overdispersion=0.): +def negative_binomial_dist( + concentration, probs=None, *, logits=None, overdispersion=0.0 +): _validate_overdispersion(overdispersion) if _is_zero(overdispersion): return dist.NegativeBinomial(concentration, probs=probs, logits=logits) raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson") -def infection_dist(*, - individual_rate, - num_infectious, - num_susceptible=math.inf, - population=math.inf, - concentration=math.inf, - overdispersion=0.): +def infection_dist( + *, + individual_rate, + num_infectious, + num_susceptible=math.inf, + population=math.inf, + concentration=math.inf, + overdispersion=0.0 +): """ Create a :class:`~pyro.distributions.Distribution` over the number of new infections at a discrete time step. @@ -313,8 +317,9 @@ def infection_dist(*, # Return an overdispersed Negative-Binomial distribution. combined_k = k * I logits = torch.as_tensor(R / k).log() - return negative_binomial_dist(combined_k, logits=logits, - overdispersion=overdispersion) + return negative_binomial_dist( + combined_k, logits=logits, overdispersion=overdispersion + ) else: # Compute the probability that any given (susceptible, infectious) # pair of individuals results in an infection at this time step. diff --git a/pyro/contrib/epidemiology/models.py b/pyro/contrib/epidemiology/models.py index 393b94f0fd..ddfb515a9d 100644 --- a/pyro/contrib/epidemiology/models.py +++ b/pyro/contrib/epidemiology/models.py @@ -45,7 +45,7 @@ def __init__(self, population, recovery_time, data): def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, tau, rho @@ -57,13 +57,16 @@ def transition(self, params, state, t): R0, tau, rho = params # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0 / tau, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -71,9 +74,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho), + obs=self.data[t] if t_is_observed else None, + ) class SimpleSEIRModel(CompartmentalModel): @@ -116,7 +121,7 @@ def __init__(self, population, incubation_time, recovery_time, data): def global_model(self): tau_e = self.incubation_time tau_i = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, tau_e, tau_i, rho @@ -128,15 +133,17 @@ def transition(self, params, state, t): R0, tau_e, tau_i, rho = params # Sample flows between compartments. - S2E = pyro.sample("S2E_{}".format(t), - infection_dist(individual_rate=R0 / tau_i, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population)) - E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau_i)) + S2E = pyro.sample( + "S2E_{}".format(t), + infection_dist( + individual_rate=R0 / tau_i, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + ), + ) + E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e)) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i)) # Update compartments with flows. state["S"] = state["S"] - S2E @@ -145,9 +152,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2E, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2E, rho), + obs=self.data[t] if t_is_observed else None, + ) class SimpleSEIRDModel(CompartmentalModel): @@ -177,8 +186,9 @@ class SimpleSEIRDModel(CompartmentalModel): transitions. This allows false negative but no false positives. """ - def __init__(self, population, incubation_time, recovery_time, - mortality_rate, data): + def __init__( + self, population, incubation_time, recovery_time, mortality_rate, data + ): compartments = ("S", "E", "I", "D") # R is implicit. duration = len(data) super().__init__(compartments, duration, population) @@ -201,7 +211,7 @@ def global_model(self): tau_e = self.incubation_time tau_i = self.recovery_time mu = self.mortality_rate - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, tau_e, tau_i, mu, rho @@ -213,22 +223,25 @@ def transition(self, params, state, t): R0, tau_e, tau_i, mu, rho = params # Sample flows between compartments. - S2E = pyro.sample("S2E_{}".format(t), - infection_dist(individual_rate=R0 / tau_i, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population)) - E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e)) + S2E = pyro.sample( + "S2E_{}".format(t), + infection_dist( + individual_rate=R0 / tau_i, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + ), + ) + E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e)) # Of the 1/tau_i expected recoveries-or-deaths, a portion mu die and # the remaining recover. Alternatively we could model this with a # Multinomial distribution I2_ and extract the two components I2D and # I2R, however the Multinomial distribution does not currently # implement overdispersion or moment matching. - I2D = pyro.sample("I2D_{}".format(t), - binomial_dist(state["I"], mu / tau_i)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"] - I2D, 1 / tau_i)) + I2D = pyro.sample("I2D_{}".format(t), binomial_dist(state["I"], mu / tau_i)) + I2R = pyro.sample( + "I2R_{}".format(t), binomial_dist(state["I"] - I2D, 1 / tau_i) + ) # Update compartments with flows. state["S"] = state["S"] - S2E @@ -238,9 +251,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2E, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2E, rho), + obs=self.data[t] if t_is_observed else None, + ) def compute_flows(self, prev, curr, t): S2E = prev["S"] - curr["S"] # S can only go to E. @@ -309,7 +324,7 @@ def __init__(self, population, recovery_time, data): def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) od = pyro.sample("od", dist.Beta(2, 6)) return R0, tau, rho, od @@ -322,15 +337,19 @@ def transition(self, params, state, t): R0, tau, rho, od = params # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0 / tau, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population, - overdispersion=od)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau, - overdispersion=od)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + overdispersion=od, + ), + ) + I2R = pyro.sample( + "I2R_{}".format(t), binomial_dist(state["I"], 1 / tau, overdispersion=od) + ) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -338,9 +357,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho, overdispersion=od), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho, overdispersion=od), + obs=self.data[t] if t_is_observed else None, + ) class OverdispersedSEIRModel(CompartmentalModel): @@ -402,7 +423,7 @@ def __init__(self, population, incubation_time, recovery_time, data): def global_model(self): tau_e = self.incubation_time tau_i = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) od = pyro.sample("od", dist.Beta(2, 6)) return R0, tau_e, tau_i, rho, od @@ -415,16 +436,22 @@ def transition(self, params, state, t): R0, tau_e, tau_i, rho, od = params # Sample flows between compartments. - S2E = pyro.sample("S2E_{}".format(t), - infection_dist(individual_rate=R0 / tau_i, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population, - overdispersion=od)) - E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e, overdispersion=od)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau_i, overdispersion=od)) + S2E = pyro.sample( + "S2E_{}".format(t), + infection_dist( + individual_rate=R0 / tau_i, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + overdispersion=od, + ), + ) + E2I = pyro.sample( + "E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e, overdispersion=od) + ) + I2R = pyro.sample( + "I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i, overdispersion=od) + ) # Update compartments with flows. state["S"] = state["S"] - S2E @@ -433,9 +460,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2E, rho, overdispersion=od), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2E, rho, overdispersion=od), + obs=self.data[t] if t_is_observed else None, + ) class SuperspreadingSIRModel(CompartmentalModel): @@ -490,8 +519,8 @@ def __init__(self, population, recovery_time, data): def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) - k = pyro.sample("k", dist.Exponential(1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) + k = pyro.sample("k", dist.Exponential(1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, k, tau, rho @@ -503,14 +532,17 @@ def transition(self, params, state, t): R0, k, tau, rho = params # Sample flows between compartments. - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population, - concentration=k)) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + concentration=k, + ), + ) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -518,9 +550,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho), + obs=self.data[t] if t_is_observed else None, + ) class SuperspreadingSEIRModel(CompartmentalModel): @@ -573,8 +607,16 @@ class SuperspreadingSEIRModel(CompartmentalModel): transitions. This allows false negative but no false positives. """ - def __init__(self, population, incubation_time, recovery_time, data, *, - leaf_times=None, coal_times=None): + def __init__( + self, + population, + incubation_time, + recovery_time, + data, + *, + leaf_times=None, + coal_times=None + ): compartments = ("S", "E", "I") # R is implicit. duration = len(data) super().__init__(compartments, duration, population) @@ -594,13 +636,14 @@ def __init__(self, population, incubation_time, recovery_time, data, *, self.coal_likelihood = None else: self.coal_likelihood = dist.CoalescentRateLikelihood( - leaf_times, coal_times, duration) + leaf_times, coal_times, duration + ) def global_model(self): tau_e = self.incubation_time tau_i = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) - k = pyro.sample("k", dist.Exponential(1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) + k = pyro.sample("k", dist.Exponential(1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, k, tau_e, tau_i, rho @@ -612,28 +655,35 @@ def transition(self, params, state, t): R0, k, tau_e, tau_i, rho = params # Sample flows between compartments. - E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau_i)) - S2E = pyro.sample("S2E_{}".format(t), - infection_dist(individual_rate=R0, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population, - concentration=k)) + E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e)) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i)) + S2E = pyro.sample( + "S2E_{}".format(t), + infection_dist( + individual_rate=R0, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + concentration=k, + ), + ) # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2E, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2E, rho), + obs=self.data[t] if t_is_observed else None, + ) if self.coal_likelihood is not None: R = R0 * state["S"] / self.population - coal_rate = R * (1. + 1. / k) / (tau_i * state["I"] + 1e-8) - pyro.factor("coalescent_{}".format(t), - self.coal_likelihood(coal_rate, t) - if t_is_observed else torch.tensor(0.)) + coal_rate = R * (1.0 + 1.0 / k) / (tau_i * state["I"] + 1e-8) + pyro.factor( + "coalescent_{}".format(t), + self.coal_likelihood(coal_rate, t) + if t_is_observed + else torch.tensor(0.0), + ) # Update compartements with flows. state["S"] = state["S"] - S2E @@ -674,7 +724,7 @@ def __init__(self, population, recovery_time, data): def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) # Let's consider a piecewise constant response rate, say low rate for # two weeks, then intermediate rate as testing capacity increases, and @@ -685,9 +735,14 @@ def global_model(self): rho1 = pyro.sample("rho1", dist.Beta(4, 4)) rho2 = pyro.sample("rho2", dist.Beta(8, 4)) # Later .transition() will index into this time series as rho[..., t]. - rho = torch.cat([rho0.unsqueeze(-1).expand(rho0.shape + (14,)), - rho1.unsqueeze(-1).expand(rho1.shape + (7,)), - rho2.unsqueeze(-1).expand(rho2.shape + (60,))], dim=-1) + rho = torch.cat( + [ + rho0.unsqueeze(-1).expand(rho0.shape + (14,)), + rho1.unsqueeze(-1).expand(rho1.shape + (7,)), + rho2.unsqueeze(-1).expand(rho2.shape + (60,)), + ], + dim=-1, + ) # We can also save the time series for output in self.samples. pyro.deterministic("rho", rho, event_dim=1) @@ -697,25 +752,29 @@ def initialize(self, params): R0, tau, rho = params # Start with a single infection. # We also store the initial beta value in the state dict. - return {"S": self.population - 1, "I": 1, "beta": torch.tensor(1.)} + return {"S": self.population - 1, "I": 1, "beta": torch.tensor(1.0)} def transition(self, params, state, t): R0, tau, rho = params # Sample heterogeneous variables. # This assumes beta slowly drifts via Brownian motion in log space. - beta = pyro.sample("beta_{}".format(t), - dist.LogNormal(state["beta"].log(), 0.1)) + beta = pyro.sample( + "beta_{}".format(t), dist.LogNormal(state["beta"].log(), 0.1) + ) Rt = pyro.deterministic("Rt_{}".format(t), R0 * beta) # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=Rt / tau, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=Rt / tau, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) # Update compartments and heterogeneous variables. state["S"] = state["S"] - S2I @@ -726,9 +785,11 @@ def transition(self, params, state, t): # Note that, since rho may be batched over particles or samples, we # need to index it via rho[..., t] rather than a simple rho[t]. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho[..., t]), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho[..., t]), + obs=self.data[t] if t_is_observed else None, + ) class SparseSIRModel(CompartmentalModel): @@ -774,7 +835,7 @@ def __init__(self, population, recovery_time, data, mask): def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) rho = pyro.sample("rho", dist.Beta(10, 10)) return R0, tau, rho @@ -786,15 +847,17 @@ def transition(self, params, state, t): R0, tau, rho = params # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0 / tau, - num_susceptible=state["S"], - num_infectious=state["I"], - population=self.population)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) - S2O = pyro.sample("S2O_{}".format(t), - binomial_dist(S2I, rho)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) + S2O = pyro.sample("S2O_{}".format(t), binomial_dist(S2I, rho)) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -805,10 +868,12 @@ def transition(self, params, state, t): t_is_observed = isinstance(t, slice) or t < self.duration mask_t = self.mask[t] if t_is_observed else False data_t = self.data[t] if t_is_observed else None - pyro.sample("obs_{}".format(t), - # FIXME Delta is incompatible with relaxed inference. - dist.Delta(state["O"]).mask(mask_t), - obs=data_t) + pyro.sample( + "obs_{}".format(t), + # FIXME Delta is incompatible with relaxed inference. + dist.Delta(state["O"]).mask(mask_t), + obs=data_t, + ) def compute_flows(self, prev, curr, t): # Reverse the flow computation. @@ -868,14 +933,14 @@ def __init__(self, population, recovery_time, pre_obs_window, data): # Prepend data with zeros. if isinstance(data, list): - data = [0.] * self.pre_obs_window + data + data = [0.0] * self.pre_obs_window + data else: - data = pad(data, (self.pre_obs_window, 0), value=0.) + data = pad(data, (self.pre_obs_window, 0), value=0.0) self.data = data def global_model(self): tau = self.recovery_time - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) # Assume two different response rates: rho0 before any observations # were made (in pre_obs_window), followed by a higher response rate rho1 @@ -884,10 +949,13 @@ def global_model(self): rho1 = pyro.sample("rho1", dist.Beta(10, 10)) # Whereas each of rho0,rho1 are scalars (possibly batched over samples), # we construct a time series rho with an extra time dim on the right. - rho = torch.cat([ - rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_obs_window,)), - rho1.unsqueeze(-1).expand(rho1.shape + (self.post_obs_window,)), - ], dim=-1) + rho = torch.cat( + [ + rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_obs_window,)), + rho1.unsqueeze(-1).expand(rho1.shape + (self.post_obs_window,)), + ], + dim=-1, + ) # Model external infections as an infectious pseudo-individual added # to num_infectious when sampling S2I below. @@ -903,13 +971,16 @@ def transition(self, params, state, t): R0, X, tau, rho = params # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0 / tau, - num_susceptible=state["S"], - num_infectious=state["I"] + X, - population=self.population)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=state["I"] + X, + population=self.population, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -922,9 +993,7 @@ def transition(self, params, state, t): data_t = self.data[t] if t_is_observed else None # Condition on observations. - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho_t), - obs=data_t) + pyro.sample("obs_{}".format(t), binomial_dist(S2I, rho_t), obs=data_t) def predict(self, forecast=0): """ @@ -992,7 +1061,7 @@ class RegionalSIRModel(CompartmentalModel): def __init__(self, population, coupling, recovery_time, data): duration = len(data) - num_regions, = population.shape + (num_regions,) = population.shape assert coupling.shape == (num_regions, num_regions) assert (0 <= coupling).all() assert (coupling <= 1).all() @@ -1015,7 +1084,7 @@ def global_model(self): tau = self.recovery_time # Assume reproductive number is unknown but homogeneous. - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) # Assume response rate is heterogeneous and model it with a # hierarchical Gamma-Beta prior. @@ -1046,13 +1115,16 @@ def transition(self, params, state, t): with self.region_plate: # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=R0 / tau, - num_susceptible=state["S"], - num_infectious=I_coupled, - population=pop_coupled)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=I_coupled, + population=pop_coupled, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) # Update compartments with flows. state["S"] = state["S"] - S2I @@ -1060,9 +1132,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho), + obs=self.data[t] if t_is_observed else None, + ) class HeterogeneousRegionalSIRModel(CompartmentalModel): @@ -1094,7 +1168,7 @@ class HeterogeneousRegionalSIRModel(CompartmentalModel): def __init__(self, population, coupling, recovery_time, data): duration = len(data) - num_regions, = population.shape + (num_regions,) = population.shape assert coupling.shape == (num_regions, num_regions) assert (0 <= coupling).all() assert (coupling <= 1).all() @@ -1116,13 +1190,13 @@ def global_model(self): tau = self.recovery_time # Assume reproductive number is heterogeneous but shared among regions. - R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) - R_drift = pyro.sample("R_drift", dist.LogNormal(-3., 1.)) + R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0)) + R_drift = pyro.sample("R_drift", dist.LogNormal(-3.0, 1.0)) # Assume response rate is heterogeneous in time and region. with self.region_plate: rho0 = pyro.sample("rho0", dist.Beta(4, 4)) - rho_drift = pyro.sample("rho_drift", dist.LogNormal(-3., 1.)) + rho_drift = pyro.sample("rho_drift", dist.LogNormal(-3.0, 1.0)) return tau, R0, R_drift, rho0, rho_drift @@ -1131,9 +1205,12 @@ def initialize(self, params): I = torch.zeros_like(self.population) I[0] += 1 S = self.population - I - return {"S": S, "I": I, - "R_factor": torch.tensor(1.), - "rho_shift": torch.tensor(0.)} + return { + "S": S, + "I": I, + "R_factor": torch.tensor(1.0), + "rho_shift": torch.tensor(0.0), + } def transition(self, params, state, t): tau, R0, R_drift, rho0, rho_drift = params @@ -1147,25 +1224,31 @@ def transition(self, params, state, t): pop_coupled = self.population @ self.coupling # Sample region-global time-heterogeneous variables. - R_factor = pyro.sample("R_factor_{}".format(t), - dist.LogNormal(state["R_factor"].log(), R_drift)) + R_factor = pyro.sample( + "R_factor_{}".format(t), dist.LogNormal(state["R_factor"].log(), R_drift) + ) Rt = pyro.deterministic("Rt_{}".format(t), R0 * R_factor) with self.region_plate: # Sample region-local time-heterogeneous variables. - rho_shift = pyro.sample("rho_shift_{}".format(t), - dist.Normal(state["rho_shift"], rho_drift)) - rho = pyro.deterministic("rho_{}".format(t), - (rho0.log() - (-rho0).log1p() + rho_shift).sigmoid()) + rho_shift = pyro.sample( + "rho_shift_{}".format(t), dist.Normal(state["rho_shift"], rho_drift) + ) + rho = pyro.deterministic( + "rho_{}".format(t), (rho0.log() - (-rho0).log1p() + rho_shift).sigmoid() + ) # Sample flows between compartments. - S2I = pyro.sample("S2I_{}".format(t), - infection_dist(individual_rate=Rt / tau, - num_susceptible=state["S"], - num_infectious=I_coupled, - population=pop_coupled)) - I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau)) + S2I = pyro.sample( + "S2I_{}".format(t), + infection_dist( + individual_rate=Rt / tau, + num_susceptible=state["S"], + num_infectious=I_coupled, + population=pop_coupled, + ), + ) + I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau)) # Update compartments and heterogeneous variables. state["S"] = state["S"] - S2I @@ -1175,9 +1258,11 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration - pyro.sample("obs_{}".format(t), - binomial_dist(S2I, rho), - obs=self.data[t] if t_is_observed else None) + pyro.sample( + "obs_{}".format(t), + binomial_dist(S2I, rho), + obs=self.data[t] if t_is_observed else None, + ) # Create sphinx documentation. @@ -1186,12 +1271,18 @@ def transition(self, params, state, t): if isinstance(_Model, type) and issubclass(_Model, CompartmentalModel): if _Model is not CompartmentalModel: __all__.append(_name) -__all__.sort(key=lambda name, vals=locals(): vals[name].__init__.__code__.co_firstlineno) -__doc__ = "\n\n".join([ - """ +__all__.sort( + key=lambda name, vals=locals(): vals[name].__init__.__code__.co_firstlineno +) +__doc__ = "\n\n".join( + [ + """ {} ---------------------------------------------------------------- .. autoclass:: pyro.contrib.epidemiology.models.{} - """.format(re.sub("([A-Z][a-z]+)", r"\1 ", _name[:-5]), _name) - for _name in __all__ -]) + """.format( + re.sub("([A-Z][a-z]+)", r"\1 ", _name[:-5]), _name + ) + for _name in __all__ + ] +) diff --git a/pyro/contrib/epidemiology/util.py b/pyro/contrib/epidemiology/util.py index 14af951ff8..0d137e328d 100644 --- a/pyro/contrib/epidemiology/util.py +++ b/pyro/contrib/epidemiology/util.py @@ -80,28 +80,93 @@ def align_samples(samples, model, particle_dim): # this 8 x 10 tensor encodes the coefficients of 8 10-dimensional polynomials # that are used to construct the num_quant_bins=16 quantization strategy -W16 = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1562511562511555e-07], - [1.1562511562511557e-07, 1.04062604062604e-06, 4.16250416250416e-06, - 9.712509712509707e-06, 1.456876456876456e-05, 1.4568764568764562e-05, - 9.712509712509707e-06, 4.16250416250416e-06, 1.04062604062604e-06, -6.937506937506934e-07], - [5.839068339068337e-05, 0.0002591158841158841, 0.0005036630036630038, - 0.0005536130536130536, 0.00036421911421911425, 0.00013111888111888106, - 9.712509712509736e-06, -1.2487512487512482e-05, -5.2031302031302014e-06, 1.6187516187516182e-06], - [0.0018637612387612374, 0.004983558108558107, 0.005457042957042955, - 0.0029234654234654212, 0.000568181818181818, -0.0001602564102564102, - -8.741258741258739e-05, 4.162504162504162e-06, 9.365634365634364e-06, -1.7536475869809201e-06], - [0.015560115039281694, 0.025703289765789755, 0.015009296259296255, - 0.0023682336182336166, -0.000963966588966589, -0.00029380341880341857, - 5.6656306656306665e-05, 1.5956265956265953e-05, -6.417193917193917e-06, 7.515632515632516e-07], - [0.057450111616778265, 0.05790875790875791, 0.014424464424464418, - -0.0030303030303030303, -0.0013791763791763793, 0.00011655011655011669, - 5.180005180005181e-05, -8.325008325008328e-06, 3.4687534687534703e-07, 0.0], - [0.12553422657589322, 0.072988122988123, -0.0011641136641136712, - -0.006617456617456618, -0.00028651903651903725, 0.00027195027195027195, - 3.2375032375032334e-06, -5.550005550005552e-06, 3.4687534687534703e-07, 0.0], - [0.21761806865973532, 1.7482707128494565e-17, -0.028320290820290833, - 0.0, 0.0014617327117327117, 0.0, - -3.561253561253564e-05, 0.0, 3.4687534687534714e-07, 0.0]] +W16 = [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1562511562511555e-07], + [ + 1.1562511562511557e-07, + 1.04062604062604e-06, + 4.16250416250416e-06, + 9.712509712509707e-06, + 1.456876456876456e-05, + 1.4568764568764562e-05, + 9.712509712509707e-06, + 4.16250416250416e-06, + 1.04062604062604e-06, + -6.937506937506934e-07, + ], + [ + 5.839068339068337e-05, + 0.0002591158841158841, + 0.0005036630036630038, + 0.0005536130536130536, + 0.00036421911421911425, + 0.00013111888111888106, + 9.712509712509736e-06, + -1.2487512487512482e-05, + -5.2031302031302014e-06, + 1.6187516187516182e-06, + ], + [ + 0.0018637612387612374, + 0.004983558108558107, + 0.005457042957042955, + 0.0029234654234654212, + 0.000568181818181818, + -0.0001602564102564102, + -8.741258741258739e-05, + 4.162504162504162e-06, + 9.365634365634364e-06, + -1.7536475869809201e-06, + ], + [ + 0.015560115039281694, + 0.025703289765789755, + 0.015009296259296255, + 0.0023682336182336166, + -0.000963966588966589, + -0.00029380341880341857, + 5.6656306656306665e-05, + 1.5956265956265953e-05, + -6.417193917193917e-06, + 7.515632515632516e-07, + ], + [ + 0.057450111616778265, + 0.05790875790875791, + 0.014424464424464418, + -0.0030303030303030303, + -0.0013791763791763793, + 0.00011655011655011669, + 5.180005180005181e-05, + -8.325008325008328e-06, + 3.4687534687534703e-07, + 0.0, + ], + [ + 0.12553422657589322, + 0.072988122988123, + -0.0011641136641136712, + -0.006617456617456618, + -0.00028651903651903725, + 0.00027195027195027195, + 3.2375032375032334e-06, + -5.550005550005552e-06, + 3.4687534687534703e-07, + 0.0, + ], + [ + 0.21761806865973532, + 1.7482707128494565e-17, + -0.028320290820290833, + 0.0, + 0.0014617327117327117, + 0.0, + -3.561253561253564e-05, + 0.0, + 3.4687534687534714e-07, + 0.0, + ], +] W16 = numpy.array(W16) @@ -125,12 +190,18 @@ def compute_bin_probs(s, num_quant_bins): if num_quant_bins == 4: # This cubic spline interpolates over the nearest four integers, ensuring # piecewise quadratic gradients. - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) + probs = ( + torch.stack( + [ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], + dim=-1, + ) + * (1 / 6) + ) return probs if num_quant_bins == 8: @@ -144,16 +215,22 @@ def compute_bin_probs(s, num_quant_bins): t4 = tt * tt t5 = t3 * tt - probs = torch.stack([ - 2 * t5, - 2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5, - 55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5, - 302 - 100 * ss + 10 * s4, - 302 - 100 * tt + 10 * t4, - 55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5, - 2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5, - 2 * s5 - ], dim=-1) * (1/840) + probs = ( + torch.stack( + [ + 2 * t5, + 2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5, + 55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5, + 302 - 100 * ss + 10 * s4, + 302 - 100 * tt + 10 * t4, + 55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5, + 2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5, + 2 * s5, + ], + dim=-1, + ) + * (1 / 840) + ) return probs if num_quant_bins == 12: @@ -170,27 +247,87 @@ def compute_bin_probs(s, num_quant_bins): t6 = t3 * t3 t7 = t4 * t3 - probs = torch.stack([ - 693 * t7, - 693 + 4851 * t + 14553 * tt + 24255 * t3 + 24255 * t4 + 14553 * t5 + 4851 * t6 - 3267 * t7, - 84744 + 282744 * t + 382536 * tt + 249480 * t3 + 55440 * t4 - 24948 * t5 - 18018 * t6 + 5445 * t7, - 1017423 + 1823283 * t + 1058211 * tt + 51975 * t3 - 148995 * t4 - 18711 * t5 + 20097 * t6 - 3267 * t7, - 3800016 + 3503808 * t + 365904 * tt - 443520 * t3 - 55440 * t4 + 33264 * t5 - 2772 * t6, - 8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6, - 8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6, - 3800016 + 3503808 * s + 365904 * ss - 443520 * s3 - 55440 * s4 + 33264 * s5 - 2772 * s6, - 1017423 + 1823283 * s + 1058211 * ss + 51975 * s3 - 148995 * s4 - 18711 * s5 + 20097 * s6 - 3267 * s7, - 84744 + 282744 * s + 382536 * ss + 249480 * s3 + 55440 * s4 - 24948 * s5 - 18018 * s6 + 5445 * s7, - 693 + 4851 * s + 14553 * ss + 24255 * s3 + 24255 * s4 + 14553 * s5 + 4851 * s6 - 3267 * s7, - 693 * s7, - ], dim=-1) * (1/32931360) + probs = ( + torch.stack( + [ + 693 * t7, + 693 + + 4851 * t + + 14553 * tt + + 24255 * t3 + + 24255 * t4 + + 14553 * t5 + + 4851 * t6 + - 3267 * t7, + 84744 + + 282744 * t + + 382536 * tt + + 249480 * t3 + + 55440 * t4 + - 24948 * t5 + - 18018 * t6 + + 5445 * t7, + 1017423 + + 1823283 * t + + 1058211 * tt + + 51975 * t3 + - 148995 * t4 + - 18711 * t5 + + 20097 * t6 + - 3267 * t7, + 3800016 + + 3503808 * t + + 365904 * tt + - 443520 * t3 + - 55440 * t4 + + 33264 * t5 + - 2772 * t6, + 8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6, + 8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6, + 3800016 + + 3503808 * s + + 365904 * ss + - 443520 * s3 + - 55440 * s4 + + 33264 * s5 + - 2772 * s6, + 1017423 + + 1823283 * s + + 1058211 * ss + + 51975 * s3 + - 148995 * s4 + - 18711 * s5 + + 20097 * s6 + - 3267 * s7, + 84744 + + 282744 * s + + 382536 * ss + + 249480 * s3 + + 55440 * s4 + - 24948 * s5 + - 18018 * s6 + + 5445 * s7, + 693 + + 4851 * s + + 14553 * ss + + 24255 * s3 + + 24255 * s4 + + 14553 * s5 + + 4851 * s6 + - 3267 * s7, + 693 * s7, + ], + dim=-1, + ) + * (1 / 32931360) + ) return probs if num_quant_bins == 16: # This nonic spline interpolates over the nearest 16 integers w16 = torch.from_numpy(W16).to(s.device).type_as(s) - s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.)) - t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.)) + s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.0)) + t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.0)) splines_t = (w16 * t_powers).sum(-1) splines_s = (w16 * s_powers).sum(-1) index = [0, 1, 2, 3, 4, 5, 6, 15, 7, 14, 13, 12, 11, 10, 9, 8] @@ -220,8 +357,9 @@ def quantize(name, x_real, min, max, num_quant_bins=4): probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) - q = pyro.sample("Q_" + name, dist.Categorical(probs), - infer={"enumerate": "parallel"}) + q = pyro.sample( + "Q_" + name, dist.Categorical(probs), infer={"enumerate": "parallel"} + ) q = q.type_as(x_real) - (num_quant_bins // 2 - 1) x = lb + q diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index 0398ad137d..f2ea719566 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -122,16 +122,19 @@ def load_bart_od(): if os.path.exists(pkl_file): return torch.load(pkl_file) - filenames = multiprocessing.Pool(len(SOURCE_FILES)).map(_load_hourly_od, SOURCE_FILES) + filenames = multiprocessing.Pool(len(SOURCE_FILES)).map( + _load_hourly_od, SOURCE_FILES + ) datasets = list(map(torch.load, filenames)) stations = sorted(set().union(*(d["stations"].keys() for d in datasets))) min_time = min(int(d["rows"][:, 0].min()) for d in datasets) max_time = max(int(d["rows"][:, 0].max()) for d in datasets) num_rows = max_time - min_time + 1 - start_date = datasets[0]["start_date"] + datetime.timedelta(hours=min_time), - logging.info("Loaded data from {} stations, {} hours" - .format(len(stations), num_rows)) + start_date = (datasets[0]["start_date"] + datetime.timedelta(hours=min_time),) + logging.info( + "Loaded data from {} stations, {} hours".format(len(stations), num_rows) + ) result = torch.zeros(num_rows, len(stations), len(stations)) for dataset in datasets: @@ -143,8 +146,9 @@ def load_bart_od(): count = dataset["rows"][:, 3].float() result[time, origin, destin] = count dataset.clear() - logging.info("Loaded {} shaped data of mean {:0.3g}" - .format(result.shape, result.mean())) + logging.info( + "Loaded {} shaped data of mean {:0.3g}".format(result.shape, result.mean()) + ) dataset = { "stations": stations, @@ -174,6 +178,8 @@ def load_fake_od(): parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() - logging.basicConfig(format='%(relativeCreated) 9d %(message)s', - level=logging.DEBUG if args.verbose else logging.INFO) + logging.basicConfig( + format="%(relativeCreated) 9d %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) load_bart_od() diff --git a/pyro/contrib/examples/multi_mnist.py b/pyro/contrib/examples/multi_mnist.py index 3859c4fe53..6c73c5e7ad 100644 --- a/pyro/contrib/examples/multi_mnist.py +++ b/pyro/contrib/examples/multi_mnist.py @@ -23,9 +23,9 @@ def imresize(arr, size): def sample_one(canvas_size, mnist): - i = np.random.randint(mnist['digits'].shape[0]) - digit = mnist['digits'][i] - label = mnist['labels'][i].item() + i = np.random.randint(mnist["digits"].shape[0]) + digit = mnist["digits"][i] + label = mnist["labels"][i].item() scale = 0.1 * np.random.randn() + 1.3 new_size = tuple(int(s / scale) for s in digit.shape) resized = imresize(digit, new_size) @@ -35,7 +35,7 @@ def sample_one(canvas_size, mnist): pad_l = np.random.randint(0, padding) pad_r = np.random.randint(0, padding) pad_width = ((pad_l, padding - pad_l), (pad_r, padding - pad_r)) - positioned = np.pad(resized, pad_width, 'constant', constant_values=0) + positioned = np.pad(resized, pad_width, "constant", constant_values=0) return positioned, label @@ -65,30 +65,30 @@ def mk_dataset(n, mnist, max_digits, canvas_size): def load_mnist(root_path): - loader = get_data_loader('MNIST', root_path) + loader = get_data_loader("MNIST", root_path) return { - 'digits': loader.dataset.data.cpu().numpy(), - 'labels': loader.dataset.targets + "digits": loader.dataset.data.cpu().numpy(), + "labels": loader.dataset.targets, } def load(root_path): - file_path = os.path.join(root_path, 'multi_mnist_uint8.npz') + file_path = os.path.join(root_path, "multi_mnist_uint8.npz") if os.path.exists(file_path): data = np.load(file_path, allow_pickle=True) - return data['x'], data['y'] + return data["x"], data["y"] else: # Set RNG to known state. rng_state = np.random.get_state() np.random.seed(681307) mnist = load_mnist(root_path) - print('Generating multi-MNIST dataset...') + print("Generating multi-MNIST dataset...") x, y = mk_dataset(60000, mnist, 2, 50) # Revert RNG state. np.random.set_state(rng_state) # Crude checksum. # assert x.sum() == 883114919, 'Did not generate the expected data.' - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: np.savez_compressed(f, x=x, y=y) - print('Done!') + print("Done!") return x, y diff --git a/pyro/contrib/examples/polyphonic_data_loader.py b/pyro/contrib/examples/polyphonic_data_loader.py index 491ae0517f..7791c0162b 100644 --- a/pyro/contrib/examples/polyphonic_data_loader.py +++ b/pyro/contrib/examples/polyphonic_data_loader.py @@ -29,21 +29,29 @@ dset = namedtuple("dset", ["name", "url", "filename"]) -JSB_CHORALES = dset("jsb_chorales", - "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle", - "jsb_chorales.pkl") - -PIANO_MIDI = dset("piano_midi", - "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle", - "piano_midi.pkl") - -MUSE_DATA = dset("muse_data", - "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle", - "muse_data.pkl") - -NOTTINGHAM = dset("nottingham", - "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle", - "nottingham.pkl") +JSB_CHORALES = dset( + "jsb_chorales", + "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle", + "jsb_chorales.pkl", +) + +PIANO_MIDI = dset( + "piano_midi", + "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle", + "piano_midi.pkl", +) + +MUSE_DATA = dset( + "muse_data", + "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle", + "muse_data.pkl", +) + +NOTTINGHAM = dset( + "nottingham", + "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle", + "nottingham.pkl", +) # this function processes the raw data; in particular it unsparsifies it @@ -64,18 +72,20 @@ def process_data(base_path, dataset, min_note=21, note_range=88): for split, data_split in data.items(): processed_dataset[split] = {} n_seqs = len(data_split) - processed_dataset[split]['sequence_lengths'] = torch.zeros(n_seqs, dtype=torch.long) - processed_dataset[split]['sequences'] = [] + processed_dataset[split]["sequence_lengths"] = torch.zeros( + n_seqs, dtype=torch.long + ) + processed_dataset[split]["sequences"] = [] for seq in range(n_seqs): seq_length = len(data_split[seq]) - processed_dataset[split]['sequence_lengths'][seq] = seq_length + processed_dataset[split]["sequence_lengths"][seq] = seq_length processed_sequence = torch.zeros((seq_length, note_range)) for t in range(seq_length): note_slice = torch.tensor(list(data_split[seq][t])) - min_note slice_length = len(note_slice) if slice_length > 0: processed_sequence[t, note_slice] = torch.ones(slice_length) - processed_dataset[split]['sequences'].append(processed_sequence) + processed_dataset[split]["sequences"].append(processed_sequence) pickle.dump(processed_dataset, open(output, "wb"), pickle.HIGHEST_PROTOCOL) print("dumped processed data to %s" % output) @@ -95,8 +105,12 @@ def load_data(dataset): dset = pickle.load(f) for k, v in dset.items(): sequences = v["sequences"] - dset[k]["sequences"] = pad_sequence(sequences, batch_first=True).type(torch.Tensor) - dset[k]["sequence_lengths"] = v["sequence_lengths"].to(device=torch.Tensor().device) + dset[k]["sequences"] = pad_sequence(sequences, batch_first=True).type( + torch.Tensor + ) + dset[k]["sequence_lengths"] = v["sequence_lengths"].to( + device=torch.Tensor().device + ) return dset @@ -125,7 +139,7 @@ def pad_and_reverse(rnn_output, seq_lengths): def get_mini_batch_mask(mini_batch, seq_lengths): mask = torch.zeros(mini_batch.shape[0:2]) for b in range(mini_batch.shape[0]): - mask[b, 0:seq_lengths[b]] = torch.ones(seq_lengths[b]) + mask[b, 0 : seq_lengths[b]] = torch.ones(seq_lengths[b]) return mask @@ -159,8 +173,8 @@ def get_mini_batch(mini_batch_indices, sequences, seq_lengths, cuda=False): mini_batch_reversed = mini_batch_reversed.cuda() # do sequence packing - mini_batch_reversed = nn.utils.rnn.pack_padded_sequence(mini_batch_reversed, - sorted_seq_lengths, - batch_first=True) + mini_batch_reversed = nn.utils.rnn.pack_padded_sequence( + mini_batch_reversed, sorted_seq_lengths, batch_first=True + ) return mini_batch, mini_batch_reversed, mini_batch_mask, sorted_seq_lengths diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index 3caad9c414..2d86901c5d 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -15,12 +15,14 @@ class MNIST(datasets.MNIST): ] + datasets.MNIST.mirrors -def get_data_loader(dataset_name, - data_dir, - batch_size=1, - dataset_transforms=None, - is_training_set=True, - shuffle=True): +def get_data_loader( + dataset_name, + data_dir, + batch_size=1, + dataset_transforms=None, + is_training_set=True, + shuffle=True, +): if not dataset_transforms: dataset_transforms = [] trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms) @@ -29,16 +31,9 @@ def get_data_loader(dataset_name, else: dataset = getattr(datasets, dataset_name) print("downloading data") - dset = dataset(root=data_dir, - train=is_training_set, - transform=trans, - download=True) + dset = dataset(root=data_dir, train=is_training_set, transform=trans, download=True) print("download complete.") - return DataLoader( - dset, - batch_size=batch_size, - shuffle=shuffle - ) + return DataLoader(dset, batch_size=batch_size, shuffle=shuffle) def print_and_log(logger, msg): @@ -51,10 +46,9 @@ def print_and_log(logger, msg): def get_data_directory(filepath=None): - if 'CI' in os.environ: - return os.path.expanduser('~/.data') - return os.path.abspath(os.path.join(os.path.dirname(filepath), - '.data')) + if "CI" in os.environ: + return os.path.expanduser("~/.data") + return os.path.abspath(os.path.join(os.path.dirname(filepath), ".data")) def _mkdir_p(dirname): diff --git a/pyro/contrib/forecast/evaluate.py b/pyro/contrib/forecast/evaluate.py index 9a9d68d639..cf81aa2ea5 100644 --- a/pyro/contrib/forecast/evaluate.py +++ b/pyro/contrib/forecast/evaluate.py @@ -68,19 +68,24 @@ def eval_crps(pred, truth): } -def backtest(data, covariates, model_fn, *, - forecaster_fn=Forecaster, - metrics=None, - transform=None, - train_window=None, - min_train_window=1, - test_window=None, - min_test_window=1, - stride=1, - seed=1234567890, - num_samples=100, - batch_size=None, - forecaster_options={}): +def backtest( + data, + covariates, + model_fn, + *, + forecaster_fn=Forecaster, + metrics=None, + transform=None, + train_window=None, + min_train_window=1, + test_window=None, + min_test_window=1, + stride=1, + seed=1234567890, + num_samples=100, + batch_size=None, + forecaster_options={}, +): """ Backtest a forecasting model on a moving window of (train,test) data. @@ -141,11 +146,15 @@ def backtest(data, covariates, model_fn, *, if callable(forecaster_options): forecaster_options_fn = forecaster_options else: + def forecaster_options_fn(*args, **kwargs): return forecaster_options + if train_window is not None and forecaster_options_fn().get("warm_start"): - raise ValueError("Cannot warm start with moving training window; " - "either set warm_start=False or train_window=None") + raise ValueError( + "Cannot warm start with moving training window; " + "either set warm_start=False or train_window=None" + ) duration = data.size(-2) if test_window is None: @@ -163,8 +172,11 @@ def forecaster_options_fn(*args, **kwargs): t0 = 0 if train_window is None else t1 - train_window t2 = duration if test_window is None else t1 + test_window assert 0 <= t0 < t1 < t2 <= duration - logger.info("Training on window [{t0}:{t1}], testing on window [{t1}:{t2}]" - .format(t0=t0, t1=t1, t2=t2)) + logger.info( + "Training on window [{t0}:{t1}], testing on window [{t1}:{t2}]".format( + t0=t0, t1=t1, t2=t2 + ) + ) # Train a forecaster on the training window. pyro.set_rng_seed(seed) @@ -175,8 +187,9 @@ def forecaster_options_fn(*args, **kwargs): train_covariates = covariates[..., t0:t1, :] start_time = default_timer() model = model_fn() - forecaster = forecaster_fn(model, train_data, train_covariates, - **forecaster_options) + forecaster = forecaster_fn( + model, train_data, train_covariates, **forecaster_options + ) train_walltime = default_timer() - start_time # Forecast forward to testing window. @@ -185,14 +198,20 @@ def forecaster_options_fn(*args, **kwargs): # Gradually reduce batch_size to avoid OOM errors. while True: try: - pred = forecaster(train_data, test_covariates, num_samples=num_samples, - batch_size=batch_size) + pred = forecaster( + train_data, + test_covariates, + num_samples=num_samples, + batch_size=batch_size, + ) break except RuntimeError as e: if "out of memory" in str(e) and batch_size > 1: batch_size = (batch_size + 1) // 2 - warnings.warn("out of memory, decreasing batch_size to {}" - .format(batch_size), RuntimeWarning) + warnings.warn( + "out of memory, decreasing batch_size to {}".format(batch_size), + RuntimeWarning, + ) else: raise test_walltime = default_timer() - start_time diff --git a/pyro/contrib/forecast/forecaster.py b/pyro/contrib/forecast/forecaster.py index 96ec482c3c..505660bb08 100644 --- a/pyro/contrib/forecast/forecaster.py +++ b/pyro/contrib/forecast/forecaster.py @@ -40,6 +40,7 @@ class ForecastingModel(PyroModule, metaclass=_ForecastingModelMeta): Derived classes must implement the :meth:`model` method. """ + def __init__(self): super().__init__() self._prefix_condition_data = {} @@ -108,12 +109,14 @@ def predict(self, noise_dist, prediction): if noise_dist.event_dim == 0: if noise_dist.batch_shape[-2:] != prediction.shape[-2:]: noise_dist = noise_dist.expand( - noise_dist.batch_shape[:-2] + prediction.shape[-2:]) + noise_dist.batch_shape[:-2] + prediction.shape[-2:] + ) noise_dist = noise_dist.to_event(2) elif noise_dist.event_dim == 1: if noise_dist.batch_shape[-1:] != prediction.shape[-2:-1]: noise_dist = noise_dist.expand( - noise_dist.batch_shape[:-1] + prediction.shape[-2:-1]) + noise_dist.batch_shape[:-1] + prediction.shape[-2:-1] + ) noise_dist = noise_dist.to_event(1) assert noise_dist.event_dim == 2 assert noise_dist.event_shape == prediction.shape[-2:] @@ -154,7 +157,7 @@ def predict(self, noise_dist, prediction): noise = pyro.sample("residual", noise_dist) del self._prefix_condition_data["residual"] - assert noise.shape[-data.dim():] == right_pred.shape[-data.dim():] + assert noise.shape[-data.dim() :] == right_pred.shape[-data.dim() :] self._forecast = right_pred + noise # Move the "time" batch dim back to its original place. @@ -175,7 +178,8 @@ def forward(self, data, covariates): zero_data = data.new_zeros(()).expand(data.shape) else: # forecasting zero_data = data.new_zeros(()).expand( - data.shape[:-2] + covariates.shape[-2:-1] + data.shape[-1:]) + data.shape[:-2] + covariates.shape[-2:-1] + data.shape[-1:] + ) self._forecast = None self.model(zero_data, covariates) @@ -252,24 +256,31 @@ class Forecaster(nn.Module): when publishing metrics. :param int log_every: Number of training steps between logging messages. """ - def __init__(self, model, data, covariates, *, - guide=None, - init_loc_fn=init_to_sample, - init_scale=0.1, - create_plates=None, - optim=None, - learning_rate=0.01, - betas=(0.9, 0.99), - learning_rate_decay=0.1, - clip_norm=10.0, - time_reparam=None, - dct_gradients=False, - subsample_aware=False, - num_steps=1001, - num_particles=1, - vectorize_particles=True, - warm_start=False, - log_every=100): + + def __init__( + self, + model, + data, + covariates, + *, + guide=None, + init_loc_fn=init_to_sample, + init_scale=0.1, + create_plates=None, + optim=None, + learning_rate=0.01, + betas=(0.9, 0.99), + learning_rate_decay=0.1, + clip_norm=10.0, + time_reparam=None, + dct_gradients=False, + subsample_aware=False, + num_steps=1001, + num_particles=1, + vectorize_particles=True, + warm_start=False, + log_every=100 + ): assert data.size(-2) == covariates.size(-2) super().__init__() self.model = model @@ -280,8 +291,12 @@ def __init__(self, model, data, covariates, *, elif time_reparam is not None: raise ValueError("unknown time_reparam: {}".format(time_reparam)) if guide is None: - guide = AutoNormal(model, init_loc_fn=init_loc_fn, init_scale=init_scale, - create_plates=create_plates) + guide = AutoNormal( + model, + init_loc_fn=init_loc_fn, + init_scale=init_scale, + create_plates=create_plates, + ) self.guide = guide # Initialize. @@ -291,18 +306,24 @@ def __init__(self, model, data, covariates, *, if dct_gradients: model = MarkDCTParamMessenger("time")(model) guide = MarkDCTParamMessenger("time")(guide) - elbo = Trace_ELBO(num_particles=num_particles, - vectorize_particles=vectorize_particles) + elbo = Trace_ELBO( + num_particles=num_particles, vectorize_particles=vectorize_particles + ) elbo._guess_max_plate_nesting(model, guide, (data, covariates), {}) elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1) # force a time plate losses = [] if num_steps: if optim is None: - optim = DCTAdam({"lr": learning_rate, "betas": betas, - "lrd": learning_rate_decay ** (1 / num_steps), - "clip_norm": clip_norm, - "subsample_aware": subsample_aware}) + optim = DCTAdam( + { + "lr": learning_rate, + "betas": betas, + "lrd": learning_rate_decay ** (1 / num_steps), + "clip_norm": clip_norm, + "subsample_aware": subsample_aware, + } + ) svi = SVI(model, guide, optim, elbo) for step in range(num_steps): loss = svi.step(data, covariates) / data.numel() @@ -361,7 +382,8 @@ def forward(self, data, covariates, num_samples, batch_size=None): if data.size(-2) < covariates.size(-2): stack.enter_context(PrefixReplayMessenger(tr.trace)) stack.enter_context( - PrefixConditionMessenger(self.model._prefix_condition_data)) + PrefixConditionMessenger(self.model._prefix_condition_data) + ) else: stack.enter_context(poutine.replay(trace=tr.trace)) with pyro.plate("particles", num_samples, dim=dim): @@ -399,9 +421,21 @@ class HMCForecaster(nn.Module): doubling scheme of the :class:`~pyro.infer.mcmc.nuts.NUTS` sampler. Defaults to 10. """ - def __init__(self, model, data, covariates=None, *, - num_warmup=1000, num_samples=1000, num_chains=1, time_reparam=None, - dense_mass=False, jit_compile=False, max_tree_depth=10): + + def __init__( + self, + model, + data, + covariates=None, + *, + num_warmup=1000, + num_samples=1000, + num_chains=1, + time_reparam=None, + dense_mass=False, + jit_compile=False, + max_tree_depth=10 + ): assert data.size(-2) == covariates.size(-2) super().__init__() if time_reparam == "haar": @@ -414,12 +448,25 @@ def __init__(self, model, data, covariates=None, *, max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {}) self.max_plate_nesting = max(max_plate_nesting, 1) # force a time plate - kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True, - max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting) - mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains) + kernel = NUTS( + model, + full_mass=dense_mass, + jit_compile=jit_compile, + ignore_jit_warnings=True, + max_tree_depth=max_tree_depth, + max_plate_nesting=max_plate_nesting, + ) + mcmc = MCMC( + kernel, + warmup_steps=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + ) mcmc.run(data, covariates) # conditions to compute rhat - if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2): + if (num_chains == 1 and num_samples >= 4) or ( + num_chains > 1 and num_samples >= 2 + ): mcmc.summary() # inspect the model with particles plate = 1, so that we can reshape samples to @@ -476,16 +523,22 @@ def forward(self, data, covariates, num_samples, batch_size=None): with torch.no_grad(): weights = torch.ones(self._num_samples, device=data.device) - indices = torch.multinomial(weights, num_samples, replacement=num_samples > self._num_samples) + indices = torch.multinomial( + weights, num_samples, replacement=num_samples > self._num_samples + ) for name, node in list(self._trace.nodes.items()): sample = self._samples[name].index_select(0, indices) - node['value'] = sample.reshape( - (num_samples,) + (1,) * (node['value'].dim() - sample.dim()) + sample.shape[1:]) + node["value"] = sample.reshape( + (num_samples,) + + (1,) * (node["value"].dim() - sample.dim()) + + sample.shape[1:] + ) with ExitStack() as stack: if data.size(-2) < covariates.size(-2): stack.enter_context(PrefixReplayMessenger(self._trace)) stack.enter_context( - PrefixConditionMessenger(self.model._prefix_condition_data)) + PrefixConditionMessenger(self.model._prefix_condition_data) + ) with pyro.plate("particles", num_samples, dim=dim): return self.model(data, covariates) diff --git a/pyro/contrib/forecast/util.py b/pyro/contrib/forecast/util.py index 1f0a4e5ed2..0736c0895b 100644 --- a/pyro/contrib/forecast/util.py +++ b/pyro/contrib/forecast/util.py @@ -48,6 +48,7 @@ class MarkDCTParamMessenger(Messenger): :param str name: The name of the plate along which to apply discrete cosine transforms on gradients. """ + def __init__(self, name): super().__init__() self.name = name @@ -72,6 +73,7 @@ class PrefixWarmStartMessenger(Messenger): defined on a short time window, re-initialize by splicing old params with new initial params defined on a longer time window. """ + def _pyro_param(self, msg): store = get_param_store() name = msg["name"] @@ -101,7 +103,7 @@ def _pyro_param(self, msg): if new.size(dim) != old.size(dim): break assert new.size(dim) > old.size(dim) - assert new.shape[dim + 1:] == old.shape[dim + 1:] + assert new.shape[dim + 1 :] == old.shape[dim + 1 :] split = old.size(dim) index = (slice(None),) * dim + (slice(split, None),) new = torch.cat([old, new[index]], dim=dim) @@ -118,6 +120,7 @@ class PrefixReplayMessenger(Messenger): :param trace: a guide trace. :type trace: ~pyro.poutine.trace_struct.Trace """ + def __init__(self, trace): super().__init__() self.trace = trace @@ -142,7 +145,7 @@ def _pyro_post_sample(self, msg): if model_value.size(dim) != guide_value.size(dim): break assert model_value.size(dim) > guide_value.size(dim) - assert model_value.shape[dim + 1:] == guide_value.shape[dim + 1:] + assert model_value.shape[dim + 1 :] == guide_value.shape[dim + 1 :] split = guide_value.size(dim) index = (slice(None),) * dim + (slice(split, None),) msg["value"] = torch.cat([guide_value, model_value[index]], dim=dim) @@ -155,6 +158,7 @@ class PrefixConditionMessenger(Messenger): :param dict data: A dict mapping site name to tensors of observations. """ + def __init__(self, data): super().__init__() self.data = data @@ -214,7 +218,9 @@ def prefix_condition(d, data): try: return d.prefix_condition(data) except AttributeError as e: - raise NotImplementedError("prefix_condition() does not suport {}".format(type(d))) from e + raise NotImplementedError( + "prefix_condition() does not suport {}".format(type(d)) + ) from e @prefix_condition.register(dist.MaskedDistribution) @@ -253,8 +259,7 @@ def _(d, data): def _prefix_condition_univariate(d, data): t = data.size(-2) - params = {name: getattr(d, name)[..., t:, :] - for name in UNIVARIATE_DISTS[type(d)]} + params = {name: getattr(d, name)[..., t:, :] for name in UNIVARIATE_DISTS[type(d)]} return type(d)(**params) @@ -300,7 +305,7 @@ def _(d, batch_shape): @reshape_batch.register(dist.Independent) def _(d, batch_shape): - base_shape = batch_shape + d.event_shape[:d.reinterpreted_batch_ndims] + base_shape = batch_shape + d.event_shape[: d.reinterpreted_batch_ndims] base_dist = reshape_batch(d.base_dist, base_shape) return base_dist.to_event(d.reinterpreted_batch_ndims) @@ -323,15 +328,18 @@ def _(d, batch_shape): base_dist = reshape_batch(d.base_dist, batch_shape) old_shape = d.base_dist.shape() new_shape = base_dist.shape() - transforms = [reshape_transform_batch(t, old_shape, new_shape) - for t in d.transforms] + transforms = [ + reshape_transform_batch(t, old_shape, new_shape) for t in d.transforms + ] return dist.TransformedDistribution(base_dist, transforms) def _reshape_batch_univariate(d, batch_shape): batch_shape = batch_shape + (-1,) * d.event_dim - params = {name: getattr(d, name).reshape(batch_shape) - for name in UNIVARIATE_DISTS[type(d)]} + params = { + name: getattr(d, name).reshape(batch_shape) + for name in UNIVARIATE_DISTS[type(d)] + } return type(d)(**params) @@ -370,15 +378,18 @@ def _(d, batch_shape): new._init = init new._trans = trans new._obs = obs - super(dist.GaussianHMM, new).__init__(d.duration, batch_shape, d.event_shape, - validate_args=d._validate_args) + super(dist.GaussianHMM, new).__init__( + d.duration, batch_shape, d.event_shape, validate_args=d._validate_args + ) return new @reshape_batch.register(dist.LinearHMM) def _(d, batch_shape): init_dist = reshape_batch(d.initial_dist, batch_shape) - trans_mat = d.transition_matrix.reshape(batch_shape + (-1, d.hidden_dim, d.hidden_dim)) + trans_mat = d.transition_matrix.reshape( + batch_shape + (-1, d.hidden_dim, d.hidden_dim) + ) trans_dist = reshape_batch(d.transition_dist, batch_shape + (-1,)) obs_mat = d.observation_matrix.reshape(batch_shape + (-1, d.hidden_dim, d.obs_dim)) obs_dist = reshape_batch(d.observation_dist, batch_shape + (-1,)) @@ -393,15 +404,17 @@ def _(d, batch_shape): new.observation_dist = obs_dist transforms = [] for transform in d.transforms: - assert type(transform) in UNIVARIATE_TRANSFORMS, \ - "Currently, reshape_batch only supports AbsTransform, " + \ - "ExpTransform, SigmoidTransform transform" + assert type(transform) in UNIVARIATE_TRANSFORMS, ( + "Currently, reshape_batch only supports AbsTransform, " + + "ExpTransform, SigmoidTransform transform" + ) old_shape = d.observation_dist.shape() new_shape = obs_dist.shape() transforms.append(reshape_transform_batch(transform, old_shape, new_shape)) new.transforms = transforms - super(dist.LinearHMM, new).__init__(d.duration, batch_shape, d.event_shape, - validate_args=d._validate_args) + super(dist.LinearHMM, new).__init__( + d.duration, batch_shape, d.event_shape, validate_args=d._validate_args + ) return new @@ -430,7 +443,9 @@ def reshape_transform_batch(t, old_shape, new_shape): :returns: A transform with the same type but given new batch shape. :rtype: ~torch.distributions.transforms.Transform """ - raise NotImplementedError("reshape_transform_batch() does not suport {}".format(type(t))) + raise NotImplementedError( + "reshape_transform_batch() does not suport {}".format(type(t)) + ) def _reshape_batch_univariate_transform(t, old_shape, new_shape): @@ -451,10 +466,9 @@ def _(t, old_shape, new_shape): @reshape_transform_batch.register(dist.transforms.ComposeTransform) def _(t, old_shape, new_shape): - return dist.transforms.ComposeTransform([ - reshape_transform_batch(part, old_shape, new_shape) - for part in t.parts - ]) + return dist.transforms.ComposeTransform( + [reshape_transform_batch(part, old_shape, new_shape) for part in t.parts] + ) for _type in UNIVARIATE_TRANSFORMS: diff --git a/pyro/contrib/funsor/__init__.py b/pyro/contrib/funsor/__init__.py index aa2d3eeb01..d8a4d3eea9 100644 --- a/pyro/contrib/funsor/__init__.py +++ b/pyro/contrib/funsor/__init__.py @@ -26,14 +26,17 @@ def plate(*args, **kwargs): return _plate(None, *args, **kwargs) -pyroapi.register_backend('contrib.funsor', { - 'distributions': 'pyro.distributions', - 'handlers': 'pyro.contrib.funsor.handlers', - 'infer': 'pyro.contrib.funsor.infer', - 'ops': 'torch', - 'optim': 'pyro.optim', - 'pyro': 'pyro.contrib.funsor', -}) +pyroapi.register_backend( + "contrib.funsor", + { + "distributions": "pyro.distributions", + "handlers": "pyro.contrib.funsor.handlers", + "infer": "pyro.contrib.funsor.infer", + "ops": "torch", + "optim": "pyro.optim", + "pyro": "pyro.contrib.funsor", + }, +) __all__ = [ "clear_param_store", diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index befbeb2014..98e9a7ee33 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -26,13 +26,18 @@ @functools.singledispatch def _get_support_value(funsor_dist, name, **kwargs): - raise ValueError("Could not extract point from {} at name {}".format(funsor_dist, name)) + raise ValueError( + "Could not extract point from {} at name {}".format(funsor_dist, name) + ) @_get_support_value.register(funsor.cnf.Contraction) def _get_support_value_contraction(funsor_dist, name, **kwargs): - delta_terms = [v for v in funsor_dist.terms - if isinstance(v, funsor.delta.Delta) and name in v.fresh] + delta_terms = [ + v + for v in funsor_dist.terms + if isinstance(v, funsor.delta.Delta) and name in v.fresh + ] assert len(delta_terms) == 1 return _get_support_value(delta_terms[0], name, **kwargs) @@ -49,7 +54,7 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs): return funsor.Tensor( funsor.ops.new_arange(funsor_dist.data, funsor_dist.inputs[name].size), OrderedDict([(name, funsor_dist.inputs[name])]), - funsor_dist.inputs[name].size + funsor_dist.inputs[name].size, ) @@ -60,22 +65,31 @@ def _get_support_value_distribution(funsor_dist, name, expand=False): def _enum_strategy_default(dist, msg): - sample_inputs = OrderedDict((f.name, funsor.Bint[f.size]) for f in msg["cond_indep_stack"] - if f.vectorized and f.name not in dist.inputs) + sample_inputs = OrderedDict( + (f.name, funsor.Bint[f.size]) + for f in msg["cond_indep_stack"] + if f.vectorized and f.name not in dist.inputs + ) sampled_dist = dist.sample(msg["name"], sample_inputs) return sampled_dist def _enum_strategy_diagonal(dist, msg): sample_dim_name = "{}__PARTICLES".format(msg["name"]) - sample_inputs = OrderedDict({sample_dim_name: funsor.Bint[msg["infer"]["num_samples"]]}) + sample_inputs = OrderedDict( + {sample_dim_name: funsor.Bint[msg["infer"]["num_samples"]]} + ) plate_names = frozenset(f.name for f in msg["cond_indep_stack"] if f.vectorized) - ancestor_names = frozenset(k for k, v in dist.inputs.items() if v.dtype != 'real' - and k != msg["name"] and k not in plate_names) + ancestor_names = frozenset( + k + for k, v in dist.inputs.items() + if v.dtype != "real" and k != msg["name"] and k not in plate_names + ) # TODO should the ancestor_indices be pyro.observed? ancestor_indices = {name: sample_dim_name for name in ancestor_names} sampled_dist = dist(**ancestor_indices).sample( - msg["name"], sample_inputs if not ancestor_indices else None) + msg["name"], sample_inputs if not ancestor_indices else None + ) if ancestor_indices: # XXX is there a better way to account for this in funsor? sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"]) return sampled_dist @@ -83,26 +97,38 @@ def _enum_strategy_diagonal(dist, msg): def _enum_strategy_mixture(dist, msg): sample_dim_name = "{}__PARTICLES".format(msg["name"]) - sample_inputs = OrderedDict({sample_dim_name: funsor.Bint[msg['infer']['num_samples']]}) + sample_inputs = OrderedDict( + {sample_dim_name: funsor.Bint[msg["infer"]["num_samples"]]} + ) plate_names = frozenset(f.name for f in msg["cond_indep_stack"] if f.vectorized) - ancestor_names = frozenset(k for k, v in dist.inputs.items() if v.dtype != 'real' - and k != msg["name"] and k not in plate_names) + ancestor_names = frozenset( + k + for k, v in dist.inputs.items() + if v.dtype != "real" and k != msg["name"] and k not in plate_names + ) plate_inputs = OrderedDict((k, dist.inputs[k]) for k in plate_names) # TODO should the ancestor_indices be pyro.sampled? ancestor_indices = { # TODO make this comprehension less gross - name: _get_support_value(funsor.torch.distributions.CategoricalLogits( - # sample different ancestors for each plate slice - logits=funsor.Tensor( - # TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros - torch.zeros((1,)).expand(tuple(v.dtype for v in plate_inputs.values()) + (dist.inputs[name].dtype,)), - plate_inputs - ), - )(value=name).sample(name, sample_inputs), name) + name: _get_support_value( + funsor.torch.distributions.CategoricalLogits( + # sample different ancestors for each plate slice + logits=funsor.Tensor( + # TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros + torch.zeros((1,)).expand( + tuple(v.dtype for v in plate_inputs.values()) + + (dist.inputs[name].dtype,) + ), + plate_inputs, + ), + )(value=name).sample(name, sample_inputs), + name, + ) for name in ancestor_names } sampled_dist = dist(**ancestor_indices).sample( - msg["name"], sample_inputs if not ancestor_indices else None) + msg["name"], sample_inputs if not ancestor_indices else None + ) if ancestor_indices: # XXX is there a better way to account for this in funsor? sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"]) return sampled_dist @@ -110,7 +136,9 @@ def _enum_strategy_mixture(dist, msg): def _enum_strategy_full(dist, msg): sample_dim_name = "{}__PARTICLES".format(msg["name"]) - sample_inputs = OrderedDict({sample_dim_name: funsor.Bint[msg["infer"]["num_samples"]]}) + sample_inputs = OrderedDict( + {sample_dim_name: funsor.Bint[msg["infer"]["num_samples"]]} + ) sampled_dist = dist.sample(msg["name"], sample_inputs) return sampled_dist @@ -127,10 +155,14 @@ def enumerate_site(dist, msg): return _enum_strategy_default(dist, msg) elif msg["infer"].get("num_samples", None) is None: return _enum_strategy_exact(dist, msg) - elif msg["infer"]["num_samples"] > 1 and \ - (msg["infer"].get("expand", False) or msg["infer"].get("tmc") == "full"): + elif msg["infer"]["num_samples"] > 1 and ( + msg["infer"].get("expand", False) or msg["infer"].get("tmc") == "full" + ): return _enum_strategy_full(dist, msg) - elif msg["infer"]["num_samples"] > 1 and msg["infer"].get("tmc", "diagonal") == "diagonal": + elif ( + msg["infer"]["num_samples"] > 1 + and msg["infer"].get("tmc", "diagonal") == "diagonal" + ): return _enum_strategy_diagonal(dist, msg) elif msg["infer"]["num_samples"] > 1 and msg["infer"]["tmc"] == "mixture": return _enum_strategy_mixture(dist, msg) @@ -142,27 +174,40 @@ class EnumMessenger(NamedMessenger): This version of :class:`~EnumMessenger` uses :func:`~pyro.contrib.funsor.to_data` to allocate a fresh enumeration dim for each discrete sample site. """ + def _pyro_sample(self, msg): - if msg["done"] or msg["is_observed"] or \ - msg["infer"].get("enumerate") not in {"flat", "parallel"} or \ - isinstance(msg["fn"], _Subsample): + if ( + msg["done"] + or msg["is_observed"] + or msg["infer"].get("enumerate") not in {"flat", "parallel"} + or isinstance(msg["fn"], _Subsample) + ): return if "funsor" not in msg: msg["funsor"] = {} - unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)(value=msg["name"]) + unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)( + value=msg["name"] + ) msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg) msg["funsor"]["value"] = _get_support_value( - msg["funsor"]["log_measure"], msg["name"], expand=msg["infer"].get("expand", False)) + msg["funsor"]["log_measure"], + msg["name"], + expand=msg["infer"].get("expand", False), + ) msg["value"] = to_data(msg["funsor"]["value"]) msg["done"] = True -def queue(fn=None, queue=None, - max_tries=int(1e6), num_samples=-1, - extend_fn=pyro.poutine.util.enum_extend, - escape_fn=pyro.poutine.util.discrete_escape): +def queue( + fn=None, + queue=None, + max_tries=int(1e6), + num_samples=-1, + extend_fn=pyro.poutine.util.enum_extend, + escape_fn=pyro.poutine.util.discrete_escape, +): """ Used in sequential enumeration over discrete variables (copied from poutine.queue). @@ -184,22 +229,27 @@ def wrapper(wrapped): def _fn(*args, **kwargs): for i in range(max_tries): - assert not queue.empty(), \ - "trying to get() from an empty queue will deadlock" + assert ( + not queue.empty() + ), "trying to get() from an empty queue will deadlock" next_trace = queue.get() try: ftr = TraceMessenger()( - EscapeMessenger(escape_fn=functools.partial(escape_fn, next_trace))( - ReplayMessenger(trace=next_trace)(wrapped))) + EscapeMessenger( + escape_fn=functools.partial(escape_fn, next_trace) + )(ReplayMessenger(trace=next_trace)(wrapped)) + ) return ftr(*args, **kwargs) except pyro.poutine.runtime.NonlocalExit as site_container: site_container.reset_stack() # TODO implement missing ._reset()s - for tr in extend_fn(ftr.trace.copy(), site_container.site, - num_samples=num_samples): + for tr in extend_fn( + ftr.trace.copy(), site_container.site, num_samples=num_samples + ): queue.put(tr) raise ValueError("max tries ({}) exceeded".format(str(max_tries))) + return _fn return wrapper(fn) if fn is not None else wrapper diff --git a/pyro/contrib/funsor/handlers/named_messenger.py b/pyro/contrib/funsor/handlers/named_messenger.py index 3412440645..ff65a18223 100644 --- a/pyro/contrib/funsor/handlers/named_messenger.py +++ b/pyro/contrib/funsor/handlers/named_messenger.py @@ -23,8 +23,11 @@ class NamedMessenger(ReentrantMessenger): This design ensures that the global name-dim mapping is reset upon handler exit rather than potentially persisting until the entire program terminates. """ + def __init__(self, first_available_dim=None): - assert first_available_dim is None or first_available_dim < 0, first_available_dim + assert ( + first_available_dim is None or first_available_dim < 0 + ), first_available_dim self.first_available_dim = first_available_dim self._saved_dims = set() return super().__init__() @@ -32,7 +35,9 @@ def __init__(self, first_available_dim=None): def __enter__(self): if self._ref_count == 0: if self.first_available_dim is not None: - self._prev_first_dim = _DIM_STACK.set_first_available_dim(self.first_available_dim) + self._prev_first_dim = _DIM_STACK.set_first_available_dim( + self.first_available_dim + ) if _DIM_STACK.outermost is None: _DIM_STACK.outermost = self for name, dim in self._saved_dims: @@ -55,7 +60,7 @@ def __exit__(self, *args, **kwargs): @staticmethod # only depends on the global _DIM_STACK state, not self def _pyro_to_data(msg): - funsor_value, = msg["args"] + (funsor_value,) = msg["args"] name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) @@ -65,7 +70,9 @@ def _pyro_to_data(msg): name_to_dim_request = name_to_dim.copy() for name in batch_names: dim = name_to_dim.get(name, None) - name_to_dim_request[name] = dim if isinstance(dim, DimRequest) else DimRequest(dim, dim_type) + name_to_dim_request[name] = ( + dim if isinstance(dim, DimRequest) else DimRequest(dim, dim_type) + ) # request and update name_to_dim in-place # name_to_dim.update(_DIM_STACK.allocate_name_to_dim(name_to_dim_request)) @@ -89,15 +96,19 @@ def _pyro_to_funsor(msg): batch_shape = raw_value.batch_shape # TODO make make this more robust except AttributeError: full_shape = getattr(raw_value, "shape", ()) - batch_shape = full_shape[:len(full_shape) - event_dim] + batch_shape = full_shape[: len(full_shape) - event_dim] - batch_dims = tuple(dim for dim in range(-len(batch_shape), 0) if batch_shape[dim] > 1) + batch_dims = tuple( + dim for dim in range(-len(batch_shape), 0) if batch_shape[dim] > 1 + ) # interpret all names/dims as requests since we only run this function once dim_to_name_request = dim_to_name.copy() for dim in batch_dims: name = dim_to_name.get(dim, None) - dim_to_name_request[dim] = name if isinstance(name, DimRequest) else DimRequest(name, dim_type) + dim_to_name_request[dim] = ( + name if isinstance(name, DimRequest) else DimRequest(name, dim_type) + ) # request and update dim_to_name in-place dim_to_name.update(_DIM_STACK.allocate(dim_to_name_request)) @@ -117,6 +128,7 @@ class MarkovMessenger(NamedMessenger): level can depend on each other; if ``keep=False``, neighboring branches are independent (conditioned on their shared ancestors). """ + def __init__(self, history=1, keep=False): self.history = history self.keep = keep @@ -144,8 +156,10 @@ def __enter__(self): frame = self._saved_frames.pop() else: frame = StackFrame( - name_to_dim=OrderedDict(), dim_to_name=OrderedDict(), - history=self.history, keep=self.keep, + name_to_dim=OrderedDict(), + dim_to_name=OrderedDict(), + history=self.history, + keep=self.keep, ) _DIM_STACK.push_local(frame) @@ -169,13 +183,17 @@ class GlobalNamedMessenger(NamedMessenger): global dimensions will be considered active until the innermost :class:`~GlobalNamedMessenger` under which they were initially allocated exits. """ + def __init__(self, first_available_dim=None): self._saved_frames = [] super().__init__(first_available_dim=first_available_dim) def __enter__(self): - frame = self._saved_frames.pop() if self._saved_frames else StackFrame( - name_to_dim=OrderedDict(), dim_to_name=OrderedDict()) + frame = ( + self._saved_frames.pop() + if self._saved_frames + else StackFrame(name_to_dim=OrderedDict(), dim_to_name=OrderedDict()) + ) _DIM_STACK.push_global(frame) return super().__enter__() diff --git a/pyro/contrib/funsor/handlers/plate_messenger.py b/pyro/contrib/funsor/handlers/plate_messenger.py index 702993a27d..2c76e85e9f 100644 --- a/pyro/contrib/funsor/handlers/plate_messenger.py +++ b/pyro/contrib/funsor/handlers/plate_messenger.py @@ -31,18 +31,23 @@ class IndepMessenger(GlobalNamedMessenger): Vectorized plate implementation using :func:`~pyro.contrib.funsor.to_data` instead of :class:`~pyro.poutine.runtime._DimAllocator`. """ + def __init__(self, name=None, size=None, dim=None, indices=None): assert dim is None or dim < 0 super().__init__() # without a name or dim, treat as a "vectorize" effect and allocate a non-visible dim - self.dim_type = DimType.GLOBAL if name is None and dim is None else DimType.VISIBLE + self.dim_type = ( + DimType.GLOBAL if name is None and dim is None else DimType.VISIBLE + ) self.name = name if name is not None else funsor.interpreter.gensym("PLATE") self.size = size self.dim = dim if not hasattr(self, "_full_size"): self._full_size = size if indices is None: - indices = funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size) + indices = funsor.ops.new_arange( + funsor.tensor.get_default_prototype(), self.size + ) assert len(indices) == size self._indices = funsor.Tensor( @@ -68,11 +73,19 @@ def _pyro_param(self, msg): @copy_docs_from(OrigSubsampleMessenger) class SubsampleMessenger(IndepMessenger): - - def __init__(self, name=None, size=None, subsample_size=None, subsample=None, dim=None, - use_cuda=None, device=None): + def __init__( + self, + name=None, + size=None, + subsample_size=None, + subsample=None, + dim=None, + use_cuda=None, + device=None, + ): size, subsample_size, indices = OrigSubsampleMessenger._subsample( - name, size, subsample_size, subsample, use_cuda, device) + name, size, subsample_size, subsample, use_cuda, device + ) self.subsample_size = subsample_size self._full_size = size self._scale = float(size) / subsample_size @@ -88,8 +101,12 @@ def _pyro_param(self, msg): msg["scale"] = msg["scale"] * self._scale def _subsample_site_value(self, value, event_dim=None): - if self.dim is not None and event_dim is not None and self.subsample_size < self._full_size: - event_shape = value.shape[len(value.shape) - event_dim:] + if ( + self.dim is not None + and event_dim is not None + and self.subsample_size < self._full_size + ): + event_shape = value.shape[len(value.shape) - event_dim :] funsor_value = to_funsor(value, output=funsor.Reals[event_shape]) if self.name in funsor_value.inputs: return to_data(funsor_value(**{self.name: self._indices})) @@ -122,6 +139,7 @@ class PlateMessenger(SubsampleMessenger): :class:`pyro.poutine.BroadcastMessenger`. Should eventually be a drop-in replacement for :class:`pyro.plate`. """ + def __enter__(self): super().__enter__() return self.indices # match pyro.plate behavior @@ -131,13 +149,18 @@ def _pyro_sample(self, msg): BroadcastMessenger._pyro_sample(msg) def __iter__(self): - return iter(_SequentialPlateMessenger(self.name, self.size, self._indices.data.squeeze(), self._scale)) + return iter( + _SequentialPlateMessenger( + self.name, self.size, self._indices.data.squeeze(), self._scale + ) + ) class _SequentialPlateMessenger(Messenger): """ Implementation of sequential plate. Should not be used directly. """ + def __init__(self, name, size, indices, scale): self.name = name self.size = size @@ -269,6 +292,7 @@ def model(data, vectorized=True): :return: Returns both :class:`int` and 1-dimensional :class:`torch.Tensor` indices: ``(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))``. """ + def __init__(self, name=None, size=None, dim=None, history=1): self.name = name self.size = size @@ -290,8 +314,12 @@ def _markov_chain(name=None, markov_vars=set(), suffixes=list()): :return: step information :rtype: frozenset """ - chain = frozenset({tuple("{}{}".format(var, suffix) for suffix in suffixes) - for var in markov_vars}) + chain = frozenset( + { + tuple("{}{}".format(var, suffix) for suffix in suffixes) + for var in markov_vars + } + ) return chain def __iter__(self): @@ -302,13 +330,20 @@ def __iter__(self): self._suffixes.append(self._suffix) yield self._suffix with self: - with IndepMessenger(name=self.name, size=self.size-self.history, dim=self.dim) as time: - time_indices = [time.indices+i for i in range(self.history+1)] - time_slices = [slice(i, self.size-self.history+i) for i in range(self.history+1)] + with IndepMessenger( + name=self.name, size=self.size - self.history, dim=self.dim + ) as time: + time_indices = [time.indices + i for i in range(self.history + 1)] + time_slices = [ + slice(i, self.size - self.history + i) + for i in range(self.history + 1) + ] self._suffixes.extend(time_slices) for self._suffix, self._indices in zip(time_slices, time_indices): yield self._indices - self._markov_chain(name=self.name, markov_vars=self._markov_vars, suffixes=self._suffixes) + self._markov_chain( + name=self.name, markov_vars=self._markov_vars, suffixes=self._suffixes + ) def _pyro_sample(self, msg): if type(msg["fn"]).__name__ == "_Subsample": @@ -317,7 +352,7 @@ def _pyro_sample(self, msg): # replace tensor suffix with a nice slice suffix if isinstance(self._suffix, slice): assert msg["name"].endswith(str(self._indices)) - msg["name"] = msg["name"][:-len(str(self._indices))] + str(self._suffix) + msg["name"] = msg["name"][: -len(str(self._indices))] + str(self._suffix) if str(self._suffix) != str(self._suffixes[-1]): # _do_not_score: record these sites when tracing for use with replay, # but do not include them in ELBO computation. @@ -325,7 +360,7 @@ def _pyro_sample(self, msg): # map auxiliary var to markov var name prefix # assuming that site name has a format: "markov_var{}".format(_suffix) # is there a better way? - markov_var = msg["name"][:-len(str(self._suffix))] + markov_var = msg["name"][: -len(str(self._suffix))] self._auxiliary_to_markov[msg["name"]] = markov_var def _pyro_post_sample(self, msg): @@ -336,8 +371,11 @@ def _pyro_post_sample(self, msg): return # if last step in the for loop if str(self._suffix) == str(self._suffixes[-1]): - funsor_log_prob = msg["funsor"]["log_prob"] if "log_prob" in msg.get("funsor", {}) else \ - to_funsor(msg["fn"].log_prob(msg["value"]), output=funsor.Real) + funsor_log_prob = ( + msg["funsor"]["log_prob"] + if "log_prob" in msg.get("funsor", {}) + else to_funsor(msg["fn"].log_prob(msg["value"]), output=funsor.Real) + ) # for auxiliary sites in the log_prob for name in set(funsor_log_prob.inputs) & set(self._auxiliary_to_markov): # add markov var name prefix to self._markov_vars diff --git a/pyro/contrib/funsor/handlers/primitives.py b/pyro/contrib/funsor/handlers/primitives.py index 0b7a4c4edb..706f8e187f 100644 --- a/pyro/contrib/funsor/handlers/primitives.py +++ b/pyro/contrib/funsor/handlers/primitives.py @@ -8,16 +8,22 @@ @pyro.poutine.runtime.effectful(type="to_funsor") def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): import funsor + if pyro.poutine.runtime.am_i_wrapped() and not dim_to_name: dim_to_name = _DIM_STACK.global_frame.dim_to_name.copy() - assert not dim_to_name or not any(isinstance(name, DimRequest) for name in dim_to_name.values()) + assert not dim_to_name or not any( + isinstance(name, DimRequest) for name in dim_to_name.values() + ) return funsor.to_funsor(x, output=output, dim_to_name=dim_to_name) @pyro.poutine.runtime.effectful(type="to_data") def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): import funsor + if pyro.poutine.runtime.am_i_wrapped() and not name_to_dim: name_to_dim = _DIM_STACK.global_frame.name_to_dim.copy() - assert not name_to_dim or not any(isinstance(dim, DimRequest) for dim in name_to_dim.values()) + assert not name_to_dim or not any( + isinstance(dim, DimRequest) for dim in name_to_dim.values() + ) return funsor.to_data(x, name_to_dim=name_to_dim) diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index f561d152e9..2f6a76033a 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -11,9 +11,12 @@ class ReplayMessenger(OrigReplayMessenger): except that it calls :func:`~pyro.contrib.funsor.to_data` on the replayed funsor values. This may result in different unpacked shapes, but should produce correct allocations. """ + def _pyro_sample(self, msg): name = msg["name"] - msg["replay_active"] = True # indicate replaying so importance weights can be scaled + msg[ + "replay_active" + ] = True # indicate replaying so importance weights can be scaled if self.trace is None: return @@ -24,7 +27,9 @@ def _pyro_sample(self, msg): raise RuntimeError("site {} must be sample in trace".format(name)) # TODO make this work with sequential enumeration if guide_msg.get("funsor", {}).get("value", None) is not None: - msg["value"] = to_data(guide_msg["funsor"]["value"]) # only difference is here + msg["value"] = to_data( + guide_msg["funsor"]["value"] + ) # only difference is here else: msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"] diff --git a/pyro/contrib/funsor/handlers/runtime.py b/pyro/contrib/funsor/handlers/runtime.py index 6b320e8942..d1d3a8081a 100644 --- a/pyro/contrib/funsor/handlers/runtime.py +++ b/pyro/contrib/funsor/handlers/runtime.py @@ -10,18 +10,27 @@ class StackFrame: Consistent bidirectional mapping between integer positional dimensions and names. Can be queried like a dictionary (``value = frame[key]``, ``frame[key] = value``). """ + def __init__(self, name_to_dim, dim_to_name, history=1, keep=False): - assert isinstance(name_to_dim, OrderedDict) and \ - all(isinstance(name, str) and isinstance(dim, int) for name, dim in name_to_dim.items()) - assert isinstance(dim_to_name, OrderedDict) and \ - all(isinstance(name, str) and isinstance(dim, int) for dim, name in dim_to_name.items()) + assert isinstance(name_to_dim, OrderedDict) and all( + isinstance(name, str) and isinstance(dim, int) + for name, dim in name_to_dim.items() + ) + assert isinstance(dim_to_name, OrderedDict) and all( + isinstance(name, str) and isinstance(dim, int) + for dim, name in dim_to_name.items() + ) self.name_to_dim = name_to_dim self.dim_to_name = dim_to_name self.history = history self.keep = keep def __setitem__(self, key, value): - assert isinstance(key, (int, str)) and isinstance(value, (int, str)) and type(key) != type(value) + assert ( + isinstance(key, (int, str)) + and isinstance(value, (int, str)) + and type(key) != type(value) + ) name, dim = (value, key) if isinstance(key, int) else (key, value) self.name_to_dim[name], self.dim_to_name[dim] = dim, name @@ -31,8 +40,11 @@ def __getitem__(self, key): def __delitem__(self, key): assert isinstance(key, (int, str)) - k2v, v2k = (self.dim_to_name, self.name_to_dim) if isinstance(key, int) else \ - (self.name_to_dim, self.dim_to_name) + k2v, v2k = ( + (self.dim_to_name, self.name_to_dim) + if isinstance(key, int) + else (self.name_to_dim, self.dim_to_name) + ) del v2k[k2v[key]] del k2v[key] @@ -43,12 +55,13 @@ def __contains__(self, key): class DimType(Enum): """Enumerates the possible types of dimensions to allocate""" + LOCAL = 0 GLOBAL = 1 VISIBLE = 2 -DimRequest = namedtuple('DimRequest', ['value', 'dim_type']) +DimRequest = namedtuple("DimRequest", ["value", "dim_type"]) DimRequest.__new__.__defaults__ = (None, DimType.LOCAL) @@ -60,10 +73,13 @@ class DimStack: the enum :class:`~pyro.poutine.runtime._EnumAllocator`, the ``stack`` in :class:`~MarkovMessenger`, ``_param_dims`` and ``_value_dims`` in :class:`~EnumMessenger`, and ``dim_to_symbol`` in ``msg['infer']`` """ + def __init__(self): global_frame = StackFrame( - name_to_dim=OrderedDict(), dim_to_name=OrderedDict(), - history=0, keep=False, + name_to_dim=OrderedDict(), + dim_to_name=OrderedDict(), + history=0, + keep=False, ) self._local_stack = [global_frame] self._iter_stack = [global_frame] @@ -110,8 +126,11 @@ def local_frame(self): @property def current_write_env(self): - return self._local_stack[-1:] if not self.local_frame.keep else \ - self._local_stack[-self.local_frame.history-1:] + return ( + self._local_stack[-1:] + if not self.local_frame.keep + else self._local_stack[-self.local_frame.history - 1 :] + ) @property def current_read_env(self): @@ -119,7 +138,11 @@ def current_read_env(self): Collect all frames necessary to compute the full name <--> dim mapping and interpret Funsor inputs or batch shapes at any point in a computation. """ - return self._global_stack + self._local_stack[-self.local_frame.history-1:] + self._iter_stack + return ( + self._global_stack + + self._local_stack[-self.local_frame.history - 1 :] + + self._iter_stack + ) def _genvalue(self, key, value_request): """ @@ -145,12 +168,17 @@ def _genvalue(self, key, value_request): while any(fresh_dim in p for p in self.current_read_env): fresh_dim -= 1 - if fresh_dim < self.MAX_DIM or \ - (dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim): - raise ValueError("Ran out of free dims during allocation for {}".format(name)) + if fresh_dim < self.MAX_DIM or ( + dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim + ): + raise ValueError( + "Ran out of free dims during allocation for {}".format(name) + ) return name, fresh_dim - raise ValueError("{} and {} not a valid name-dim pair".format(key, value_request)) + raise ValueError( + "{} and {} not a valid name-dim pair".format(key, value_request) + ) def allocate(self, key_to_value_request): @@ -177,11 +205,19 @@ def allocate(self, key_to_value_request): key, fresh_value = self._genvalue(key, value_request) # if this key is already active but inconsistent with the fresh value, # generate a fresh_key for future conversions via _genvalue in reverse - if value_request.dim_type != DimType.VISIBLE or any(key in frame for frame in self.current_read_env): - _, fresh_key = self._genvalue(fresh_value, DimRequest(key, value_request.dim_type)) + if value_request.dim_type != DimType.VISIBLE or any( + key in frame for frame in self.current_read_env + ): + _, fresh_key = self._genvalue( + fresh_value, DimRequest(key, value_request.dim_type) + ) else: fresh_key = key - for frame in ([self.global_frame] if value_request.dim_type != DimType.LOCAL else self.current_write_env): + for frame in ( + [self.global_frame] + if value_request.dim_type != DimType.LOCAL + else self.current_write_env + ): frame[fresh_key] = fresh_value # use the user-provided key rather than fresh_key for satisfying this request only key_to_value[key] = fresh_value @@ -190,10 +226,13 @@ def allocate(self, key_to_value_request): return key_to_value def names_from_batch_shape(self, batch_shape, dim_type=DimType.LOCAL): - return self.allocate_dim_to_name(OrderedDict( - (dim, DimRequest(None, dim_type)) - for dim in range(-len(batch_shape), 0) if batch_shape[dim] > 1 - )) + return self.allocate_dim_to_name( + OrderedDict( + (dim, DimRequest(None, dim_type)) + for dim in range(-len(batch_shape), 0) + if batch_shape[dim] > 1 + ) + ) _DIM_STACK = DimStack() # only one global instance diff --git a/pyro/contrib/funsor/handlers/trace_messenger.py b/pyro/contrib/funsor/handlers/trace_messenger.py index abadd1b8db..f517967079 100644 --- a/pyro/contrib/funsor/handlers/trace_messenger.py +++ b/pyro/contrib/funsor/handlers/trace_messenger.py @@ -26,6 +26,7 @@ class TraceMessenger(OrigTraceMessenger): Each sample site is annotated with a ``dim_to_name`` dictionary, which can be passed directly to :func:`~pyro.contrib.funsor.to_funsor`. """ + def __init__(self, graph_type=None, param_only=None, pack_online=True): super().__init__(graph_type=graph_type, param_only=param_only) self.pack_online = True if pack_online is None else pack_online @@ -40,16 +41,24 @@ def _pyro_post_sample(self, msg): if self.pack_online: if "fn" not in msg["funsor"]: fn_masked = _mask_fn(msg["fn"], msg["mask"]) - msg["funsor"]["fn"] = to_funsor(fn_masked, funsor.Real)(value=msg["name"]) + msg["funsor"]["fn"] = to_funsor(fn_masked, funsor.Real)( + value=msg["name"] + ) if "value" not in msg["funsor"]: # value_output = funsor.Reals[getattr(msg["fn"], "event_shape", ())] - msg["funsor"]["value"] = to_funsor(msg["value"], msg["funsor"]["fn"].inputs[msg["name"]]) - if "log_prob" not in msg["funsor"] and \ - not msg["infer"].get("_do_not_trace") and \ - not msg["infer"].get("_do_not_score", False): + msg["funsor"]["value"] = to_funsor( + msg["value"], msg["funsor"]["fn"].inputs[msg["name"]] + ) + if ( + "log_prob" not in msg["funsor"] + and not msg["infer"].get("_do_not_trace") + and not msg["infer"].get("_do_not_score", False) + ): # optimization: don't perform this tensor op unless we have to fn_masked = _mask_fn(msg["fn"], msg["mask"]) - msg["funsor"]["log_prob"] = to_funsor(fn_masked.log_prob(msg["value"]), output=funsor.Real) + msg["funsor"]["log_prob"] = to_funsor( + fn_masked.log_prob(msg["value"]), output=funsor.Real + ) # TODO support this pattern which uses funsor directly - blocked by casting issues # msg["funsor"]["log_prob"] = msg["funsor"]["fn"](**{msg["name"]: msg["funsor"]["value"]}) if msg["scale"] is not None and "scale" not in msg["funsor"]: @@ -57,9 +66,16 @@ def _pyro_post_sample(self, msg): else: # this logic has the same side effect on the _DIM_STACK as the above, # but does not perform any tensor or funsor operations. - msg["funsor"]["dim_to_name"] = _DIM_STACK.names_from_batch_shape(msg["fn"].batch_shape) - msg["funsor"]["dim_to_name"].update(_DIM_STACK.names_from_batch_shape( - msg["value"].shape[:len(msg["value"]).shape - len(msg["fn"].event_shape)])) + msg["funsor"]["dim_to_name"] = _DIM_STACK.names_from_batch_shape( + msg["fn"].batch_shape + ) + msg["funsor"]["dim_to_name"].update( + _DIM_STACK.names_from_batch_shape( + msg["value"].shape[ + : len(msg["value"]).shape - len(msg["fn"].event_shape) + ] + ) + ) return super()._pyro_post_sample(msg) def _pyro_post_markov_chain(self, msg): diff --git a/pyro/contrib/funsor/infer/discrete.py b/pyro/contrib/funsor/infer/discrete.py index 3a5ff15d84..3518882c16 100644 --- a/pyro/contrib/funsor/infer/discrete.py +++ b/pyro/contrib/funsor/infer/discrete.py @@ -34,10 +34,11 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( - sum_op, prod_op, + sum_op, + prod_op, terms["log_factors"] + terms["log_measures"], eliminate=terms["measure_vars"] | terms["plate_vars"], - plates=terms["plate_vars"] + plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) @@ -56,8 +57,12 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # TODO this should really be handled entirely under the hood by adjoint node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)} else: - node["funsor"]["log_measure"] = approx_factors[node["funsor"]["log_measure"]] - node["funsor"]["value"] = _get_support_value(node["funsor"]["log_measure"], name) + node["funsor"]["log_measure"] = approx_factors[ + node["funsor"]["log_measure"] + ] + node["funsor"]["value"] = _get_support_value( + node["funsor"]["log_measure"], name + ) sample_subs[name] = node["funsor"]["value"] with replay(trace=sample_tr): diff --git a/pyro/contrib/funsor/infer/elbo.py b/pyro/contrib/funsor/infer/elbo.py index 6d7dbd8763..80e9b7402d 100644 --- a/pyro/contrib/funsor/infer/elbo.py +++ b/pyro/contrib/funsor/infer/elbo.py @@ -7,7 +7,6 @@ class ELBO(_OrigELBO): - def _get_trace(self, *args, **kwargs): raise ValueError("shouldn't be here!") @@ -24,19 +23,19 @@ def loss_and_grads(self, model, guide, *args, **kwargs): class Jit_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): - kwargs['_model_id'] = id(model) - kwargs['_guide_id'] = id(guide) - if getattr(self, '_differentiable_loss', None) is None: + kwargs["_model_id"] = id(model) + kwargs["_guide_id"] = id(guide) + if getattr(self, "_differentiable_loss", None) is None: # build a closure for differentiable_loss superself = super() - @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings, - jit_options=self.jit_options) + @pyro.ops.jit.trace( + ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options + ) def differentiable_loss(*args, **kwargs): - kwargs.pop('_model_id') - kwargs.pop('_guide_id') + kwargs.pop("_model_id") + kwargs.pop("_guide_id") with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): return superself.differentiable_loss(model, guide, *args, **kwargs) diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index 686926b772..e91787732f 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -17,24 +17,30 @@ @copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): - with enum(), \ - plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(): - guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace(*args, **kwargs) + with enum(), plate( + size=self.num_particles + ) if self.num_particles > 1 else contextlib.ExitStack(): + guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( + *args, **kwargs + ) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] - log_factors = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] + log_factors = model_terms["log_factors"] + [ + -f for f in guide_terms["log_factors"] + ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] - elbo = funsor.Integrate(sum(log_measures, to_funsor(0.)), - sum(log_factors, to_funsor(0.)), - measure_vars) + elbo = funsor.Integrate( + sum(log_measures, to_funsor(0.0)), + sum(log_factors, to_funsor(0.0)), + measure_vars, + ) elbo = elbo.reduce(funsor.ops.add, plate_vars) return -to_data(elbo) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 0529bfd746..7915f7e101 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -28,50 +28,78 @@ def terms_from_trace(tr): """Helper function to extract elbo components from execution traces.""" # data structure containing densities, measures, scales, and identification # of free variables as either product (plate) variables or sum (measure) variables - terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.), - "plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()} + terms = { + "log_factors": [], + "log_measures": [], + "scale": to_funsor(1.0), + "plate_vars": frozenset(), + "measure_vars": frozenset(), + "plate_to_step": dict(), + } for name, node in tr.nodes.items(): # add markov dimensions to the plate_to_step dictionary if node["type"] == "markov_chain": terms["plate_to_step"][node["name"]] = node["value"] # ensure previous step variables are added to measure_vars for step in node["value"]: - terms["measure_vars"] |= frozenset({ - var for var in step[1:-1] - if tr.nodes[var]["funsor"].get("log_measure", None) is not None}) - if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample" or \ - node["infer"].get("_do_not_score", False): + terms["measure_vars"] |= frozenset( + { + var + for var in step[1:-1] + if tr.nodes[var]["funsor"].get("log_measure", None) is not None + } + ) + if ( + node["type"] != "sample" + or type(node["fn"]).__name__ == "_Subsample" + or node["infer"].get("_do_not_score", False) + ): continue # grab plate dimensions from the cond_indep_stack - terms["plate_vars"] |= frozenset(f.name for f in node["cond_indep_stack"] if f.vectorized) + terms["plate_vars"] |= frozenset( + f.name for f in node["cond_indep_stack"] if f.vectorized + ) # grab the log-measure, found only at sites that are not replayed or observed if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) # sum (measure) variables: the fresh non-plate variables at a site - terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"] + terms["measure_vars"] |= ( + frozenset(node["funsor"]["value"].inputs) | {name} + ) - terms["plate_vars"] # grab the scale, assuming a common subsampling scale - if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \ - float(to_data(node["funsor"]["scale"])) != 1.: + if ( + node.get("replay_active", False) + and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] + and float(to_data(node["funsor"]["scale"])) != 1.0 + ): # model site that depends on enumerated variable: common scale terms["scale"] = node["funsor"]["scale"] else: # otherwise: default scale behavior - node["funsor"]["log_prob"] = node["funsor"]["log_prob"] * node["funsor"]["scale"] + node["funsor"]["log_prob"] = ( + node["funsor"]["log_prob"] * node["funsor"]["scale"] + ) # grab the log-density, found at all sites except those that are not replayed if node["is_observed"] or not node.get("replay_skipped", False): terms["log_factors"].append(node["funsor"]["log_prob"]) # add plate dimensions to the plate_to_step dictionary - terms["plate_to_step"].update({plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]}) + terms["plate_to_step"].update( + {plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]} + ) return terms @copy_docs_from(_OrigTraceEnum_ELBO) class TraceMarkovEnum_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model - with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ - enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): + with plate( + size=self.num_particles + ) if self.num_particles > 1 else contextlib.ExitStack(), enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting + else None + ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) @@ -81,7 +109,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): # guide side enumeration is not supported if any(guide_terms["plate_to_step"].values()): - raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration") + raise NotImplementedError( + "TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration" + ) # build up a lazy expression for the elbo with funsor.terms.lazy: @@ -93,14 +123,19 @@ def differentiable_loss(self, model, guide, *args, **kwargs): else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor - markov_dims = frozenset({ - plate for plate, step in model_terms["plate_to_step"].items() if step}) - contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.dynamic_partial_sum_product( - funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plate_to_step=model_terms["plate_to_step"], - eliminate=model_terms["measure_vars"] | markov_dims - )] + markov_dims = frozenset( + {plate for plate, step in model_terms["plate_to_step"].items() if step} + ) + contracted_costs = [ + model_terms["scale"] * f + for f in funsor.sum_product.dynamic_partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plate_to_step=model_terms["plate_to_step"], + eliminate=model_terms["measure_vars"] | markov_dims, + ) + ] costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq @@ -111,14 +146,20 @@ def differentiable_loss(self, model, guide, *args, **kwargs): for cost in costs: # compute the marginal logq in the guide corresponding to this cost term log_prob = funsor.sum_product.sum_product( - funsor.ops.logaddexp, funsor.ops.add, + funsor.ops.logaddexp, + funsor.ops.add, guide_terms["log_measures"], plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs) + eliminate=(plate_vars | guide_terms["measure_vars"]) + - frozenset(cost.inputs), ) # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q - elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) - elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) + elbo_term = funsor.Integrate( + log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs) + ) + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): @@ -127,12 +168,16 @@ def differentiable_loss(self, model, guide, *args, **kwargs): @copy_docs_from(_OrigTraceEnum_ELBO) class TraceEnum_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model - with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ - enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): + with plate( + size=self.num_particles + ) if self.num_particles > 1 else contextlib.ExitStack(), enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting + else None + ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) @@ -150,11 +195,16 @@ def differentiable_loss(self, model, guide, *args, **kwargs): else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor - contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product( - funsor.ops.logaddexp, funsor.ops.add, - model_terms["log_measures"] + contracted_factors, - plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"] - )] + contracted_costs = [ + model_terms["scale"] * f + for f in funsor.sum_product.partial_sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plates=model_terms["plate_vars"], + eliminate=model_terms["measure_vars"], + ) + ] # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp @@ -180,20 +230,31 @@ def differentiable_loss(self, model, guide, *args, **kwargs): ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( - funsor.ops.logaddexp, funsor.ops.add, + funsor.ops.logaddexp, + funsor.ops.add, guide_terms["log_measures"] + list(targets.values()), plates=plate_vars, - eliminate=(plate_vars | guide_terms["measure_vars"]) + eliminate=(plate_vars | guide_terms["measure_vars"]), ) - marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values())) + marginals = tape.adjoint( + funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values()) + ) # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] - logzq_local = marginals[target].reduce(funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars) + logzq_local = marginals[target].reduce( + funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars + ) log_prob = marginals[target] - logzq_local - elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(log_prob.inputs)) - elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) + elbo_term = funsor.Integrate( + log_prob, + cost, + guide_terms["measure_vars"] & frozenset(log_prob.inputs), + ) + elbo += elbo_term.reduce( + funsor.ops.add, plate_vars & frozenset(cost.inputs) + ) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index c9c44f4a25..7cf4ba805e 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -15,10 +15,14 @@ @copy_docs_from(_OrigTraceTMC_ELBO) class TraceTMC_ELBO(ELBO): - def differentiable_loss(self, model, guide, *args, **kwargs): - with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ - enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): + with plate( + size=self.num_particles + ) if self.num_particles > 1 else contextlib.ExitStack(), enum( + first_available_dim=(-self.max_plate_nesting - 1) + if self.max_plate_nesting + else None + ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) @@ -26,16 +30,19 @@ def differentiable_loss(self, model, guide, *args, **kwargs): guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] - log_factors = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]] + log_factors = model_terms["log_factors"] + [ + -f for f in guide_terms["log_factors"] + ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] with funsor.terms.lazy: elbo = funsor.sum_product.sum_product( - funsor.ops.logaddexp, funsor.ops.add, + funsor.ops.logaddexp, + funsor.ops.add, log_measures + log_factors, eliminate=measure_vars | plate_vars, - plates=plate_vars + plates=plate_vars, ) return -to_data(apply_optimizer(elbo)) diff --git a/pyro/contrib/gp/kernels/__init__.py b/pyro/contrib/gp/kernels/__init__.py index 3c2889cf31..8c5985ac10 100644 --- a/pyro/contrib/gp/kernels/__init__.py +++ b/pyro/contrib/gp/kernels/__init__.py @@ -52,9 +52,9 @@ ] # Create sphinx documentation. -__doc__ = '\n\n'.join([ - - ''' +__doc__ = "\n\n".join( + [ + """ {0} ---------------------------------------------------------------- .. autoclass:: pyro.contrib.gp.kernels.{0} @@ -63,6 +63,9 @@ :special-members: __call__ :show-inheritance: :member-order: bysource - '''.format(_name) - for _name in __all__ -]) + """.format( + _name + ) + for _name in __all__ + ] +) diff --git a/pyro/contrib/gp/kernels/brownian.py b/pyro/contrib/gp/kernels/brownian.py index 0c6c8626d3..78e90601c1 100644 --- a/pyro/contrib/gp/kernels/brownian.py +++ b/pyro/contrib/gp/kernels/brownian.py @@ -28,7 +28,7 @@ def __init__(self, input_dim, variance=None, active_dims=None): raise ValueError("Input dimensional for Brownian kernel must be 1.") super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) def forward(self, X, Z=None, diag=False): @@ -43,6 +43,8 @@ def forward(self, X, Z=None, diag=False): raise ValueError("Inputs must have the same number of features.") Zt = Z.t() - return torch.where(X.sign() == Zt.sign(), - self.variance * torch.min(X.abs(), Zt.abs()), - X.data.new_zeros(X.size(0), Z.size(0))) + return torch.where( + X.sign() == Zt.sign(), + self.variance * torch.min(X.abs(), Zt.abs()), + X.data.new_zeros(X.size(0), Z.size(0)), + ) diff --git a/pyro/contrib/gp/kernels/coregionalize.py b/pyro/contrib/gp/kernels/coregionalize.py index bd583baf00..41868433fa 100644 --- a/pyro/contrib/gp/kernels/coregionalize.py +++ b/pyro/contrib/gp/kernels/coregionalize.py @@ -45,7 +45,9 @@ class Coregionalize(Kernel): :param str name: Name of the kernel. """ - def __init__(self, input_dim, rank=None, components=None, diagonal=None, active_dims=None): + def __init__( + self, input_dim, rank=None, components=None, diagonal=None, active_dims=None + ): super().__init__(input_dim, active_dims) # Add a low-rank kernel with expected value torch.eye(input_dim, input_dim) / 2. @@ -55,16 +57,24 @@ def __init__(self, input_dim, rank=None, components=None, diagonal=None, active_ else: rank = components.size(-1) if components.shape != (input_dim, rank): - raise ValueError("Expected components.shape == ({},rank), actual {}" - .format(input_dim, components.shape)) + raise ValueError( + "Expected components.shape == ({},rank), actual {}".format( + input_dim, components.shape + ) + ) self.components = Parameter(components) # Add a diagonal component initialized to torch.eye(input_dim, input_dim) / 2, # such that the total kernel has expected value the identity matrix. - diagonal = components.new_ones(input_dim) * 0.5 if diagonal is None else diagonal + diagonal = ( + components.new_ones(input_dim) * 0.5 if diagonal is None else diagonal + ) if diagonal.shape != (input_dim,): - raise ValueError("Expected diagonal.shape == ({},), actual {}" - .format(input_dim, diagonal.shape)) + raise ValueError( + "Expected diagonal.shape == ({},), actual {}".format( + input_dim, diagonal.shape + ) + ) self.diagonal = PyroParam(diagonal, constraints.positive) def forward(self, X, Z=None, diag=False): diff --git a/pyro/contrib/gp/kernels/dot_product.py b/pyro/contrib/gp/kernels/dot_product.py index 5f4fdaf001..072e2199c4 100644 --- a/pyro/contrib/gp/kernels/dot_product.py +++ b/pyro/contrib/gp/kernels/dot_product.py @@ -16,7 +16,7 @@ class DotProduct(Kernel): def __init__(self, input_dim, variance=None, active_dims=None): super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) def _dot_product(self, X, Z=None, diag=False): @@ -70,12 +70,16 @@ class Polynomial(DotProduct): def __init__(self, input_dim, variance=None, bias=None, degree=1, active_dims=None): super().__init__(input_dim, variance, active_dims) - bias = torch.tensor(1.) if bias is None else bias + bias = torch.tensor(1.0) if bias is None else bias self.bias = PyroParam(bias, constraints.positive) if not isinstance(degree, int) or degree < 1: - raise ValueError("Degree for Polynomial kernel should be a positive integer.") + raise ValueError( + "Degree for Polynomial kernel should be a positive integer." + ) self.degree = degree def forward(self, X, Z=None, diag=False): - return self.variance * ((self.bias + self._dot_product(X, Z, diag)) ** self.degree) + return self.variance * ( + (self.bias + self._dot_product(X, Z, diag)) ** self.degree + ) diff --git a/pyro/contrib/gp/kernels/isotropic.py b/pyro/contrib/gp/kernels/isotropic.py index 8ea21f4f5d..8e0bbe2bdc 100644 --- a/pyro/contrib/gp/kernels/isotropic.py +++ b/pyro/contrib/gp/kernels/isotropic.py @@ -28,13 +28,14 @@ class Isotropy(Kernel): :param torch.Tensor lengthscale: Length-scale parameter of this kernel. """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) - lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale + lengthscale = torch.tensor(1.0) if lengthscale is None else lengthscale self.lengthscale = PyroParam(lengthscale, constraints.positive) def _square_scaled_dist(self, X, Z=None): @@ -77,6 +78,7 @@ class RBF(Isotropy): .. note:: This kernel also has name `Squared Exponential` in literature. """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, variance, lengthscale, active_dims) @@ -98,12 +100,19 @@ class RationalQuadratic(Isotropy): :param torch.Tensor scale_mixture: Scale mixture (:math:`\alpha`) parameter of this kernel. Should have size 1. """ - def __init__(self, input_dim, variance=None, lengthscale=None, scale_mixture=None, - active_dims=None): + + def __init__( + self, + input_dim, + variance=None, + lengthscale=None, + scale_mixture=None, + active_dims=None, + ): super().__init__(input_dim, variance, lengthscale, active_dims) if scale_mixture is None: - scale_mixture = torch.tensor(1.) + scale_mixture = torch.tensor(1.0) self.scale_mixture = PyroParam(scale_mixture, constraints.positive) def forward(self, X, Z=None, diag=False): @@ -111,7 +120,9 @@ def forward(self, X, Z=None, diag=False): return self._diag(X) r2 = self._square_scaled_dist(X, Z) - return self.variance * (1 + (0.5 / self.scale_mixture) * r2).pow(-self.scale_mixture) + return self.variance * (1 + (0.5 / self.scale_mixture) * r2).pow( + -self.scale_mixture + ) class Exponential(Isotropy): @@ -120,6 +131,7 @@ class Exponential(Isotropy): :math:`k(x, z) = \sigma^2\exp\left(-\frac{|x-z|}{l}\right).` """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, variance, lengthscale, active_dims) @@ -138,6 +150,7 @@ class Matern32(Isotropy): :math:`k(x, z) = \sigma^2\left(1 + \sqrt{3} \times \frac{|x-z|}{l}\right) \exp\left(-\sqrt{3} \times \frac{|x-z|}{l}\right).` """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, variance, lengthscale, active_dims) @@ -146,7 +159,7 @@ def forward(self, X, Z=None, diag=False): return self._diag(X) r = self._scaled_dist(X, Z) - sqrt3_r = 3**0.5 * r + sqrt3_r = 3 ** 0.5 * r return self.variance * (1 + sqrt3_r) * torch.exp(-sqrt3_r) @@ -157,6 +170,7 @@ class Matern52(Isotropy): :math:`k(x,z)=\sigma^2\left(1+\sqrt{5}\times\frac{|x-z|}{l}+\frac{5}{3}\times \frac{|x-z|^2}{l^2}\right)\exp\left(-\sqrt{5} \times \frac{|x-z|}{l}\right).` """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, variance, lengthscale, active_dims) @@ -166,5 +180,5 @@ def forward(self, X, Z=None, diag=False): r2 = self._square_scaled_dist(X, Z) r = _torch_sqrt(r2) - sqrt5_r = 5**0.5 * r - return self.variance * (1 + sqrt5_r + (5/3) * r2) * torch.exp(-sqrt5_r) + sqrt5_r = 5 ** 0.5 * r + return self.variance * (1 + sqrt5_r + (5 / 3) * r2) * torch.exp(-sqrt5_r) diff --git a/pyro/contrib/gp/kernels/kernel.py b/pyro/contrib/gp/kernels/kernel.py index bbb29321a8..0271f36e11 100644 --- a/pyro/contrib/gp/kernels/kernel.py +++ b/pyro/contrib/gp/kernels/kernel.py @@ -33,7 +33,9 @@ def __init__(self, input_dim, active_dims=None): if active_dims is None: active_dims = list(range(input_dim)) elif input_dim != len(active_dims): - raise ValueError("Input size and the length of active dimensionals should be equal.") + raise ValueError( + "Input size and the length of active dimensionals should be equal." + ) self.input_dim = input_dim self.active_dims = active_dims @@ -77,13 +79,17 @@ class Combination(Kernel): :param kern1: Second kernel to combine. :type kern1: Kernel or numbers.Number """ + def __init__(self, kern0, kern1): if not isinstance(kern0, Kernel): - raise TypeError("The first component of a combined kernel must be a " - "Kernel instance.") + raise TypeError( + "The first component of a combined kernel must be a " "Kernel instance." + ) if not (isinstance(kern1, Kernel) or isinstance(kern1, numbers.Number)): - raise TypeError("The second component of a combined kernel must be a " - "Kernel instance or a number.") + raise TypeError( + "The second component of a combined kernel must be a " + "Kernel instance or a number." + ) active_dims = set(kern0.active_dims) if isinstance(kern1, Kernel): @@ -101,6 +107,7 @@ class Sum(Combination): Returns a new kernel which acts like a sum/direct sum of two kernels. The second kernel can be a constant. """ + def forward(self, X, Z=None, diag=False): if isinstance(self.kern1, Kernel): return self.kern0(X, Z, diag=diag) + self.kern1(X, Z, diag=diag) @@ -113,6 +120,7 @@ class Product(Combination): Returns a new kernel which acts like a product/tensor product of two kernels. The second kernel can be a constant. """ + def forward(self, X, Z=None, diag=False): if isinstance(self.kern1, Kernel): return self.kern0(X, Z, diag=diag) * self.kern1(X, Z, diag=diag) @@ -127,6 +135,7 @@ class Transforming(Kernel): :param Kernel kern: The original kernel. """ + def __init__(self, kern): super().__init__(kern.input_dim, kern.active_dims) @@ -139,6 +148,7 @@ class Exponent(Transforming): :math:`k_{new}(x, z) = \exp(k(x, z)).` """ + def forward(self, X, Z=None, diag=False): return self.kern(X, Z, diag=diag).exp() @@ -153,6 +163,7 @@ class VerticalScaling(Transforming): :param callable vscaling_fn: A vertical scaling function :math:`f`. """ + def __init__(self, kern, vscaling_fn): super().__init__(kern) @@ -160,13 +171,18 @@ def __init__(self, kern, vscaling_fn): def forward(self, X, Z=None, diag=False): if diag: - return self.vscaling_fn(X) * self.kern(X, Z, diag=diag) * self.vscaling_fn(X) + return ( + self.vscaling_fn(X) * self.kern(X, Z, diag=diag) * self.vscaling_fn(X) + ) elif Z is None: vscaled_X = self.vscaling_fn(X).unsqueeze(1) return vscaled_X * self.kern(X, Z, diag=diag) * vscaled_X.t() else: - return (self.vscaling_fn(X).unsqueeze(1) * self.kern(X, Z, diag=diag) * - self.vscaling_fn(Z).unsqueeze(0)) + return ( + self.vscaling_fn(X).unsqueeze(1) + * self.kern(X, Z, diag=diag) + * self.vscaling_fn(Z).unsqueeze(0) + ) def _Horner_evaluate(x, coef): @@ -176,7 +192,7 @@ def _Horner_evaluate(x, coef): # https://en.wikipedia.org/wiki/Horner%27s_method n = len(coef) - 1 b = coef[n] - for i in range(n-1, -1, -1): + for i in range(n - 1, -1, -1): b = coef[i] + b * x return b @@ -209,6 +225,7 @@ class Warping(Transforming): :param list owarping_coef: A list of coefficients of the output warping polynomial. These coefficients must be non-negative. """ + def __init__(self, kern, iwarping_fn=None, owarping_coef=None): super().__init__(kern) @@ -217,11 +234,15 @@ def __init__(self, kern, iwarping_fn=None, owarping_coef=None): if owarping_coef is not None: for coef in owarping_coef: if not isinstance(coef, int) and coef < 0: - raise ValueError("Coefficients of the polynomial must be a " - "non-negative integer.") + raise ValueError( + "Coefficients of the polynomial must be a " + "non-negative integer." + ) if len(owarping_coef) < 2 and sum(owarping_coef) == 0: - raise ValueError("The ouput warping polynomial should have a degree " - "of at least 1.") + raise ValueError( + "The ouput warping polynomial should have a degree " + "of at least 1." + ) self.owarping_coef = owarping_coef def forward(self, X, Z=None, diag=False): diff --git a/pyro/contrib/gp/kernels/periodic.py b/pyro/contrib/gp/kernels/periodic.py index f6519112d9..afad7ccad1 100644 --- a/pyro/contrib/gp/kernels/periodic.py +++ b/pyro/contrib/gp/kernels/periodic.py @@ -19,6 +19,7 @@ class Cosine(Isotropy): :param torch.Tensor lengthscale: Length-scale parameter of this kernel. """ + def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None): super().__init__(input_dim, variance, lengthscale, active_dims) @@ -46,16 +47,19 @@ class Periodic(Kernel): :param torch.Tensor lengthscale: Length scale parameter of this kernel. :param torch.Tensor period: Period parameter of this kernel. """ - def __init__(self, input_dim, variance=None, lengthscale=None, period=None, active_dims=None): + + def __init__( + self, input_dim, variance=None, lengthscale=None, period=None, active_dims=None + ): super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) - lengthscale = torch.tensor(1.) if lengthscale is None else lengthscale + lengthscale = torch.tensor(1.0) if lengthscale is None else lengthscale self.lengthscale = PyroParam(lengthscale, constraints.positive) - period = torch.tensor(1.) if period is None else period + period = torch.tensor(1.0) if period is None else period self.period = PyroParam(period, constraints.positive) def forward(self, X, Z=None, diag=False): diff --git a/pyro/contrib/gp/kernels/static.py b/pyro/contrib/gp/kernels/static.py index 11ebc03ff2..081ea80e94 100644 --- a/pyro/contrib/gp/kernels/static.py +++ b/pyro/contrib/gp/kernels/static.py @@ -14,10 +14,11 @@ class Constant(Kernel): :math:`k(x, z) = \sigma^2.` """ + def __init__(self, input_dim, variance=None, active_dims=None): super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) def forward(self, X, Z=None, diag=False): @@ -37,10 +38,11 @@ class WhiteNoise(Kernel): where :math:`\delta` is a Dirac delta function. """ + def __init__(self, input_dim, variance=None, active_dims=None): super().__init__(input_dim, active_dims) - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) def forward(self, X, Z=None, diag=False): diff --git a/pyro/contrib/gp/likelihoods/__init__.py b/pyro/contrib/gp/likelihoods/__init__.py index 993d876dbc..30749be3fb 100644 --- a/pyro/contrib/gp/likelihoods/__init__.py +++ b/pyro/contrib/gp/likelihoods/__init__.py @@ -17,9 +17,9 @@ # Create sphinx documentation. -__doc__ = '\n\n'.join([ - - ''' +__doc__ = "\n\n".join( + [ + """ {0} ---------------------------------------------------------------- .. autoclass:: pyro.contrib.gp.likelihoods.{0} @@ -28,6 +28,9 @@ :special-members: __call__ :show-inheritance: :member-order: bysource - '''.format(_name) - for _name in __all__ -]) + """.format( + _name + ) + for _name in __all__ + ] +) diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index ef417f9e22..9597c3dc31 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -20,9 +20,12 @@ class Binary(Likelihood): :param callable response_function: A mapping to correct domain for Binary likelihood. """ + def __init__(self, response_function=None): super().__init__() - self.response_function = torch.sigmoid if response_function is None else response_function + self.response_function = ( + torch.sigmoid if response_function is None else response_function + ) def forward(self, f_loc, f_var, y=None): r""" @@ -48,5 +51,5 @@ def forward(self, f_loc, f_var, y=None): f_res = self.response_function(f) y_dist = dist.Bernoulli(f_res) if y is not None: - y_dist = y_dist.expand_by(y.shape[:-f.dim()]).to_event(y.dim()) + y_dist = y_dist.expand_by(y.shape[: -f.dim()]).to_event(y.dim()) return pyro.sample(self._pyro_get_fullname("y"), y_dist, obs=y) diff --git a/pyro/contrib/gp/likelihoods/gaussian.py b/pyro/contrib/gp/likelihoods/gaussian.py index b1b65ff95c..314a375a45 100644 --- a/pyro/contrib/gp/likelihoods/gaussian.py +++ b/pyro/contrib/gp/likelihoods/gaussian.py @@ -19,10 +19,11 @@ class Gaussian(Likelihood): :param torch.Tensor variance: A variance parameter, which plays the role of ``noise`` in regression problems. """ + def __init__(self, variance=None): super().__init__() - variance = torch.tensor(1.) if variance is None else variance + variance = torch.tensor(1.0) if variance is None else variance self.variance = PyroParam(variance, constraints.positive) def forward(self, f_loc, f_var, y=None): @@ -43,5 +44,5 @@ def forward(self, f_loc, f_var, y=None): y_dist = dist.Normal(f_loc, y_var.sqrt()) if y is not None: - y_dist = y_dist.expand_by(y.shape[:-f_loc.dim()]).to_event(y.dim()) + y_dist = y_dist.expand_by(y.shape[: -f_loc.dim()]).to_event(y.dim()) return pyro.sample(self._pyro_get_fullname("y"), y_dist, obs=y) diff --git a/pyro/contrib/gp/likelihoods/likelihood.py b/pyro/contrib/gp/likelihoods/likelihood.py index 2e22cdde9e..48146c9e9a 100644 --- a/pyro/contrib/gp/likelihoods/likelihood.py +++ b/pyro/contrib/gp/likelihoods/likelihood.py @@ -11,6 +11,7 @@ class Likelihood(Parameterized): Every inherited class should implement a forward pass which takes an input :math:`f` and returns a sample :math:`y`. """ + def __init__(self): super().__init__() diff --git a/pyro/contrib/gp/likelihoods/multi_class.py b/pyro/contrib/gp/likelihoods/multi_class.py index ed8463f8bd..25253ddcb0 100644 --- a/pyro/contrib/gp/likelihoods/multi_class.py +++ b/pyro/contrib/gp/likelihoods/multi_class.py @@ -25,10 +25,13 @@ class MultiClass(Likelihood): :param callable response_function: A mapping to correct domain for MultiClass likelihood. """ + def __init__(self, num_classes, response_function=None): super().__init__() self.num_classes = num_classes - self.response_function = _softmax if response_function is None else response_function + self.response_function = ( + _softmax if response_function is None else response_function + ) def forward(self, f_loc, f_var, y=None): r""" @@ -49,21 +52,26 @@ def forward(self, f_loc, f_var, y=None): # calculates Monte Carlo estimate for E_q(f) [logp(y | f)] f = dist.Normal(f_loc, f_var.sqrt())() if f.dim() < 2: - raise ValueError("Latent function output should have at least 2 " - "dimensions: one for number of classes and one for " - "number of data.") + raise ValueError( + "Latent function output should have at least 2 " + "dimensions: one for number of classes and one for " + "number of data." + ) # swap class dimension and data dimension f_swap = f.transpose(-2, -1) # -> num_data x num_classes if f_swap.size(-1) != self.num_classes: - raise ValueError("Number of Gaussian processes should be equal to the " - "number of classes. Expected {} but got {}." - .format(self.num_classes, f_swap.size(-1))) + raise ValueError( + "Number of Gaussian processes should be equal to the " + "number of classes. Expected {} but got {}.".format( + self.num_classes, f_swap.size(-1) + ) + ) if self.response_function is _softmax: y_dist = dist.Categorical(logits=f_swap) else: f_res = self.response_function(f_swap) y_dist = dist.Categorical(f_res) if y is not None: - y_dist = y_dist.expand_by(y.shape[:-f.dim() + 1]).to_event(y.dim()) + y_dist = y_dist.expand_by(y.shape[: -f.dim() + 1]).to_event(y.dim()) return pyro.sample(self._pyro_get_fullname("y"), y_dist, obs=y) diff --git a/pyro/contrib/gp/likelihoods/poisson.py b/pyro/contrib/gp/likelihoods/poisson.py index 8abed6fd2a..fcf23ca37b 100644 --- a/pyro/contrib/gp/likelihoods/poisson.py +++ b/pyro/contrib/gp/likelihoods/poisson.py @@ -19,9 +19,12 @@ class Poisson(Likelihood): :param callable response_function: A mapping to positive real numbers. """ + def __init__(self, response_function=None): super().__init__() - self.response_function = torch.exp if response_function is None else response_function + self.response_function = ( + torch.exp if response_function is None else response_function + ) def forward(self, f_loc, f_var, y=None): r""" @@ -45,5 +48,5 @@ def forward(self, f_loc, f_var, y=None): y_dist = dist.Poisson(f_res) if y is not None: - y_dist = y_dist.expand_by(y.shape[:-f_res.dim()]).to_event(y.dim()) + y_dist = y_dist.expand_by(y.shape[: -f_res.dim()]).to_event(y.dim()) return pyro.sample(self._pyro_get_fullname("y"), y_dist, obs=y) diff --git a/pyro/contrib/gp/models/gplvm.py b/pyro/contrib/gp/models/gplvm.py index fe29e91365..f0d988325c 100644 --- a/pyro/contrib/gp/models/gplvm.py +++ b/pyro/contrib/gp/models/gplvm.py @@ -55,14 +55,19 @@ class GPLVM(Parameterized): model object. Note that ``base_model.X`` will be the initial value for the variational parameter ``X_loc``. """ + def __init__(self, base_model): super().__init__() if base_model.X.dim() != 2: - raise ValueError("GPLVM model only works with 2D latent X, but got " - "X.dim() = {}.".format(base_model.X.dim())) + raise ValueError( + "GPLVM model only works with 2D latent X, but got " + "X.dim() = {}.".format(base_model.X.dim()) + ) self.base_model = base_model - self.X = PyroSample(dist.Normal(base_model.X.new_zeros(base_model.X.shape), 1.).to_event()) + self.X = PyroSample( + dist.Normal(base_model.X.new_zeros(base_model.X.shape), 1.0).to_event() + ) self.autoguide("X", dist.Normal) self.X_loc.data = base_model.X diff --git a/pyro/contrib/gp/models/gpr.py b/pyro/contrib/gp/models/gpr.py index 530d39f64e..a0b7967d9f 100644 --- a/pyro/contrib/gp/models/gpr.py +++ b/pyro/contrib/gp/models/gpr.py @@ -65,10 +65,11 @@ class GPRegression(GPModel): :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ + def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6): super().__init__(X, y, kernel, mean_function, jitter) - noise = self.X.new_tensor(1.) if noise is None else noise + noise = self.X.new_tensor(1.0) if noise is None else noise self.noise = PyroParam(noise, constraints.positive) @pyro_method @@ -77,7 +78,7 @@ def model(self): N = self.X.size(0) Kff = self.kernel(self.X) - Kff.view(-1)[::N + 1] += self.jitter + self.noise # add noise to diagonal + Kff.view(-1)[:: N + 1] += self.jitter + self.noise # add noise to diagonal Lff = torch.linalg.cholesky(Kff) zero_loc = self.X.new_zeros(self.X.size(0)) @@ -86,11 +87,13 @@ def model(self): f_var = Lff.pow(2).sum(dim=-1) return f_loc, f_var else: - return pyro.sample(self._pyro_get_fullname("y"), - dist.MultivariateNormal(f_loc, scale_tril=Lff) - .expand_by(self.y.shape[:-1]) - .to_event(self.y.dim() - 1), - obs=self.y) + return pyro.sample( + self._pyro_get_fullname("y"), + dist.MultivariateNormal(f_loc, scale_tril=Lff) + .expand_by(self.y.shape[:-1]) + .to_event(self.y.dim() - 1), + obs=self.y, + ) @pyro_method def guide(self): @@ -122,17 +125,25 @@ def forward(self, Xnew, full_cov=False, noiseless=True): N = self.X.size(0) Kff = self.kernel(self.X).contiguous() - Kff.view(-1)[::N + 1] += self.jitter + self.noise # add noise to the diagonal + Kff.view(-1)[:: N + 1] += self.jitter + self.noise # add noise to the diagonal Lff = torch.linalg.cholesky(Kff) y_residual = self.y - self.mean_function(self.X) - loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff, - full_cov, jitter=self.jitter) + loc, cov = conditional( + Xnew, + self.X, + self.kernel, + y_residual, + None, + Lff, + full_cov, + jitter=self.jitter, + ) if full_cov and not noiseless: M = Xnew.size(0) cov = cov.contiguous() - cov.view(-1, M * M)[:, ::M + 1] += self.noise # add noise to the diagonal + cov.view(-1, M * M)[:, :: M + 1] += self.noise # add noise to the diagonal if not full_cov and not noiseless: cov = cov + self.noise @@ -165,7 +176,7 @@ def iter_sample(self, noiseless=True): y = self.y.clone().detach() N = X.size(0) Kff = self.kernel(X).contiguous() - Kff.view(-1)[::N + 1] += noise # add noise to the diagonal + Kff.view(-1)[:: N + 1] += noise # add noise to the diagonal outside_vars = {"X": X, "y": y, "N": N, "Kff": Kff} @@ -183,16 +194,19 @@ def sample_next(xnew, outside_vars): y_residual = y - self.mean_function(X) # Compute conditional mean and variance - loc, cov = conditional(xnew, X, self.kernel, y_residual, None, Lff, - False, jitter=self.jitter) + loc, cov = conditional( + xnew, X, self.kernel, y_residual, None, Lff, False, jitter=self.jitter + ) if not noiseless: cov = cov + noise - ynew = torchdist.Normal(loc + self.mean_function(xnew), cov.sqrt()).rsample() + ynew = torchdist.Normal( + loc + self.mean_function(xnew), cov.sqrt() + ).rsample() # Update kernel matrix N = outside_vars["N"] - Kffnew = Kff.new_empty(N+1, N+1) + Kffnew = Kff.new_empty(N + 1, N + 1) Kffnew[:N, :N] = Kff cross = self.kernel(X, xnew).squeeze() end = self.kernel(xnew, xnew).squeeze() @@ -201,7 +215,7 @@ def sample_next(xnew, outside_vars): # No noise, just jitter for numerical stability Kffnew[N, N] = end + self.jitter # Heuristic to avoid adding degenerate points - if Kffnew.logdet() > -15.: + if Kffnew.logdet() > -15.0: outside_vars["Kff"] = Kffnew outside_vars["N"] += 1 outside_vars["X"] = torch.cat((X, xnew)) diff --git a/pyro/contrib/gp/models/model.py b/pyro/contrib/gp/models/model.py index dc1c53db3e..51bb4c0877 100644 --- a/pyro/contrib/gp/models/model.py +++ b/pyro/contrib/gp/models/model.py @@ -87,12 +87,14 @@ class GPModel(Parameterized): :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ + def __init__(self, X, y, kernel, mean_function=None, jitter=1e-6): super().__init__() self.set_data(X, y) self.kernel = kernel - self.mean_function = (mean_function if mean_function is not None else - _zero_mean_function) + self.mean_function = ( + mean_function if mean_function is not None else _zero_mean_function + ) self.jitter = jitter def model(self): @@ -185,9 +187,12 @@ def set_data(self, X, y=None): number of data points. """ if y is not None and X.size(0) != y.size(-1): - raise ValueError("Expected the number of input data points equal to the " - "number of output data points, but got {} and {}." - .format(X.size(0), y.size(-1))) + raise ValueError( + "Expected the number of input data points equal to the " + "number of output data points, but got {} and {}.".format( + X.size(0), y.size(-1) + ) + ) self.X = X self.y = y @@ -199,10 +204,16 @@ def _check_Xnew_shape(self, Xnew): ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. """ if Xnew.dim() != self.X.dim(): - raise ValueError("Train data and test data should have the same " - "number of dimensions, but got {} and {}." - .format(self.X.dim(), Xnew.dim())) + raise ValueError( + "Train data and test data should have the same " + "number of dimensions, but got {} and {}.".format( + self.X.dim(), Xnew.dim() + ) + ) if self.X.shape[1:] != Xnew.shape[1:]: - raise ValueError("Train data and test data should have the same " - "shape of features, but got {} and {}." - .format(self.X.shape[1:], Xnew.shape[1:])) + raise ValueError( + "Train data and test data should have the same " + "shape of features, but got {} and {}.".format( + self.X.shape[1:], Xnew.shape[1:] + ) + ) diff --git a/pyro/contrib/gp/models/sgpr.py b/pyro/contrib/gp/models/sgpr.py index b2b0eedab8..4052138491 100644 --- a/pyro/contrib/gp/models/sgpr.py +++ b/pyro/contrib/gp/models/sgpr.py @@ -94,12 +94,15 @@ class SparseGPRegression(GPModel): a covariance matrix to help stablize its Cholesky decomposition. :param str name: Name of this model. """ - def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-6): + + def __init__( + self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-6 + ): super().__init__(X, y, kernel, mean_function, jitter) self.Xu = Parameter(Xu) - noise = self.X.new_tensor(1.) if noise is None else noise + noise = self.X.new_tensor(1.0) if noise is None else noise self.noise = PyroParam(noise, constraints.positive) if approx is None: @@ -107,8 +110,10 @@ def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None elif approx in ["DTC", "FITC", "VFE"]: self.approx = approx else: - raise ValueError("The sparse approximation method should be one of " - "'DTC', 'FITC', 'VFE'.") + raise ValueError( + "The sparse approximation method should be one of " + "'DTC', 'FITC', 'VFE'." + ) @pyro_method def model(self): @@ -126,7 +131,7 @@ def model(self): N = self.X.size(0) M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() - Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal + Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal Luu = torch.linalg.cholesky(Kuu) Kuf = self.kernel(self.Xu, self.X) W = Kuf.triangular_solve(Luu, upper=False)[0].t() @@ -148,13 +153,15 @@ def model(self): return f_loc, f_var else: if self.approx == "VFE": - pyro.factor(self._pyro_get_fullname("trace_term"), -trace_term / 2.) + pyro.factor(self._pyro_get_fullname("trace_term"), -trace_term / 2.0) - return pyro.sample(self._pyro_get_fullname("y"), - dist.LowRankMultivariateNormal(f_loc, W, D) - .expand_by(self.y.shape[:-1]) - .to_event(self.y.dim() - 1), - obs=self.y) + return pyro.sample( + self._pyro_get_fullname("y"), + dist.LowRankMultivariateNormal(f_loc, W, D) + .expand_by(self.y.shape[:-1]) + .to_event(self.y.dim() - 1), + obs=self.y, + ) @pyro_method def guide(self): @@ -203,7 +210,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): # TODO: cache these calculations to get faster inference Kuu = self.kernel(self.Xu).contiguous() - Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal + Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal Luu = torch.linalg.cholesky(Kuu) Kuf = self.kernel(self.Xu, self.X) @@ -217,7 +224,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): W_Dinv = W / D K = W_Dinv.matmul(W.t()).contiguous() - K.view(-1)[::M + 1] += 1 # add identity matrix to K + K.view(-1)[:: M + 1] += 1 # add identity matrix to K L = torch.linalg.cholesky(K) # get y_residual and convert it into 2D tensor for packing @@ -232,8 +239,8 @@ def forward(self, Xnew, full_cov=False, noiseless=True): pack = torch.cat((W_Dinv_y, Ws), dim=1) Linv_pack = pack.triangular_solve(L, upper=False)[0] # unpack - Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]] - Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] + Linv_W_Dinv_y = Linv_pack[:, : W_Dinv_y.shape[1]] + Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1] :] C = Xnew.size(0) loc_shape = self.y.shape[:-1] + (C,) @@ -242,7 +249,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): if full_cov: Kss = self.kernel(Xnew).contiguous() if not noiseless: - Kss.view(-1)[::C + 1] += self.noise # add noise to the diagonal + Kss.view(-1)[:: C + 1] += self.noise # add noise to the diagonal Qss = Ws.t().matmul(Ws) cov = Kss - Qss + Linv_Ws.t().matmul(Linv_Ws) cov_shape = self.y.shape[:-1] + (C, C) diff --git a/pyro/contrib/gp/models/vgp.py b/pyro/contrib/gp/models/vgp.py index ec7d6c36ff..8e87082cfb 100644 --- a/pyro/contrib/gp/models/vgp.py +++ b/pyro/contrib/gp/models/vgp.py @@ -59,8 +59,18 @@ class VariationalGP(GPModel): :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ - def __init__(self, X, y, kernel, likelihood, mean_function=None, - latent_shape=None, whiten=False, jitter=1e-6): + + def __init__( + self, + X, + y, + kernel, + likelihood, + mean_function=None, + latent_shape=None, + whiten=False, + jitter=1e-6, + ): super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood @@ -85,21 +95,27 @@ def model(self): N = self.X.size(0) Kff = self.kernel(self.X).contiguous() - Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal + Kff.view(-1)[:: N + 1] += self.jitter # add jitter to the diagonal Lff = torch.linalg.cholesky(Kff) zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) - pyro.sample(self._pyro_get_fullname("f"), - dist.MultivariateNormal(zero_loc, scale_tril=identity) - .to_event(zero_loc.dim() - 1)) + pyro.sample( + self._pyro_get_fullname("f"), + dist.MultivariateNormal(zero_loc, scale_tril=identity).to_event( + zero_loc.dim() - 1 + ), + ) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: - pyro.sample(self._pyro_get_fullname("f"), - dist.MultivariateNormal(zero_loc, scale_tril=Lff) - .to_event(zero_loc.dim() - 1)) + pyro.sample( + self._pyro_get_fullname("f"), + dist.MultivariateNormal(zero_loc, scale_tril=Lff).to_event( + zero_loc.dim() - 1 + ), + ) f_scale_tril = self.f_scale_tril f_loc = self.f_loc @@ -115,9 +131,12 @@ def guide(self): self.set_mode("guide") self._load_pyro_samples() - pyro.sample(self._pyro_get_fullname("f"), - dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril) - .to_event(self.f_loc.dim()-1)) + pyro.sample( + self._pyro_get_fullname("f"), + dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril).to_event( + self.f_loc.dim() - 1 + ), + ) def forward(self, Xnew, full_cov=False): r""" @@ -141,6 +160,14 @@ def forward(self, Xnew, full_cov=False): self._check_Xnew_shape(Xnew) self.set_mode("guide") - loc, cov = conditional(Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril, - full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) + loc, cov = conditional( + Xnew, + self.X, + self.kernel, + self.f_loc, + self.f_scale_tril, + full_cov=full_cov, + whiten=self.whiten, + jitter=self.jitter, + ) return loc + self.mean_function(Xnew), cov diff --git a/pyro/contrib/gp/models/vsgp.py b/pyro/contrib/gp/models/vsgp.py index 4acfab82c0..e4b83a0b30 100644 --- a/pyro/contrib/gp/models/vsgp.py +++ b/pyro/contrib/gp/models/vsgp.py @@ -78,8 +78,20 @@ class VariationalSparseGP(GPModel): :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ - def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None, - latent_shape=None, num_data=None, whiten=False, jitter=1e-6): + + def __init__( + self, + X, + y, + kernel, + Xu, + likelihood, + mean_function=None, + latent_shape=None, + num_data=None, + whiten=False, + jitter=1e-6, + ): super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood @@ -106,22 +118,37 @@ def model(self): M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() - Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal + Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal Luu = torch.linalg.cholesky(Kuu) zero_loc = self.Xu.new_zeros(self.u_loc.shape) if self.whiten: identity = eye_like(self.Xu, M) - pyro.sample(self._pyro_get_fullname("u"), - dist.MultivariateNormal(zero_loc, scale_tril=identity) - .to_event(zero_loc.dim() - 1)) + pyro.sample( + self._pyro_get_fullname("u"), + dist.MultivariateNormal(zero_loc, scale_tril=identity).to_event( + zero_loc.dim() - 1 + ), + ) else: - pyro.sample(self._pyro_get_fullname("u"), - dist.MultivariateNormal(zero_loc, scale_tril=Luu) - .to_event(zero_loc.dim() - 1)) - - f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, - Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) + pyro.sample( + self._pyro_get_fullname("u"), + dist.MultivariateNormal(zero_loc, scale_tril=Luu).to_event( + zero_loc.dim() - 1 + ), + ) + + f_loc, f_var = conditional( + self.X, + self.Xu, + self.kernel, + self.u_loc, + self.u_scale_tril, + Luu, + full_cov=False, + whiten=self.whiten, + jitter=self.jitter, + ) f_loc = f_loc + self.mean_function(self.X) if self.y is None: @@ -137,9 +164,12 @@ def guide(self): self.set_mode("guide") self._load_pyro_samples() - pyro.sample(self._pyro_get_fullname("u"), - dist.MultivariateNormal(self.u_loc, scale_tril=self.u_scale_tril) - .to_event(self.u_loc.dim()-1)) + pyro.sample( + self._pyro_get_fullname("u"), + dist.MultivariateNormal(self.u_loc, scale_tril=self.u_scale_tril).to_event( + self.u_loc.dim() - 1 + ), + ) def forward(self, Xnew, full_cov=False): r""" @@ -163,6 +193,14 @@ def forward(self, Xnew, full_cov=False): self._check_Xnew_shape(Xnew) self.set_mode("guide") - loc, cov = conditional(Xnew, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, - full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) + loc, cov = conditional( + Xnew, + self.Xu, + self.kernel, + self.u_loc, + self.u_scale_tril, + full_cov=full_cov, + whiten=self.whiten, + jitter=self.jitter, + ) return loc + self.mean_function(Xnew), cov diff --git a/pyro/contrib/gp/parameterized.py b/pyro/contrib/gp/parameterized.py index 8c4e9098e1..04ddd7d280 100644 --- a/pyro/contrib/gp/parameterized.py +++ b/pyro/contrib/gp/parameterized.py @@ -43,9 +43,11 @@ def _get_sample_fn(module, name): # otherwise, we do inference in unconstrained space and transform the value # back to original space # TODO: move this logic to infer.autoguide or somewhere else - unconstrained_value = pyro.sample(module._pyro_get_fullname("{}_latent".format(name)), - guide.to_event(), - infer={"is_auxiliary": True}) + unconstrained_value = pyro.sample( + module._pyro_get_fullname("{}_latent".format(name)), + guide.to_event(), + infer={"is_auxiliary": True}, + ) transform = biject_to(support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) @@ -86,6 +88,7 @@ class Parameterized(PyroModule): :meth:`~torch.nn.Module.cuda`. See :class:`torch.nn.Module` for more information. """ + def __init__(self): super().__init__() self._priors = OrderedDict() @@ -100,8 +103,11 @@ def set_prior(self, name, prior): :param ~pyro.distributions.distribution.Distribution prior: A Pyro prior distribution. """ - warnings.warn("The method `self.set_prior({}, prior)` has been deprecated" - " in favor of `self.{} = PyroSample(prior)`.".format(name, name), UserWarning) + warnings.warn( + "The method `self.set_prior({}, prior)` has been deprecated" + " in favor of `self.{} = PyroSample(prior)`.".format(name, name), + UserWarning, + ) setattr(self, name, PyroSample(prior)) def __setattr__(self, name, value): @@ -132,8 +138,9 @@ def autoguide(self, name, dist_constructor): raise ValueError("There is no prior for parameter: {}".format(name)) if dist_constructor not in [dist.Delta, dist.Normal, dist.MultivariateNormal]: - raise NotImplementedError("Unsupported distribution type: {}" - .format(dist_constructor)) + raise NotImplementedError( + "Unsupported distribution type: {}".format(dist_constructor) + ) # delete old guide if name in self._guides: @@ -159,8 +166,9 @@ def autoguide(self, name, dist_constructor): elif dist_constructor is dist.MultivariateNormal: loc = Parameter(biject_to(self._priors[name].support).inv(p).detach()) identity = eye_like(loc, loc.size(-1)) - scale_tril = PyroParam(identity.repeat(loc.shape[:-1] + (1, 1)), - constraints.lower_cholesky) + scale_tril = PyroParam( + identity.repeat(loc.shape[:-1] + (1, 1)), constraints.lower_cholesky + ) setattr(self, "{}_loc".format(name), loc) setattr(self, "{}_scale_tril".format(name), scale_tril) dist_args = ("loc", "scale_tril") diff --git a/pyro/contrib/gp/util.py b/pyro/contrib/gp/util.py index 3fb9846557..4309ec96f2 100644 --- a/pyro/contrib/gp/util.py +++ b/pyro/contrib/gp/util.py @@ -7,8 +7,17 @@ from pyro.infer.util import torch_backward, torch_item -def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=False, - whiten=False, jitter=1e-6): +def conditional( + Xnew, + X, + kernel, + f_loc, + f_scale_tril=None, + Lff=None, + full_cov=False, + whiten=False, + jitter=1e-6, +): r""" Given :math:`X_{new}`, predicts loc and covariance matrix of the conditional multivariate normal distribution @@ -82,7 +91,7 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa if Lff is None: Kff = kernel(X).contiguous() - Kff.view(-1)[::N + 1] += jitter # add jitter to diagonal + Kff.view(-1)[:: N + 1] += jitter # add jitter to diagonal Lff = torch.linalg.cholesky(Kff) Kfs = kernel(X, Xnew) @@ -108,10 +117,10 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa Lffinv_pack = pack.triangular_solve(Lff, upper=False)[0] # unpack - v_2D = Lffinv_pack[:, :f_loc_2D.size(1)] - W = Lffinv_pack[:, f_loc_2D.size(1):f_loc_2D.size(1) + M].t() + v_2D = Lffinv_pack[:, : f_loc_2D.size(1)] + W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t() if f_scale_tril is not None: - S_2D = Lffinv_pack[:, -f_scale_tril_2D.size(1):] + S_2D = Lffinv_pack[:, -f_scale_tril_2D.size(1) :] loc_shape = latent_shape + (M,) loc = W.matmul(v_2D).t().reshape(loc_shape) @@ -164,8 +173,11 @@ def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1 :returns: a list of losses during the training procedure :rtype: list """ - optimizer = (torch.optim.Adam(gpmodule.parameters(), lr=0.01) - if optimizer is None else optimizer) + optimizer = ( + torch.optim.Adam(gpmodule.parameters(), lr=0.01) + if optimizer is None + else optimizer + ) # TODO: add support for JIT loss loss_fn = TraceMeanField_ELBO().differentiable_loss if loss_fn is None else loss_fn diff --git a/pyro/contrib/minipyro.py b/pyro/contrib/minipyro.py index 0959c4a683..f0f4452c92 100644 --- a/pyro/contrib/minipyro.py +++ b/pyro/contrib/minipyro.py @@ -76,8 +76,9 @@ def __enter__(self): # trace illustrates why we need postprocess_message in addition to process_message: # We only want to record a value after all other effects have been applied def postprocess_message(self, msg): - assert msg["type"] != "sample" or msg["name"] not in self.trace, \ - "sample sites must have unique names" + assert ( + msg["type"] != "sample" or msg["name"] not in self.trace + ), "sample sites must have unique names" self.trace[msg["name"]] = msg.copy() def get_trace(self, *args, **kwargs): @@ -121,18 +122,21 @@ def __init__(self, fn=None, rng_seed=None): def __enter__(self): self.old_state = { - 'torch': torch.get_rng_state(), 'random': random.getstate(), 'numpy': np.random.get_state() + "torch": torch.get_rng_state(), + "random": random.getstate(), + "numpy": np.random.get_state(), } torch.manual_seed(self.rng_seed) random.seed(self.rng_seed) np.random.seed(self.rng_seed) def __exit__(self, type, value, traceback): - torch.set_rng_state(self.old_state['torch']) - random.setstate(self.old_state['random']) - if 'numpy' in self.old_state: + torch.set_rng_state(self.old_state["torch"]) + random.setstate(self.old_state["random"]) + if "numpy" in self.old_state: import numpy as np - np.random.set_state(self.old_state['numpy']) + + np.random.set_state(self.old_state["numpy"]) # This limited implementation of PlateMessenger only implements broadcasting. @@ -170,7 +174,7 @@ def apply_stack(msg): # A Messenger that sets msg["stop"] == True also prevents application # of postprocess_message by Messengers above it on the stack # via the pointer variable from the process_message loop - for handler in PYRO_STACK[-pointer-1:]: + for handler in PYRO_STACK[-pointer - 1 :]: handler.postprocess_message(msg) return msg @@ -178,7 +182,7 @@ def apply_stack(msg): # sample is an effectful version of Distribution.sample(...) # When any effect handlers are active, it constructs an initial message and calls apply_stack. def sample(name, fn, *args, **kwargs): - obs = kwargs.pop('obs', None) + obs = kwargs.pop("obs", None) # if there are no active Messengers, we just draw a sample and return it as expected: if not PYRO_STACK: @@ -201,7 +205,12 @@ def sample(name, fn, *args, **kwargs): # param is an effectful version of PARAM_STORE.setdefault that also handles constraints. # When any effect handlers are active, it constructs an initial message and calls apply_stack. -def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None): +def param( + name, + init_value=None, + constraint=torch.distributions.constraints.real, + event_dim=None, +): if event_dim is not None: raise NotImplementedError("minipyro.plate does not support the event_dim arg") @@ -213,12 +222,16 @@ def fn(init_value, constraint): assert init_value is not None with torch.no_grad(): constrained_value = init_value.detach() - unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value) + unconstrained_value = torch.distributions.transform_to(constraint).inv( + constrained_value + ) unconstrained_value.requires_grad_() PARAM_STORE[name] = unconstrained_value, constraint # Transform from unconstrained space to constrained space. - constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value) + constrained_value = torch.distributions.transform_to(constraint)( + unconstrained_value + ) constrained_value.unconstrained = weakref.ref(unconstrained_value) return constrained_value @@ -295,8 +308,7 @@ def step(self, *args, **kwargs): # Differentiate the loss. loss.backward() # Grab all the parameters from the trace. - params = [site["value"].unconstrained() - for site in param_capture.values()] + params = [site["value"].unconstrained() for site in param_capture.values()] # Take a step w.r.t. each parameter in params. self.optim(params) # Zero out the gradients so that they don't accumulate. @@ -323,7 +335,7 @@ def elbo(model, guide, *args, **kwargs): # distribution defined by the guide. model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs) # We will accumulate the various terms of the ELBO in `elbo`. - elbo = 0. + elbo = 0.0 # Loop over all the sample sites in the model and add the corresponding # log p(z) term to the ELBO. Note that this will also include any observed # data, i.e. sample sites with the keyword `obs=...`. @@ -362,17 +374,20 @@ def __call__(self, model, guide, *args): self._param_trace = tr # Augment args with reads from the global param store. - unconstrained_params = tuple(param(name).unconstrained() - for name in self._param_trace) + unconstrained_params = tuple( + param(name).unconstrained() for name in self._param_trace + ) params_and_args = unconstrained_params + args # On first call, create a compiled elbo. if self._compiled is None: def compiled(*params_and_args): - unconstrained_params = params_and_args[:len(self._param_trace)] - args = params_and_args[len(self._param_trace):] - for name, unconstrained_param in zip(self._param_trace, unconstrained_params): + unconstrained_params = params_and_args[: len(self._param_trace)] + args = params_and_args[len(self._param_trace) :] + for name, unconstrained_param in zip( + self._param_trace, unconstrained_params + ): constrained_param = param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param self._param_trace[name]["value"] = constrained_param @@ -381,6 +396,8 @@ def compiled(*params_and_args): with validation_enabled(False), warnings.catch_warnings(): if self.ignore_jit_warnings: warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) - self._compiled = torch.jit.trace(compiled, params_and_args, check_trace=False) + self._compiled = torch.jit.trace( + compiled, params_and_args, check_trace=False + ) return self._compiled(*params_and_args) diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index b6fbbb4489..4df5444739 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -5,12 +5,33 @@ import torch from torch.utils.data import Dataset -alphabets = {'amino-acid': np.array( - ['R', 'H', 'K', 'D', 'E', - 'S', 'T', 'N', 'Q', 'C', - 'G', 'P', 'A', 'V', 'I', - 'L', 'M', 'F', 'Y', 'W']), - 'dna': np.array(['A', 'C', 'G', 'T'])} +alphabets = { + "amino-acid": np.array( + [ + "R", + "H", + "K", + "D", + "E", + "S", + "T", + "N", + "Q", + "C", + "G", + "P", + "A", + "V", + "I", + "L", + "M", + "F", + "Y", + "W", + ] + ), + "dna": np.array(["A", "C", "G", "T"]), +} class BiosequenceDataset(Dataset): @@ -32,26 +53,32 @@ class BiosequenceDataset(Dataset): memory. """ - def __init__(self, source, source_type='list', alphabet='amino-acid', - max_length=None, include_stop=False, device=None): + def __init__( + self, + source, + source_type="list", + alphabet="amino-acid", + max_length=None, + include_stop=False, + device=None, + ): super().__init__() # Determine device if device is None: - device = torch.tensor(0.).device + device = torch.tensor(0.0).device self.device = device # Get sequences. self.include_stop = include_stop - if source_type == 'list': - seqs = [seq + include_stop*'*' for seq in source] - elif source_type == 'fasta': + if source_type == "list": + seqs = [seq + include_stop * "*" for seq in source] + elif source_type == "fasta": seqs = self._load_fasta(source) # Get lengths. - self.L_data = torch.tensor([float(len(seq)) for seq in seqs], - device=device) + self.L_data = torch.tensor([float(len(seq)) for seq in seqs], device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: @@ -64,42 +91,46 @@ def __init__(self, source, source_type='list', alphabet='amino-acid', else: alphabet = np.array(list(alphabet)) if self.include_stop: - alphabet = np.array(list(alphabet) + ['*']) + alphabet = np.array(list(alphabet) + ["*"]) self.alphabet = alphabet self.alphabet_length = len(alphabet) # Build dataset. - self.seq_data = torch.cat([self._one_hot( - seq, alphabet, self.max_length).unsqueeze(0) for seq in seqs]) + self.seq_data = torch.cat( + [self._one_hot(seq, alphabet, self.max_length).unsqueeze(0) for seq in seqs] + ) def _load_fasta(self, source): """A basic multiline fasta parser.""" seqs = [] - seq = '' - with open(source, 'r') as fr: + seq = "" + with open(source, "r") as fr: for line in fr: - if line[0] == '>': - if seq != '': + if line[0] == ">": + if seq != "": if self.include_stop: - seq += '*' + seq += "*" seqs.append(seq) - seq = '' + seq = "" else: - seq += line.strip('\n') - if seq != '': + seq += line.strip("\n") + if seq != "": if self.include_stop: - seq += '*' + seq += "*" seqs.append(seq) return seqs def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. - oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] - ).astype(np.float64), device=self.device) + oh = torch.tensor( + (np.array(list(seq))[:, None] == alphabet[None, :]).astype(np.float64), + device=self.device, + ) # Pad. - x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)], - device=self.device)]) + x = torch.cat( + [oh, torch.zeros([length - len(seq), len(alphabet)], device=self.device)] + ) return x diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index eb414bf82c..26084c10e7 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -35,42 +35,51 @@ class MissingDataDiscreteHMM(TorchDistribution): dimension of the categorical output, and be broadcastable to ``(batch_size, state_dim, categorical_size)``. """ - arg_constraints = {"initial_logits": constraints.real_vector, - "transition_logits": constraints.independent( - constraints.real, 2), - "observation_logits": constraints.independent( - constraints.real, 2)} + + arg_constraints = { + "initial_logits": constraints.real_vector, + "transition_logits": constraints.independent(constraints.real, 2), + "observation_logits": constraints.independent(constraints.real, 2), + } support = constraints.independent(constraints.nonnegative_integer, 2) - def __init__(self, initial_logits, transition_logits, observation_logits, - validate_args=None): + def __init__( + self, initial_logits, transition_logits, observation_logits, validate_args=None + ): if initial_logits.dim() < 1: raise ValueError( - "expected initial_logits to have at least one dim, " - "actual shape = {}".format(initial_logits.shape)) + "expected initial_logits to have at least one dim, " + "actual shape = {}".format(initial_logits.shape) + ) if transition_logits.dim() < 2: raise ValueError( - "expected transition_logits to have at least two dims, " - "actual shape = {}".format(transition_logits.shape)) + "expected transition_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape) + ) if observation_logits.dim() < 2: raise ValueError( - "expected observation_logits to have at least two dims, " - "actual shape = {}".format(transition_logits.shape)) - shape = broadcast_shape(initial_logits.shape[:-1], - transition_logits.shape[:-2], - observation_logits.shape[:-2]) + "expected observation_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape) + ) + shape = broadcast_shape( + initial_logits.shape[:-1], + transition_logits.shape[:-2], + observation_logits.shape[:-2], + ) if len(shape) == 0: shape = torch.Size([1]) batch_shape = shape event_shape = (1, observation_logits.shape[-1]) - self.initial_logits = (initial_logits - - initial_logits.logsumexp(-1, True)) - self.transition_logits = (transition_logits - - transition_logits.logsumexp(-1, True)) - self.observation_logits = (observation_logits - - observation_logits.logsumexp(-1, True)) + self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True) + self.transition_logits = transition_logits - transition_logits.logsumexp( + -1, True + ) + self.observation_logits = observation_logits - observation_logits.logsumexp( + -1, True + ) super(MissingDataDiscreteHMM, self).__init__( - batch_shape, event_shape, validate_args=validate_args) + batch_shape, event_shape, validate_args=validate_args + ) def log_prob(self, value): """ @@ -88,16 +97,15 @@ def log_prob(self, value): # Combine observation and transition factors. value_logits = torch.matmul( - value, torch.transpose(self.observation_logits, -2, -1)) - result = (self.transition_logits.unsqueeze(-3) + - value_logits[..., 1:, None, :]) + value, torch.transpose(self.observation_logits, -2, -1) + ) + result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. - result = (self.initial_logits + value_logits[..., 0, :] - + result.logsumexp(-1)) + result = self.initial_logits + value_logits[..., 0, :] + result.logsumexp(-1) # Marginalize out final state. result = result.logsumexp(-1) diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index fb55a2fa9f..db041ea326 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -43,9 +43,16 @@ class ProfileHMM(nn.Module): :param bool cuda: Transfer data onto the GPU during training. :param bool pin_memory: Pin memory for faster GPU transfer. """ - def __init__(self, latent_seq_length, alphabet_length, - prior_scale=1., indel_prior_bias=10., - cuda=False, pin_memory=False): + + def __init__( + self, + latent_seq_length, + alphabet_length, + prior_scale=1.0, + indel_prior_bias=10.0, + cuda=False, + pin_memory=False, + ): super().__init__() assert isinstance(cuda, bool) self.is_cuda = cuda @@ -58,13 +65,13 @@ def __init__(self, latent_seq_length, alphabet_length, self.alphabet_length = alphabet_length self.precursor_seq_shape = (latent_seq_length, alphabet_length) - self.insert_seq_shape = (latent_seq_length+1, alphabet_length) + self.insert_seq_shape = (latent_seq_length + 1, alphabet_length) self.indel_shape = (latent_seq_length, 3, 2) assert isinstance(prior_scale, float) self.prior_scale = prior_scale assert isinstance(indel_prior_bias, float) - self.indel_prior = torch.tensor([indel_prior_bias, 0.]) + self.indel_prior = torch.tensor([indel_prior_bias, 0.0]) # Initialize state arranger. self.statearrange = Profile(latent_seq_length) @@ -72,74 +79,97 @@ def __init__(self, latent_seq_length, alphabet_length, def model(self, seq_data, local_scale): # Latent sequence. - precursor_seq = pyro.sample("precursor_seq", dist.Normal( + precursor_seq = pyro.sample( + "precursor_seq", + dist.Normal( torch.zeros(self.precursor_seq_shape), - self.prior_scale * - torch.ones(self.precursor_seq_shape)).to_event(2)) + self.prior_scale * torch.ones(self.precursor_seq_shape), + ).to_event(2), + ) precursor_seq_logits = precursor_seq - precursor_seq.logsumexp(-1, True) - insert_seq = pyro.sample("insert_seq", dist.Normal( + insert_seq = pyro.sample( + "insert_seq", + dist.Normal( torch.zeros(self.insert_seq_shape), - self.prior_scale * - torch.ones(self.insert_seq_shape)).to_event(2)) + self.prior_scale * torch.ones(self.insert_seq_shape), + ).to_event(2), + ) insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True) # Indel probabilities. - insert = pyro.sample("insert", dist.Normal( + insert = pyro.sample( + "insert", + dist.Normal( self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + self.prior_scale * torch.ones(self.indel_shape), + ).to_event(3), + ) insert_logits = insert - insert.logsumexp(-1, True) - delete = pyro.sample("delete", dist.Normal( + delete = pyro.sample( + "delete", + dist.Normal( self.indel_prior * torch.ones(self.indel_shape), - self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + self.prior_scale * torch.ones(self.indel_shape), + ).to_event(3), + ) delete_logits = delete - delete.logsumexp(-1, True) # Construct HMM parameters. - initial_logits, transition_logits, observation_logits = ( - self.statearrange(precursor_seq_logits, insert_seq_logits, - insert_logits, delete_logits)) + initial_logits, transition_logits, observation_logits = self.statearrange( + precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits + ) with pyro.plate("batch", seq_data.shape[0]): with poutine.scale(scale=local_scale): # Observations. - pyro.sample("obs_seq", - MissingDataDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=seq_data) + pyro.sample( + "obs_seq", + MissingDataDiscreteHMM( + initial_logits, transition_logits, observation_logits + ), + obs=seq_data, + ) def guide(self, seq_data, local_scale): # Sequence. - precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", - torch.zeros(self.precursor_seq_shape)) - precursor_seq_q_sd = pyro.param("precursor_seq_q_sd", - torch.zeros(self.precursor_seq_shape)) - pyro.sample("precursor_seq", dist.Normal( - precursor_seq_q_mn, softplus(precursor_seq_q_sd)).to_event(2)) - insert_seq_q_mn = pyro.param("insert_seq_q_mn", - torch.zeros(self.insert_seq_shape)) - insert_seq_q_sd = pyro.param("insert_seq_q_sd", - torch.zeros(self.insert_seq_shape)) - pyro.sample("insert_seq", dist.Normal( - insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2)) + precursor_seq_q_mn = pyro.param( + "precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape) + ) + precursor_seq_q_sd = pyro.param( + "precursor_seq_q_sd", torch.zeros(self.precursor_seq_shape) + ) + pyro.sample( + "precursor_seq", + dist.Normal(precursor_seq_q_mn, softplus(precursor_seq_q_sd)).to_event(2), + ) + insert_seq_q_mn = pyro.param( + "insert_seq_q_mn", torch.zeros(self.insert_seq_shape) + ) + insert_seq_q_sd = pyro.param( + "insert_seq_q_sd", torch.zeros(self.insert_seq_shape) + ) + pyro.sample( + "insert_seq", + dist.Normal(insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2), + ) # Indels. - insert_q_mn = pyro.param("insert_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - insert_q_sd = pyro.param("insert_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("insert", dist.Normal( - insert_q_mn, softplus(insert_q_sd)).to_event(3)) - delete_q_mn = pyro.param("delete_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - delete_q_sd = pyro.param("delete_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("delete", dist.Normal( - delete_q_mn, softplus(delete_q_sd)).to_event(3)) - - def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, - jit=False): + insert_q_mn = pyro.param( + "insert_q_mn", torch.ones(self.indel_shape) * self.indel_prior + ) + insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape)) + pyro.sample( + "insert", dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3) + ) + delete_q_mn = pyro.param( + "delete_q_mn", torch.ones(self.indel_shape) * self.indel_prior + ) + delete_q_sd = pyro.param("delete_q_sd", torch.zeros(self.indel_shape)) + pyro.sample( + "delete", dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3) + ) + + def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, jit=False): """ Infer approximate posterior with stochastic variational inference. @@ -158,14 +188,19 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, if batch_size is not None: self.batch_size = batch_size if scheduler is None: - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.01}, - 'milestones': [], - 'gamma': 0.5}) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": 0.01}, + "milestones": [], + "gamma": 0.5, + } + ) # Initialize guide. self.guide(None, None) - dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, - pin_memory=self.pin_memory) + dataload = DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory + ) # Setup stochastic variational inference. if jit: elbo = JitTrace_ELBO(ignore_jit_warnings=True) @@ -180,11 +215,12 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, for seq_data, L_data in dataload: if self.is_cuda: seq_data = seq_data.cuda() - loss = svi.step(seq_data, - torch.tensor(len(dataset)/seq_data.shape[0])) + loss = svi.step( + seq_data, torch.tensor(len(dataset) / seq_data.shape[0]) + ) losses.append(loss) scheduler.step() - print(epoch, loss, ' ', datetime.datetime.now() - t0) + print(epoch, loss, " ", datetime.datetime.now() - t0) return losses def evaluate(self, dataset_train, dataset_test=None, jit=False): @@ -198,48 +234,56 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) if dataset_test is not None: - dataload_test = DataLoader(dataset_test, batch_size=1, - shuffle=False) + dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) # Initialize guide. self.guide(None, None) if jit: elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: elbo = Trace_ELBO() - scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + scheduler = MultiStepLR({"optimizer": Adam, "optim_args": {"lr": 0.01}}) # Setup stochastic variational inference. svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( - svi, dataload_train, len(dataset_train)) + svi, dataload_train, len(dataset_train) + ) if dataset_test is not None: test_lp, test_perplex = self._evaluate_local_elbo( - svi, dataload_test, len(dataset_test)) + svi, dataload_test, len(dataset_test) + ) return train_lp, test_lp, train_perplex, test_perplex else: return train_lp, None, train_perplex, None def _local_variables(self, name, site): """Return per datapoint random variables in model.""" - return name in ['obs_L', 'obs_seq'] + return name in ["obs_L", "obs_seq"] def _evaluate_local_elbo(self, svi, dataload, data_size): """Evaluate elbo and average per residue perplexity.""" - lp, perplex = 0., 0. + lp, perplex = 0.0, 0.0 with torch.no_grad(): for seq_data, L_data in dataload: if self.is_cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() - conditioned_model = poutine.condition(self.model, data={ - "obs_seq": seq_data}) - args = (seq_data, torch.tensor(1.)) + conditioned_model = poutine.condition( + self.model, data={"obs_seq": seq_data} + ) + args = (seq_data, torch.tensor(1.0)) guide_tr = poutine.trace(self.guide).get_trace(*args) - model_tr = poutine.trace(poutine.replay( - conditioned_model, trace=guide_tr)).get_trace(*args) - local_elbo = (model_tr.log_prob_sum(self._local_variables) - - guide_tr.log_prob_sum(self._local_variables) - ).cpu().numpy() + model_tr = poutine.trace( + poutine.replay(conditioned_model, trace=guide_tr) + ).get_trace(*args) + local_elbo = ( + ( + model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ) + .cpu() + .numpy() + ) lp += local_elbo perplex += -local_elbo / L_data[0].cpu().numpy() perplex = np.exp(perplex / data_size) @@ -308,23 +352,29 @@ class FactorMuE(nn.Module): :param bool pin_memory: Pin memory for faster GPU transfer. :param float epsilon: A small value for numerical stability. """ - def __init__(self, data_length, alphabet_length, z_dim, - batch_size=10, - latent_seq_length=None, - indel_factor_dependence=False, - indel_prior_scale=1., - indel_prior_bias=10., - inverse_temp_prior=100., - weights_prior_scale=1., - offset_prior_scale=1., - z_prior_distribution='Normal', - ARD_prior=False, - substitution_matrix=True, - substitution_prior_scale=10., - latent_alphabet_length=None, - cuda=False, - pin_memory=False, - epsilon=1e-32): + + def __init__( + self, + data_length, + alphabet_length, + z_dim, + batch_size=10, + latent_seq_length=None, + indel_factor_dependence=False, + indel_prior_scale=1.0, + indel_prior_bias=10.0, + inverse_temp_prior=100.0, + weights_prior_scale=1.0, + offset_prior_scale=1.0, + z_prior_distribution="Normal", + ARD_prior=False, + substitution_matrix=True, + substitution_prior_scale=10.0, + latent_alphabet_length=None, + cuda=False, + pin_memory=False, + epsilon=1e-32, + ): super().__init__() assert isinstance(cuda, bool) self.is_cuda = cuda @@ -350,8 +400,9 @@ def __init__(self, data_length, alphabet_length, z_dim, self.latent_alphabet_length = latent_alphabet_length self.indel_shape = (latent_seq_length, 3, 2) self.total_factor_size = ( - (2*latent_seq_length+1)*latent_alphabet_length + - 2*indel_factor_dependence*latent_seq_length*3*2) + (2 * latent_seq_length + 1) * latent_alphabet_length + + 2 * indel_factor_dependence * latent_seq_length * 3 * 2 + ) # Architecture. self.indel_factor_dependence = indel_factor_dependence @@ -362,7 +413,7 @@ def __init__(self, data_length, alphabet_length, z_dim, assert isinstance(indel_prior_scale, float) self.indel_prior_scale = torch.tensor(indel_prior_scale) assert isinstance(indel_prior_bias, float) - self.indel_prior = torch.tensor([indel_prior_bias, 0.]) + self.indel_prior = torch.tensor([indel_prior_bias, 0.0]) assert isinstance(inverse_temp_prior, float) self.inverse_temp_prior = torch.tensor(inverse_temp_prior) assert isinstance(weights_prior_scale, float) @@ -391,28 +442,32 @@ def decoder(self, z, W, B, inverse_temp): out = dict() if self.indel_factor_dependence: # Extract insertion and deletion parameters. - ind0 = (2*self.latent_seq_length+1)*self.latent_alphabet_length - ind1 = ind0 + self.latent_seq_length*3*2 - ind2 = ind1 + self.latent_seq_length*3*2 + ind0 = (2 * self.latent_seq_length + 1) * self.latent_alphabet_length + ind1 = ind0 + self.latent_seq_length * 3 * 2 + ind2 = ind1 + self.latent_seq_length * 3 * 2 insert_v, delete_v = v[:, ind0:ind1], v[:, ind1:ind2] - insert_v = (insert_v.reshape([-1, self.latent_seq_length, 3, 2]) - + self.indel_prior) - out['insert_logits'] = insert_v - insert_v.logsumexp(-1, True) - delete_v = (delete_v.reshape([-1, self.latent_seq_length, 3, 2]) - + self.indel_prior) - out['delete_logits'] = delete_v - delete_v.logsumexp(-1, True) + insert_v = ( + insert_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior + ) + out["insert_logits"] = insert_v - insert_v.logsumexp(-1, True) + delete_v = ( + delete_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior + ) + out["delete_logits"] = delete_v - delete_v.logsumexp(-1, True) # Extract precursor and insertion sequences. - ind0 = self.latent_seq_length*self.latent_alphabet_length - ind1 = ind0 + (self.latent_seq_length+1)*self.latent_alphabet_length + ind0 = self.latent_seq_length * self.latent_alphabet_length + ind1 = ind0 + (self.latent_seq_length + 1) * self.latent_alphabet_length precursor_seq_v, insert_seq_v = v[:, :ind0], v[:, ind0:ind1] - precursor_seq_v = (precursor_seq_v*softplus(inverse_temp)).reshape([ - -1, self.latent_seq_length, self.latent_alphabet_length]) - out['precursor_seq_logits'] = ( - precursor_seq_v - precursor_seq_v.logsumexp(-1, True)) - insert_seq_v = (insert_seq_v*softplus(inverse_temp)).reshape([ - -1, self.latent_seq_length+1, self.latent_alphabet_length]) - out['insert_seq_logits'] = ( - insert_seq_v - insert_seq_v.logsumexp(-1, True)) + precursor_seq_v = (precursor_seq_v * softplus(inverse_temp)).reshape( + [-1, self.latent_seq_length, self.latent_alphabet_length] + ) + out["precursor_seq_logits"] = precursor_seq_v - precursor_seq_v.logsumexp( + -1, True + ) + insert_seq_v = (insert_seq_v * softplus(inverse_temp)).reshape( + [-1, self.latent_seq_length + 1, self.latent_alphabet_length] + ) + out["insert_seq_logits"] = insert_seq_v - insert_seq_v.logsumexp(-1, True) return out @@ -421,85 +476,123 @@ def model(self, seq_data, local_scale, local_prior_scale): # ARD prior. if self.ARD_prior: # Relevance factors - alpha = pyro.sample("alpha", dist.Gamma( - torch.ones(self.z_dim), torch.ones(self.z_dim)).to_event(1)) + alpha = pyro.sample( + "alpha", + dist.Gamma(torch.ones(self.z_dim), torch.ones(self.z_dim)).to_event(1), + ) else: alpha = torch.ones(self.z_dim) # Factor and offset. - W = pyro.sample("W", dist.Normal( + W = pyro.sample( + "W", + dist.Normal( torch.zeros([self.z_dim, self.total_factor_size]), - torch.ones([self.z_dim, self.total_factor_size]) * - self.weights_prior_scale / (alpha[:, None] + self.epsilon) - ).to_event(2)) - B = pyro.sample("B", dist.Normal( + torch.ones([self.z_dim, self.total_factor_size]) + * self.weights_prior_scale + / (alpha[:, None] + self.epsilon), + ).to_event(2), + ) + B = pyro.sample( + "B", + dist.Normal( torch.zeros(self.total_factor_size), - torch.ones(self.total_factor_size) * self.offset_prior_scale - ).to_event(1)) + torch.ones(self.total_factor_size) * self.offset_prior_scale, + ).to_event(1), + ) # Indel probabilities. if not self.indel_factor_dependence: - insert = pyro.sample("insert", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.indel_prior_scale * torch.ones(self.indel_shape) - ).to_event(3)) + insert = pyro.sample( + "insert", + dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape), + ).to_event(3), + ) insert_logits = insert - insert.logsumexp(-1, True) - delete = pyro.sample("delete", dist.Normal( - self.indel_prior * torch.ones(self.indel_shape), - self.indel_prior_scale * torch.ones(self.indel_shape) - ).to_event(3)) + delete = pyro.sample( + "delete", + dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape), + ).to_event(3), + ) delete_logits = delete - delete.logsumexp(-1, True) # Inverse temperature. - inverse_temp = pyro.sample("inverse_temp", dist.Normal( - self.inverse_temp_prior, torch.tensor(1.))) + inverse_temp = pyro.sample( + "inverse_temp", dist.Normal(self.inverse_temp_prior, torch.tensor(1.0)) + ) # Substitution matrix. if self.substitution_matrix: - substitute = pyro.sample("substitute", dist.Normal( - torch.zeros([ - self.latent_alphabet_length, self.alphabet_length]), - self.substitution_prior_scale * torch.ones([ - self.latent_alphabet_length, self.alphabet_length]) - ).to_event(2)) + substitute = pyro.sample( + "substitute", + dist.Normal( + torch.zeros([self.latent_alphabet_length, self.alphabet_length]), + self.substitution_prior_scale + * torch.ones([self.latent_alphabet_length, self.alphabet_length]), + ).to_event(2), + ) with pyro.plate("batch", seq_data.shape[0]): with poutine.scale(scale=local_scale): with poutine.scale(scale=local_prior_scale): # Sample latent variable from prior. - if self.z_prior_distribution == 'Normal': - z = pyro.sample("latent", dist.Normal( + if self.z_prior_distribution == "Normal": + z = pyro.sample( + "latent", + dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1), + ) + elif self.z_prior_distribution == "Laplace": + z = pyro.sample( + "latent", + dist.Laplace( torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) - elif self.z_prior_distribution == 'Laplace': - z = pyro.sample("latent", dist.Laplace( - torch.zeros(self.z_dim), torch.ones(self.z_dim) - ).to_event(1)) + ).to_event(1), + ) # Decode latent sequence. decoded = self.decoder(z, W, B, inverse_temp) if self.indel_factor_dependence: - insert_logits = decoded['insert_logits'] - delete_logits = decoded['delete_logits'] + insert_logits = decoded["insert_logits"] + delete_logits = decoded["delete_logits"] # Construct HMM parameters. if self.substitution_matrix: - initial_logits, transition_logits, observation_logits = ( - self.statearrange(decoded['precursor_seq_logits'], - decoded['insert_seq_logits'], - insert_logits, delete_logits, - substitute)) + ( + initial_logits, + transition_logits, + observation_logits, + ) = self.statearrange( + decoded["precursor_seq_logits"], + decoded["insert_seq_logits"], + insert_logits, + delete_logits, + substitute, + ) else: - initial_logits, transition_logits, observation_logits = ( - self.statearrange(decoded['precursor_seq_logits'], - decoded['insert_seq_logits'], - insert_logits, delete_logits)) + ( + initial_logits, + transition_logits, + observation_logits, + ) = self.statearrange( + decoded["precursor_seq_logits"], + decoded["insert_seq_logits"], + insert_logits, + delete_logits, + ) # Draw samples. - pyro.sample("obs_seq", - MissingDataDiscreteHMM(initial_logits, - transition_logits, - observation_logits), - obs=seq_data) + pyro.sample( + "obs_seq", + MissingDataDiscreteHMM( + initial_logits, transition_logits, observation_logits + ), + obs=seq_data, + ) def guide(self, seq_data, local_scale, local_prior_scale): # Register encoder with pyro. @@ -509,13 +602,13 @@ def guide(self, seq_data, local_scale, local_prior_scale): if self.ARD_prior: alpha_conc = pyro.param("alpha_conc", torch.randn(self.z_dim)) alpha_rate = pyro.param("alpha_rate", torch.randn(self.z_dim)) - pyro.sample("alpha", dist.Gamma(softplus(alpha_conc), - softplus(alpha_rate)).to_event(1)) + pyro.sample( + "alpha", + dist.Gamma(softplus(alpha_conc), softplus(alpha_rate)).to_event(1), + ) # Factors. - W_q_mn = pyro.param("W_q_mn", torch.randn([ - self.z_dim, self.total_factor_size])) - W_q_sd = pyro.param("W_q_sd", torch.ones([ - self.z_dim, self.total_factor_size])) + W_q_mn = pyro.param("W_q_mn", torch.randn([self.z_dim, self.total_factor_size])) + W_q_sd = pyro.param("W_q_sd", torch.ones([self.z_dim, self.total_factor_size])) pyro.sample("W", dist.Normal(W_q_mn, softplus(W_q_sd)).to_event(2)) B_q_mn = pyro.param("B_q_mn", torch.randn(self.total_factor_size)) B_q_sd = pyro.param("B_q_sd", torch.ones(self.total_factor_size)) @@ -523,52 +616,64 @@ def guide(self, seq_data, local_scale, local_prior_scale): # Indel probabilities. if not self.indel_factor_dependence: - insert_q_mn = pyro.param("insert_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - insert_q_sd = pyro.param("insert_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("insert", dist.Normal( - insert_q_mn, softplus(insert_q_sd)).to_event(3)) - delete_q_mn = pyro.param("delete_q_mn", - torch.ones(self.indel_shape) - * self.indel_prior) - delete_q_sd = pyro.param("delete_q_sd", - torch.zeros(self.indel_shape)) - pyro.sample("delete", dist.Normal( - delete_q_mn, softplus(delete_q_sd)).to_event(3)) + insert_q_mn = pyro.param( + "insert_q_mn", torch.ones(self.indel_shape) * self.indel_prior + ) + insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape)) + pyro.sample( + "insert", dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3) + ) + delete_q_mn = pyro.param( + "delete_q_mn", torch.ones(self.indel_shape) * self.indel_prior + ) + delete_q_sd = pyro.param("delete_q_sd", torch.zeros(self.indel_shape)) + pyro.sample( + "delete", dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3) + ) # Inverse temperature. - inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.)) - inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.)) - pyro.sample("inverse_temp", dist.Normal( - inverse_temp_q_mn, softplus(inverse_temp_q_sd))) + inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.0)) + inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.0)) + pyro.sample( + "inverse_temp", dist.Normal(inverse_temp_q_mn, softplus(inverse_temp_q_sd)) + ) # Substitution matrix. if self.substitution_matrix: - substitute_q_mn = pyro.param("substitute_q_mn", torch.zeros( - [self.latent_alphabet_length, self.alphabet_length])) - substitute_q_sd = pyro.param("substitute_q_sd", torch.zeros( - [self.latent_alphabet_length, self.alphabet_length])) - pyro.sample("substitute", dist.Normal( - substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) + substitute_q_mn = pyro.param( + "substitute_q_mn", + torch.zeros([self.latent_alphabet_length, self.alphabet_length]), + ) + substitute_q_sd = pyro.param( + "substitute_q_sd", + torch.zeros([self.latent_alphabet_length, self.alphabet_length]), + ) + pyro.sample( + "substitute", + dist.Normal(substitute_q_mn, softplus(substitute_q_sd)).to_event(2), + ) # Per datapoint local latent variables. with pyro.plate("batch", seq_data.shape[0]): # Encode sequences. z_loc, z_scale = self.encoder(seq_data) # Scale log likelihood to account for mini-batching. - with poutine.scale(scale=local_scale*local_prior_scale): + with poutine.scale(scale=local_scale * local_prior_scale): # Sample. - if self.z_prior_distribution == 'Normal': - pyro.sample("latent", - dist.Normal(z_loc, z_scale).to_event(1)) - elif self.z_prior_distribution == 'Laplace': - pyro.sample("latent", - dist.Laplace(z_loc, z_scale).to_event(1)) - - def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, - scheduler=None, jit=False): + if self.z_prior_distribution == "Normal": + pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + elif self.z_prior_distribution == "Laplace": + pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1)) + + def fit_svi( + self, + dataset, + epochs=2, + anneal_length=1.0, + batch_size=None, + scheduler=None, + jit=False, + ): """ Infer approximate posterior with stochastic variational inference. @@ -590,17 +695,22 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, if batch_size is not None: self.batch_size = batch_size if scheduler is None: - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.01}, - 'milestones': [], - 'gamma': 0.5}) - dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, - pin_memory=self.pin_memory) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": 0.01}, + "milestones": [], + "gamma": 0.5, + } + ) + dataload = DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory + ) # Initialize guide. for seq_data, L_data in dataload: if self.is_cuda: seq_data = seq_data.cuda() - self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) + self.guide(seq_data, torch.tensor(1.0), torch.tensor(1.0)) break # Setup stochastic variational inference. if jit: @@ -618,22 +728,23 @@ def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, if self.is_cuda: seq_data = seq_data.cuda() loss = svi.step( - seq_data, torch.tensor(len(dataset)/seq_data.shape[0]), - self._beta_anneal(step_i, batch_size, len(dataset), - anneal_length)) + seq_data, + torch.tensor(len(dataset) / seq_data.shape[0]), + self._beta_anneal(step_i, batch_size, len(dataset), anneal_length), + ) losses.append(loss) scheduler.step() step_i += 1 - print(epoch, loss, ' ', datetime.datetime.now() - t0) + print(epoch, loss, " ", datetime.datetime.now() - t0) return losses def _beta_anneal(self, step, batch_size, data_size, anneal_length): """Annealing schedule for prior KL term (beta annealing).""" - if np.allclose(anneal_length, 0.): - return torch.tensor(1.) - anneal_frac = step*batch_size/(anneal_length*data_size) - return torch.tensor(min([anneal_frac, 1.])) + if np.allclose(anneal_length, 0.0): + return torch.tensor(1.0) + anneal_frac = step * batch_size / (anneal_length * data_size) + return torch.tensor(min([anneal_frac, 1.0])) def evaluate(self, dataset_train, dataset_test=None, jit=False): """ @@ -647,52 +758,60 @@ def evaluate(self, dataset_train, dataset_test=None, jit=False): """ dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) if dataset_test is not None: - dataload_test = DataLoader(dataset_test, batch_size=1, - shuffle=False) + dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False) # Initialize guide. for seq_data, L_data in dataload_train: if self.is_cuda: seq_data = seq_data.cuda() - self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) + self.guide(seq_data, torch.tensor(1.0), torch.tensor(1.0)) break if jit: elbo = JitTrace_ELBO(ignore_jit_warnings=True) else: elbo = Trace_ELBO() - scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + scheduler = MultiStepLR({"optimizer": Adam, "optim_args": {"lr": 0.01}}) # Setup stochastic variational inference. svi = SVI(self.model, self.guide, scheduler, loss=elbo) # Compute elbo and perplexity. train_lp, train_perplex = self._evaluate_local_elbo( - svi, dataload_train, len(dataset_train)) + svi, dataload_train, len(dataset_train) + ) if dataset_test is not None: test_lp, test_perplex = self._evaluate_local_elbo( - svi, dataload_test, len(dataset_test)) + svi, dataload_test, len(dataset_test) + ) return train_lp, test_lp, train_perplex, test_perplex else: return train_lp, None, train_perplex, None def _local_variables(self, name, site): """Return per datapoint random variables in model.""" - return name in ['latent', 'obs_L', 'obs_seq'] + return name in ["latent", "obs_L", "obs_seq"] def _evaluate_local_elbo(self, svi, dataload, data_size): """Evaluate elbo and average per residue perplexity.""" - lp, perplex = 0., 0. + lp, perplex = 0.0, 0.0 with torch.no_grad(): for seq_data, L_data in dataload: if self.is_cuda: seq_data, L_data = seq_data.cuda(), L_data.cuda() - conditioned_model = poutine.condition(self.model, data={ - "obs_seq": seq_data}) - args = (seq_data, torch.tensor(1.), torch.tensor(1.)) + conditioned_model = poutine.condition( + self.model, data={"obs_seq": seq_data} + ) + args = (seq_data, torch.tensor(1.0), torch.tensor(1.0)) guide_tr = poutine.trace(self.guide).get_trace(*args) - model_tr = poutine.trace(poutine.replay( - conditioned_model, trace=guide_tr)).get_trace(*args) - local_elbo = (model_tr.log_prob_sum(self._local_variables) - - guide_tr.log_prob_sum(self._local_variables) - ).cpu().numpy() + model_tr = poutine.trace( + poutine.replay(conditioned_model, trace=guide_tr) + ).get_trace(*args) + local_elbo = ( + ( + model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ) + .cpu() + .numpy() + ) lp += local_elbo perplex += -local_elbo / L_data[0].cpu().numpy() perplex = np.exp(perplex / data_size) @@ -726,6 +845,7 @@ def _reconstruct_regressor_seq(self, data, ind, param): # Encode seq. z_loc = self.encoder(data[ind][0])[0] # Reconstruct - decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), - param("inverse_temp_q_mn")) - return torch.exp(decoded['precursor_seq_logits']) + decoded = self.decoder( + z_loc, param("W_q_mn"), param("B_q_mn"), param("inverse_temp_q_mn") + ) + return torch.exp(decoded["precursor_seq_logits"]) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py index 2c384ec720..992118fa90 100644 --- a/pyro/contrib/mue/statearrangers.py +++ b/pyro/contrib/mue/statearrangers.py @@ -28,10 +28,11 @@ class Profile(nn.Module): :param epsilon: A small value for numerical stability. :type epsilon: float """ + def __init__(self, M, epsilon=1e-32): super().__init__() self.M = M - self.K = 2*M+1 + self.K = 2 * M + 1 self.epsilon = epsilon self._make_transfer() @@ -50,24 +51,21 @@ def _make_transfer(self): # ...transf_0 -> initial transition vector # ...transf -> transition matrix # We fix r_{M+1,j} = 1 for j in {0, 1, 2} - self.register_buffer('r_transf_0', - torch.zeros((M, 3, 2, K))) - self.register_buffer('u_transf_0', - torch.zeros((M, 3, 2, K))) - self.register_buffer('null_transf_0', - torch.zeros((K,))) + self.register_buffer("r_transf_0", torch.zeros((M, 3, 2, K))) + self.register_buffer("u_transf_0", torch.zeros((M, 3, 2, K))) + self.register_buffer("null_transf_0", torch.zeros((K,))) m, g = -1, 0 for gp in range(2): - for mp in range(M+gp): + for mp in range(M + gp): kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - self.r_transf_0[m+1-g, g, 0, kp] = 1 - self.u_transf_0[m+1-g, g, 0, kp] = 1 + self.r_transf_0[m + 1 - g, g, 0, kp] = 1 + self.u_transf_0[m + 1 - g, g, 0, kp] = 1 elif m + 1 - g < mp and gp == 0: - self.r_transf_0[m+1-g, g, 0, kp] = 1 - self.u_transf_0[m+1-g, g, 1, kp] = 1 - for mpp in range(m+2-g, mp): + self.r_transf_0[m + 1 - g, g, 0, kp] = 1 + self.u_transf_0[m + 1 - g, g, 1, kp] = 1 + for mpp in range(m + 2 - g, mp): self.r_transf_0[mpp, 2, 0, kp] = 1 self.u_transf_0[mpp, 2, 1, kp] = 1 self.r_transf_0[mp, 2, 0, kp] = 1 @@ -75,12 +73,12 @@ def _make_transfer(self): elif m + 1 - g == mp and gp == 1: if mp < M: - self.r_transf_0[m+1-g, g, 1, kp] = 1 + self.r_transf_0[m + 1 - g, g, 1, kp] = 1 elif m + 1 - g < mp and gp == 1: - self.r_transf_0[m+1-g, g, 0, kp] = 1 - self.u_transf_0[m+1-g, g, 1, kp] = 1 - for mpp in range(m+2-g, mp): + self.r_transf_0[m + 1 - g, g, 0, kp] = 1 + self.u_transf_0[m + 1 - g, g, 1, kp] = 1 + for mpp in range(m + 2 - g, mp): self.r_transf_0[mpp, 2, 0, kp] = 1 self.u_transf_0[mpp, 2, 1, kp] = 1 if mp < M: @@ -89,58 +87,59 @@ def _make_transfer(self): else: self.null_transf_0[kp] = 1 - self.register_buffer('r_transf', - torch.zeros((M, 3, 2, K, K))) - self.register_buffer('u_transf', - torch.zeros((M, 3, 2, K, K))) - self.register_buffer('null_transf', - torch.zeros((K, K))) + self.register_buffer("r_transf", torch.zeros((M, 3, 2, K, K))) + self.register_buffer("u_transf", torch.zeros((M, 3, 2, K, K))) + self.register_buffer("null_transf", torch.zeros((K, K))) for g in range(2): - for m in range(M+g): + for m in range(M + g): for gp in range(2): - for mp in range(M+gp): + for mp in range(M + gp): k, kp = mg2k(m, g, M), mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - self.r_transf[m+1-g, g, 0, k, kp] = 1 - self.u_transf[m+1-g, g, 0, k, kp] = 1 + self.r_transf[m + 1 - g, g, 0, k, kp] = 1 + self.u_transf[m + 1 - g, g, 0, k, kp] = 1 elif m + 1 - g < mp and gp == 0: - self.r_transf[m+1-g, g, 0, k, kp] = 1 - self.u_transf[m+1-g, g, 1, k, kp] = 1 - self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 - self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 + self.r_transf[m + 1 - g, g, 0, k, kp] = 1 + self.u_transf[m + 1 - g, g, 1, k, kp] = 1 + self.r_transf[(m + 2 - g) : mp, 2, 0, k, kp] = 1 + self.u_transf[(m + 2 - g) : mp, 2, 1, k, kp] = 1 self.r_transf[mp, 2, 0, k, kp] = 1 self.u_transf[mp, 2, 0, k, kp] = 1 elif m + 1 - g == mp and gp == 1: if mp < M: - self.r_transf[m+1-g, g, 1, k, kp] = 1 + self.r_transf[m + 1 - g, g, 1, k, kp] = 1 elif m + 1 - g < mp and gp == 1: - self.r_transf[m+1-g, g, 0, k, kp] = 1 - self.u_transf[m+1-g, g, 1, k, kp] = 1 - self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 - self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 + self.r_transf[m + 1 - g, g, 0, k, kp] = 1 + self.u_transf[m + 1 - g, g, 1, k, kp] = 1 + self.r_transf[(m + 2 - g) : mp, 2, 0, k, kp] = 1 + self.u_transf[(m + 2 - g) : mp, 2, 1, k, kp] = 1 if mp < M: self.r_transf[mp, 2, 1, k, kp] = 1 else: self.null_transf[k, kp] = 1 - self.register_buffer('vx_transf', - torch.zeros((M, K))) - self.register_buffer('vc_transf', - torch.zeros((M+1, K))) + self.register_buffer("vx_transf", torch.zeros((M, K))) + self.register_buffer("vc_transf", torch.zeros((M + 1, K))) for g in range(2): - for m in range(M+g): + for m in range(M + g): k = mg2k(m, g, M) if g == 0: self.vx_transf[m, k] = 1 elif g == 1: self.vc_transf[m, k] = 1 - def forward(self, precursor_seq_logits, insert_seq_logits, - insert_logits, delete_logits, substitute_logits=None): + def forward( + self, + precursor_seq_logits, + insert_seq_logits, + insert_logits, + delete_logits, + substitute_logits=None, + ): """ Assemble HMM parameters given profile parameters. @@ -172,32 +171,31 @@ def forward(self, precursor_seq_logits, insert_seq_logits, :rtype: ~torch.Tensor, ~torch.Tensor, ~torch.Tensor """ initial_logits = ( - torch.einsum('...ijk,ijkl->...l', - delete_logits, self.u_transf_0) + - torch.einsum('...ijk,ijkl->...l', - insert_logits, self.r_transf_0) + - (-1/self.epsilon)*self.null_transf_0) + torch.einsum("...ijk,ijkl->...l", delete_logits, self.u_transf_0) + + torch.einsum("...ijk,ijkl->...l", insert_logits, self.r_transf_0) + + (-1 / self.epsilon) * self.null_transf_0 + ) transition_logits = ( - torch.einsum('...ijk,ijklf->...lf', - delete_logits, self.u_transf) + - torch.einsum('...ijk,ijklf->...lf', - insert_logits, self.r_transf) + - (-1/self.epsilon)*self.null_transf) + torch.einsum("...ijk,ijklf->...lf", delete_logits, self.u_transf) + + torch.einsum("...ijk,ijklf->...lf", insert_logits, self.r_transf) + + (-1 / self.epsilon) * self.null_transf + ) # Broadcasting for concatenation. if len(precursor_seq_logits.size()) > len(insert_seq_logits.size()): insert_seq_logits = insert_seq_logits.unsqueeze(0).expand( - [precursor_seq_logits.size()[0], -1, -1]) + [precursor_seq_logits.size()[0], -1, -1] + ) elif len(insert_seq_logits.size()) > len(precursor_seq_logits.size()): precursor_seq_logits = precursor_seq_logits.unsqueeze(0).expand( - [insert_seq_logits.size()[0], -1, -1]) - seq_logits = torch.cat([precursor_seq_logits, insert_seq_logits], - dim=-2) + [insert_seq_logits.size()[0], -1, -1] + ) + seq_logits = torch.cat([precursor_seq_logits, insert_seq_logits], dim=-2) # Option to include the substitution matrix. if substitute_logits is not None: observation_logits = torch.logsumexp( - seq_logits.unsqueeze(-1) + substitute_logits.unsqueeze(-3), - dim=-2) + seq_logits.unsqueeze(-1) + substitute_logits.unsqueeze(-3), dim=-2 + ) else: observation_logits = seq_logits @@ -206,4 +204,4 @@ def forward(self, precursor_seq_logits, insert_seq_logits, def mg2k(m, g, M): """Convert from (m, g) indexing to k indexing.""" - return m + M*g + return m + M * g diff --git a/pyro/contrib/oed/__init__.py b/pyro/contrib/oed/__init__.py index 3afd3a440d..a60336c483 100644 --- a/pyro/contrib/oed/__init__.py +++ b/pyro/contrib/oed/__init__.py @@ -69,7 +69,4 @@ def model(design): from pyro.contrib.oed import eig, search -__all__ = [ - "search", - "eig" -] +__all__ = ["search", "eig"] diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index 8d6c7ae22b..45a0c1d3c1 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -22,12 +22,24 @@ "posterior_eig", "marginal_eig", "lfire_eig", - "vnmc_eig" + "vnmc_eig", ] -def laplace_eig(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, - final_num_samples, y_dist=None, eig=True, **prior_entropy_kwargs): +def laplace_eig( + model, + design, + observation_labels, + target_labels, + guide, + loss, + optim, + num_steps, + final_num_samples, + y_dist=None, + eig=True, + **prior_entropy_kwargs +): """ Estimates the expected information gain (EIG) by making repeated Laplace approximations to the posterior. @@ -57,8 +69,18 @@ def laplace_eig(model, design, observation_labels, target_labels, guide, loss, o if target_labels is not None and isinstance(target_labels, str): target_labels = [target_labels] - ape = _laplace_vi_ape(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, - final_num_samples, y_dist=y_dist) + ape = _laplace_vi_ape( + model, + design, + observation_labels, + target_labels, + guide, + loss, + optim, + num_steps, + final_num_samples, + y_dist=y_dist, + ) return _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs) @@ -67,19 +89,34 @@ def _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs): if eig: if mean_field: try: - prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) + prior_entropy = mean_field_entropy( + model, [design], whitelist=target_labels + ) except NotImplemented: - prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs) + prior_entropy = monte_carlo_entropy( + model, design, target_labels, **prior_entropy_kwargs + ) else: - prior_entropy = monte_carlo_entropy(model, design, target_labels, **prior_entropy_kwargs) + prior_entropy = monte_carlo_entropy( + model, design, target_labels, **prior_entropy_kwargs + ) return prior_entropy - ape else: return ape -def _laplace_vi_ape(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, - final_num_samples, y_dist=None): - +def _laplace_vi_ape( + model, + design, + observation_labels, + target_labels, + guide, + loss, + optim, + num_steps, + final_num_samples, + y_dist=None, +): def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) @@ -99,8 +136,10 @@ def posterior_entropy(y_dist, design): return entropy if y_dist is None: - y_dist = EmpiricalMarginal(Importance(model, num_samples=final_num_samples).run(design), - sites=observation_labels) + y_dist = EmpiricalMarginal( + Importance(model, num_samples=final_num_samples).run(design), + sites=observation_labels, + ) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) @@ -110,8 +149,17 @@ def posterior_entropy(y_dist, design): # Deprecated -def vi_eig(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None, - eig=True, **prior_entropy_kwargs): +def vi_eig( + model, + design, + observation_labels, + target_labels, + vi_parameters, + is_parameters, + y_dist=None, + eig=True, + **prior_entropy_kwargs +): """.. deprecated:: 0.4.1 Use `posterior_eig` instead. @@ -157,19 +205,38 @@ def vi_eig(model, design, observation_labels, target_labels, vi_parameters, is_p """ - warnings.warn("`vi_eig` is deprecated in favour of the amortized version: `posterior_eig`.", DeprecationWarning) + warnings.warn( + "`vi_eig` is deprecated in favour of the amortized version: `posterior_eig`.", + DeprecationWarning, + ) if isinstance(observation_labels, str): observation_labels = [observation_labels] if target_labels is not None and isinstance(target_labels, str): target_labels = [target_labels] - ape = _vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=y_dist) + ape = _vi_ape( + model, + design, + observation_labels, + target_labels, + vi_parameters, + is_parameters, + y_dist=y_dist, + ) return _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs) -def _vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None): - svi_num_steps = vi_parameters.pop('num_steps') +def _vi_ape( + model, + design, + observation_labels, + target_labels, + vi_parameters, + is_parameters, + y_dist=None, +): + svi_num_steps = vi_parameters.pop("num_steps") def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function @@ -187,8 +254,9 @@ def posterior_entropy(y_dist, design): return entropy if y_dist is None: - y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design), - sites=observation_labels) + y_dist = EmpiricalMarginal( + Importance(model, **is_parameters).run(design), sites=observation_labels + ) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) @@ -197,8 +265,16 @@ def posterior_entropy(y_dist, design): return loss -def nmc_eig(model, design, observation_labels, target_labels=None, - N=100, M=10, M_prime=None, independent_priors=False): +def nmc_eig( + model, + design, + observation_labels, + target_labels=None, + N=100, + M=10, + M_prime=None, + independent_priors=False, +): """Nested Monte Carlo estimate of the expected information gain (EIG). The estimate is, when there are not any random effects, @@ -251,8 +327,12 @@ def nmc_eig(model, design, observation_labels, target_labels=None, trace.compute_log_prob() if M_prime is not None: - y_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels} - theta_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels} + y_dict = { + l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels + } + theta_dict = { + l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels + } theta_dict.update(y_dict) # Resample M values of u and compute conditional probabilities # WARNING: currently the use of condition does not actually sample @@ -267,8 +347,9 @@ def nmc_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M_prime, N) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - - math.log(M_prime) + conditional_lp = sum( + retrace.nodes[l]["log_prob"] for l in observation_labels + ).logsumexp(0) - math.log(M_prime) else: # This assumes that y are independent conditional on theta # Furthermore assume that there are no other variables besides theta @@ -282,18 +363,29 @@ def nmc_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M, 1) # sample M theta retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - - math.log(M) + marginal_lp = sum( + retrace.nodes[l]["log_prob"] for l in observation_labels + ).logsumexp(0) - math.log(M) terms = conditional_lp - marginal_lp nonnan = (~torch.isnan(terms)).sum(0).type_as(terms) - terms[torch.isnan(terms)] = 0. - return terms.sum(0)/nonnan - - -def donsker_varadhan_eig(model, design, observation_labels, target_labels, - num_samples, num_steps, T, optim, return_history=False, - final_design=None, final_num_samples=None): + terms[torch.isnan(terms)] = 0.0 + return terms.sum(0) / nonnan + + +def donsker_varadhan_eig( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + T, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): """ Donsker-Varadhan estimate of the expected information gain (EIG). @@ -335,13 +427,35 @@ def donsker_varadhan_eig(model, design, observation_labels, target_labels, if isinstance(target_labels, str): target_labels = [target_labels] loss = _donsker_varadhan_loss(model, T, observation_labels, target_labels) - return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history, - final_design, final_num_samples) - - -def posterior_eig(model, design, observation_labels, target_labels, num_samples, num_steps, guide, optim, - return_history=False, final_design=None, final_num_samples=None, eig=True, prior_entropy_kwargs={}, - *args, **kwargs): + return opt_eig_ape_loss( + design, + loss, + num_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) + + +def posterior_eig( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + guide, + optim, + return_history=False, + final_design=None, + final_num_samples=None, + eig=True, + prior_entropy_kwargs={}, + *args, + **kwargs +): """ Posterior estimate of expected information gain (EIG) computed from the average posterior entropy (APE) using :math:`EIG(d) = H[p(\\theta)] - APE(d)`. See [1] for full details. @@ -390,24 +504,68 @@ def posterior_eig(model, design, observation_labels, target_labels, num_samples, if isinstance(target_labels, str): target_labels = [target_labels] - ape = _posterior_ape(model, design, observation_labels, target_labels, num_samples, num_steps, guide, optim, - return_history=return_history, final_design=final_design, final_num_samples=final_num_samples, - *args, **kwargs) + ape = _posterior_ape( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + guide, + optim, + return_history=return_history, + final_design=final_design, + final_num_samples=final_num_samples, + *args, + **kwargs + ) return _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_kwargs) -def _posterior_ape(model, design, observation_labels, target_labels, - num_samples, num_steps, guide, optim, return_history=False, - final_design=None, final_num_samples=None, *args, **kwargs): - - loss = _posterior_loss(model, guide, observation_labels, target_labels, *args, **kwargs) - return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history, - final_design, final_num_samples) - - -def marginal_eig(model, design, observation_labels, target_labels, - num_samples, num_steps, guide, optim, return_history=False, - final_design=None, final_num_samples=None): +def _posterior_ape( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + guide, + optim, + return_history=False, + final_design=None, + final_num_samples=None, + *args, + **kwargs +): + + loss = _posterior_loss( + model, guide, observation_labels, target_labels, *args, **kwargs + ) + return opt_eig_ape_loss( + design, + loss, + num_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) + + +def marginal_eig( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + guide, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): """Estimate EIG by estimating the marginal entropy :math:`p(y|d)`. See [1] for full details. The marginal representation of EIG is @@ -448,13 +606,32 @@ def marginal_eig(model, design, observation_labels, target_labels, if isinstance(target_labels, str): target_labels = [target_labels] loss = _marginal_loss(model, guide, observation_labels, target_labels) - return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history, - final_design, final_num_samples) - - -def marginal_likelihood_eig(model, design, observation_labels, target_labels, - num_samples, num_steps, marginal_guide, cond_guide, optim, - return_history=False, final_design=None, final_num_samples=None): + return opt_eig_ape_loss( + design, + loss, + num_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) + + +def marginal_likelihood_eig( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + marginal_guide, + cond_guide, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): """Estimates EIG by estimating the marginal entropy, that of :math:`p(y|d)`, *and* the conditional entropy, of :math:`p(y|\\theta, d)`, both via Gibbs' Inequality. See [1] for full details. @@ -489,14 +666,35 @@ def marginal_likelihood_eig(model, design, observation_labels, target_labels, observation_labels = [observation_labels] if isinstance(target_labels, str): target_labels = [target_labels] - loss = _marginal_likelihood_loss(model, marginal_guide, cond_guide, observation_labels, target_labels) - return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history, - final_design, final_num_samples) - - -def lfire_eig(model, design, observation_labels, target_labels, - num_y_samples, num_theta_samples, num_steps, classifier, optim, return_history=False, - final_design=None, final_num_samples=None): + loss = _marginal_likelihood_loss( + model, marginal_guide, cond_guide, observation_labels, target_labels + ) + return opt_eig_ape_loss( + design, + loss, + num_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) + + +def lfire_eig( + model, + design, + observation_labels, + target_labels, + num_y_samples, + num_theta_samples, + num_steps, + classifier, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): """Estimates the EIG using the method of Likelihood-Free Inference by Ratio Estimation (LFIRE) as in [1]. LFIRE is run separately for several samples of :math:`\\theta`. @@ -540,17 +738,35 @@ def lfire_eig(model, design, observation_labels, target_labels, cond_model = pyro.condition(model, data=theta_dict) loss = _lfire_loss(model, cond_model, classifier, observation_labels, target_labels) - out = opt_eig_ape_loss(expanded_design, loss, num_y_samples, num_steps, optim, return_history, - final_design, final_num_samples) + out = opt_eig_ape_loss( + expanded_design, + loss, + num_y_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) if return_history: return out[0], out[1].sum(0) / num_theta_samples else: return out.sum(0) / num_theta_samples -def vnmc_eig(model, design, observation_labels, target_labels, - num_samples, num_steps, guide, optim, return_history=False, - final_design=None, final_num_samples=None): +def vnmc_eig( + model, + design, + observation_labels, + target_labels, + num_samples, + num_steps, + guide, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): """Estimates the EIG using Variational Nested Monte Carlo (VNMC). The VNMC estimate [1] is .. math:: @@ -596,12 +812,28 @@ def vnmc_eig(model, design, observation_labels, target_labels, if isinstance(target_labels, str): target_labels = [target_labels] loss = _vnmc_eig_loss(model, guide, observation_labels, target_labels) - return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history, - final_design, final_num_samples) - - -def opt_eig_ape_loss(design, loss_fn, num_samples, num_steps, optim, return_history=False, - final_design=None, final_num_samples=None): + return opt_eig_ape_loss( + design, + loss, + num_samples, + num_steps, + optim, + return_history, + final_design, + final_num_samples, + ) + + +def opt_eig_ape_loss( + design, + loss_fn, + num_samples, + num_steps, + optim, + return_history=False, + final_design=None, + final_num_samples=None, +): if final_design is None: final_design = design @@ -615,8 +847,9 @@ def opt_eig_ape_loss(design, loss_fn, num_samples, num_steps, optim, return_hist pyro.infer.util.zero_grads(params) with poutine.trace(param_only=True) as param_capture: agg_loss, loss = loss_fn(design, num_samples, evaluation=return_history) - params = set(site["value"].unconstrained() - for site in param_capture.trace.nodes.values()) + params = set( + site["value"].unconstrained() for site in param_capture.trace.nodes.values() + ) if torch.isnan(agg_loss): raise ArithmeticError("Encountered NaN loss in opt_eig_ape_loss") agg_loss.backward(retain_graph=True) @@ -673,10 +906,14 @@ def loss_fn(design, num_particles, **kwargs): conditional_model = pyro.condition(model, data=y_dict) shuffled_trace = poutine.trace(conditional_model).get_trace(expanded_design) - T_joint = T(expanded_design, unshuffled_trace, observation_labels, target_labels) - T_independent = T(expanded_design, shuffled_trace, observation_labels, target_labels) + T_joint = T( + expanded_design, unshuffled_trace, observation_labels, target_labels + ) + T_independent = T( + expanded_design, shuffled_trace, observation_labels, target_labels + ) - joint_expectation = T_joint.sum(0)/num_particles + joint_expectation = T_joint.sum(0) / num_particles A = T_independent - math.log(num_particles) s, _ = torch.max(A, dim=0) @@ -690,7 +927,9 @@ def loss_fn(design, num_particles, **kwargs): return loss_fn -def _posterior_loss(model, guide, observation_labels, target_labels, analytic_entropy=False): +def _posterior_loss( + model, guide, observation_labels, target_labels, analytic_entropy=False +): """Posterior loss: to evaluate directly use `posterior_eig` setting `num_steps=0`, `eig=False`.""" def loss_fn(design, num_particles, evaluation=False, **kwargs): @@ -705,12 +944,18 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): # Run through q(theta | y, d) conditional_guide = pyro.condition(guide, data=theta_dict) cond_trace = poutine.trace(conditional_guide).get_trace( - y_dict, expanded_design, observation_labels, target_labels) + y_dict, expanded_design, observation_labels, target_labels + ) cond_trace.compute_log_prob() if evaluation and analytic_entropy: - loss = mean_field_entropy( - guide, [y_dict, expanded_design, observation_labels, target_labels], - whitelist=target_labels).sum(0) / num_particles + loss = ( + mean_field_entropy( + guide, + [y_dict, expanded_design, observation_labels, target_labels], + whitelist=target_labels, + ).sum(0) + / num_particles + ) agg_loss = loss.sum() else: terms = -sum(cond_trace.nodes[l]["log_prob"] for l in target_labels) @@ -735,7 +980,8 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): # Run through q(y | d) conditional_guide = pyro.condition(guide, data=y_dict) cond_trace = poutine.trace(conditional_guide).get_trace( - expanded_design, observation_labels, target_labels) + expanded_design, observation_labels, target_labels + ) cond_trace.compute_log_prob() terms = -sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels) @@ -750,7 +996,9 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): return loss_fn -def _marginal_likelihood_loss(model, marginal_guide, likelihood_guide, observation_labels, target_labels): +def _marginal_likelihood_loss( + model, marginal_guide, likelihood_guide, observation_labels, target_labels +): """Marginal_likelihood loss: to evaluate directly use `marginal_likelihood_eig` setting `num_steps=0`.""" def loss_fn(design, num_particles, evaluation=False, **kwargs): @@ -765,13 +1013,15 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): # Run through q(y | d) qyd = pyro.condition(marginal_guide, data=y_dict) marginal_trace = poutine.trace(qyd).get_trace( - expanded_design, observation_labels, target_labels) + expanded_design, observation_labels, target_labels + ) marginal_trace.compute_log_prob() # Run through q(y | theta, d) qythetad = pyro.condition(likelihood_guide, data=y_dict) cond_trace = poutine.trace(qythetad).get_trace( - theta_dict, expanded_design, observation_labels, target_labels) + theta_dict, expanded_design, observation_labels, target_labels + ) cond_trace.compute_log_prob() terms = -sum(marginal_trace.nodes[l]["log_prob"] for l in observation_labels) @@ -787,7 +1037,9 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): return loss_fn -def _lfire_loss(model_marginal, model_conditional, h, observation_labels, target_labels): +def _lfire_loss( + model_marginal, model_conditional, h, observation_labels, target_labels +): """LFIRE loss: to evaluate directly use `lfire_eig` setting `num_steps=0`.""" def loss_fn(design, num_particles, evaluation=False, **kwargs): @@ -798,19 +1050,37 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): pass expanded_design = lexpand(design, num_particles) - model_conditional_trace = poutine.trace(model_conditional).get_trace(expanded_design) + model_conditional_trace = poutine.trace(model_conditional).get_trace( + expanded_design + ) if not evaluation: - model_marginal_trace = poutine.trace(model_marginal).get_trace(expanded_design) - - h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) - h_independent = h(expanded_design, model_marginal_trace, observation_labels, target_labels) - - terms = torch.nn.functional.softplus(-h_joint) + torch.nn.functional.softplus(h_independent) + model_marginal_trace = poutine.trace(model_marginal).get_trace( + expanded_design + ) + + h_joint = h( + expanded_design, + model_conditional_trace, + observation_labels, + target_labels, + ) + h_independent = h( + expanded_design, model_marginal_trace, observation_labels, target_labels + ) + + terms = torch.nn.functional.softplus( + -h_joint + ) + torch.nn.functional.softplus(h_independent) return _safe_mean_terms(terms) else: - h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) + h_joint = h( + expanded_design, + model_conditional_trace, + observation_labels, + target_labels, + ) return _safe_mean_terms(h_joint) return loss_fn @@ -831,7 +1101,8 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): reexpanded_design = lexpand(expanded_design, M) conditional_guide = pyro.condition(guide, data=y_dict) guide_trace = poutine.trace(conditional_guide).get_trace( - y_dict, reexpanded_design, observation_labels, target_labels) + y_dict, reexpanded_design, observation_labels, target_labels + ) theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels} theta_y_dict.update(y_dict) guide_trace.compute_log_prob() @@ -857,12 +1128,12 @@ def loss_fn(design, num_particles, evaluation=False, **kwargs): def _safe_mean_terms(terms): - mask = torch.isnan(terms) | (terms == float('-inf')) | (terms == float('inf')) + mask = torch.isnan(terms) | (terms == float("-inf")) | (terms == float("inf")) if terms.dtype is torch.float32: nonnan = (~mask).sum(0).float() elif terms.dtype is torch.float64: nonnan = (~mask).sum(0).double() - terms[mask] = 0. + terms[mask] = 0.0 loss = terms.sum(0) / nonnan agg_loss = loss.sum() return agg_loss, loss @@ -876,9 +1147,9 @@ def xexpx(a): :param torch.Tensor a: :return: Equivalent of `a*torch.exp(a)`. """ - mask = (a == float('-inf')) - y = a*torch.exp(a) - y[mask] = 0. + mask = a == float("-inf") + y = a * torch.exp(a) + y[mask] = 0.0 return y @@ -890,7 +1161,7 @@ def forward(ctx, input, ewma): @staticmethod def backward(ctx, grad_output): - ewma, = ctx.saved_tensors + (ewma,) = ctx.saved_tensors return grad_output / ewma, None @@ -915,20 +1186,21 @@ class EwmaLog: def __init__(self, alpha): self.alpha = alpha - self.ewma = 0. + self.ewma = 0.0 self.n = 0 - self.s = 0. + self.s = 0.0 def __call__(self, inputs, s, dim=0, keepdim=False): - """Updates the moving average, and returns :code:`inputs.log()`. - """ + """Updates the moving average, and returns :code:`inputs.log()`.""" self.n += 1 if torch_isnan(self.ewma) or torch_isinf(self.ewma): ewma = inputs else: - ewma = inputs * (1. - self.alpha) / (1 - self.alpha**self.n) \ - + torch.exp(self.s - s) * self.ewma \ - * (self.alpha - self.alpha**self.n) / (1 - self.alpha**self.n) + ewma = inputs * (1.0 - self.alpha) / (1 - self.alpha ** self.n) + torch.exp( + self.s - s + ) * self.ewma * (self.alpha - self.alpha ** self.n) / ( + 1 - self.alpha ** self.n + ) self.ewma = ewma.detach() self.s = s.detach() return _ewma_log_fn(inputs, ewma) diff --git a/pyro/contrib/oed/glmm/glmm.py b/pyro/contrib/oed/glmm/glmm.py index 68507be53f..55c6bbd9ed 100644 --- a/pyro/contrib/oed/glmm/glmm.py +++ b/pyro/contrib/oed/glmm/glmm.py @@ -16,11 +16,12 @@ from pyro.contrib.util import iter_plates_to_shape, rmv # TODO read from torch float spec -epsilon = torch.tensor(2**-24) +epsilon = torch.tensor(2 ** -24) -def known_covariance_linear_model(coef_means, coef_sds, observation_sd, - coef_labels="w", observation_label="y"): +def known_covariance_linear_model( + coef_means, coef_sds, observation_sd, coef_labels="w", observation_label="y" +): if not isinstance(coef_means, list): coef_means = [coef_means] @@ -29,121 +30,197 @@ def known_covariance_linear_model(coef_means, coef_sds, observation_sd, if not isinstance(coef_labels, list): coef_labels = [coef_labels] - model = partial(bayesian_linear_model, - w_means=OrderedDict([(label, mean) for label, mean in zip(coef_labels, coef_means)]), - w_sqrtlambdas=OrderedDict([ - (label, 1./(observation_sd*sd)) for label, sd in zip(coef_labels, coef_sds)]), - obs_sd=observation_sd, - response_label=observation_label) + model = partial( + bayesian_linear_model, + w_means=OrderedDict( + [(label, mean) for label, mean in zip(coef_labels, coef_means)] + ), + w_sqrtlambdas=OrderedDict( + [ + (label, 1.0 / (observation_sd * sd)) + for label, sd in zip(coef_labels, coef_sds) + ] + ), + obs_sd=observation_sd, + response_label=observation_label, + ) # For computing the true EIG model.obs_sd = observation_sd model.w_sds = OrderedDict([(label, sd) for label, sd in zip(coef_labels, coef_sds)]) - model.w_sizes = OrderedDict([(label, sd.shape[-1]) for label, sd in zip(coef_labels, coef_sds)]) + model.w_sizes = OrderedDict( + [(label, sd.shape[-1]) for label, sd in zip(coef_labels, coef_sds)] + ) model.observation_label = observation_label model.coef_labels = coef_labels return model def normal_guide(observation_sd, coef_shape, coef_label="w"): - return partial(normal_inv_gamma_family_guide, - obs_sd=observation_sd, - w_sizes={coef_label: coef_shape}) - - -def group_linear_model(coef1_mean, coef1_sd, coef2_mean, coef2_sd, observation_sd, - coef1_label="w1", coef2_label="w2", observation_label="y"): + return partial( + normal_inv_gamma_family_guide, + obs_sd=observation_sd, + w_sizes={coef_label: coef_shape}, + ) + + +def group_linear_model( + coef1_mean, + coef1_sd, + coef2_mean, + coef2_sd, + observation_sd, + coef1_label="w1", + coef2_label="w2", + observation_label="y", +): model = partial( bayesian_linear_model, w_means=OrderedDict([(coef1_label, coef1_mean), (coef2_label, coef2_mean)]), - w_sqrtlambdas=OrderedDict([(coef1_label, 1./(observation_sd*coef1_sd)), - (coef2_label, 1./(observation_sd*coef2_sd))]), + w_sqrtlambdas=OrderedDict( + [ + (coef1_label, 1.0 / (observation_sd * coef1_sd)), + (coef2_label, 1.0 / (observation_sd * coef2_sd)), + ] + ), obs_sd=observation_sd, - response_label=observation_label) + response_label=observation_label, + ) model.obs_sd = observation_sd model.w_sds = OrderedDict([(coef1_label, coef1_sd), (coef2_label, coef2_sd)]) return model -def group_normal_guide(observation_sd, coef1_shape, coef2_shape, - coef1_label="w1", coef2_label="w2"): +def group_normal_guide( + observation_sd, coef1_shape, coef2_shape, coef1_label="w1", coef2_label="w2" +): return partial( - normal_inv_gamma_family_guide, w_sizes=OrderedDict([(coef1_label, coef1_shape), (coef2_label, coef2_shape)]), - obs_sd=observation_sd) + normal_inv_gamma_family_guide, + w_sizes=OrderedDict([(coef1_label, coef1_shape), (coef2_label, coef2_shape)]), + obs_sd=observation_sd, + ) def zero_mean_unit_obs_sd_lm(coef_sd, coef_label="w"): - model = known_covariance_linear_model(torch.tensor(0.), coef_sd, torch.tensor(1.), coef_labels=coef_label) - guide = normal_guide(torch.tensor(1.), coef_sd.shape, coef_label=coef_label) + model = known_covariance_linear_model( + torch.tensor(0.0), coef_sd, torch.tensor(1.0), coef_labels=coef_label + ) + guide = normal_guide(torch.tensor(1.0), coef_sd.shape, coef_label=coef_label) return model, guide -def normal_inverse_gamma_linear_model(coef_mean, coef_sqrtlambda, alpha, - beta, coef_label="w", - observation_label="y"): - return partial(bayesian_linear_model, - w_means={coef_label: coef_mean}, - w_sqrtlambdas={coef_label: coef_sqrtlambda}, - alpha_0=alpha, beta_0=beta, - response_label=observation_label) +def normal_inverse_gamma_linear_model( + coef_mean, coef_sqrtlambda, alpha, beta, coef_label="w", observation_label="y" +): + return partial( + bayesian_linear_model, + w_means={coef_label: coef_mean}, + w_sqrtlambdas={coef_label: coef_sqrtlambda}, + alpha_0=alpha, + beta_0=beta, + response_label=observation_label, + ) def normal_inverse_gamma_guide(coef_shape, coef_label="w", **kwargs): - return partial(normal_inv_gamma_family_guide, obs_sd=None, w_sizes={coef_label: coef_shape}, **kwargs) - - -def logistic_regression_model(coef_mean, coef_sd, coef_label="w", observation_label="y"): - return partial(bayesian_linear_model, - w_means={coef_label: coef_mean}, - w_sqrtlambdas={coef_label: 1./coef_sd}, - obs_sd=torch.tensor(1.), - response="bernoulli", - response_label=observation_label) - - -def lmer_model(fixed_effects_sd, n_groups, random_effects_alpha, random_effects_beta, - fixed_effects_label="w", random_effects_label="u", observation_label="y", - response="normal"): - return partial(bayesian_linear_model, - w_means={fixed_effects_label: torch.tensor(0.)}, - w_sqrtlambdas={fixed_effects_label: 1./fixed_effects_sd}, - obs_sd=torch.tensor(1.), - re_group_sizes={random_effects_label: n_groups}, - re_alphas={random_effects_label: random_effects_alpha}, - re_betas={random_effects_label: random_effects_beta}, - response=response, - response_label=observation_label) - + return partial( + normal_inv_gamma_family_guide, + obs_sd=None, + w_sizes={coef_label: coef_shape}, + **kwargs + ) -def sigmoid_model(coef1_mean, coef1_sd, coef2_mean, coef2_sd, observation_sd, - sigmoid_alpha, sigmoid_beta, sigmoid_design, - coef1_label="w1", coef2_label="w2", observation_label="y", - sigmoid_label="k"): +def logistic_regression_model( + coef_mean, coef_sd, coef_label="w", observation_label="y" +): + return partial( + bayesian_linear_model, + w_means={coef_label: coef_mean}, + w_sqrtlambdas={coef_label: 1.0 / coef_sd}, + obs_sd=torch.tensor(1.0), + response="bernoulli", + response_label=observation_label, + ) + + +def lmer_model( + fixed_effects_sd, + n_groups, + random_effects_alpha, + random_effects_beta, + fixed_effects_label="w", + random_effects_label="u", + observation_label="y", + response="normal", +): + return partial( + bayesian_linear_model, + w_means={fixed_effects_label: torch.tensor(0.0)}, + w_sqrtlambdas={fixed_effects_label: 1.0 / fixed_effects_sd}, + obs_sd=torch.tensor(1.0), + re_group_sizes={random_effects_label: n_groups}, + re_alphas={random_effects_label: random_effects_alpha}, + re_betas={random_effects_label: random_effects_beta}, + response=response, + response_label=observation_label, + ) + + +def sigmoid_model( + coef1_mean, + coef1_sd, + coef2_mean, + coef2_sd, + observation_sd, + sigmoid_alpha, + sigmoid_beta, + sigmoid_design, + coef1_label="w1", + coef2_label="w2", + observation_label="y", + sigmoid_label="k", +): def model(design): batch_shape = design.shape[:-2] k_shape = batch_shape + (sigmoid_design.shape[-1],) - k = pyro.sample(sigmoid_label, - dist.Gamma(sigmoid_alpha.expand(k_shape), - sigmoid_beta.expand(k_shape)).to_event(1)) + k = pyro.sample( + sigmoid_label, + dist.Gamma( + sigmoid_alpha.expand(k_shape), sigmoid_beta.expand(k_shape) + ).to_event(1), + ) k_assigned = rmv(sigmoid_design, k) return bayesian_linear_model( design, w_means=OrderedDict([(coef1_label, coef1_mean), (coef2_label, coef2_mean)]), - w_sqrtlambdas={coef1_label: 1./(observation_sd*coef1_sd), coef2_label: 1./(observation_sd*coef2_sd)}, + w_sqrtlambdas={ + coef1_label: 1.0 / (observation_sd * coef1_sd), + coef2_label: 1.0 / (observation_sd * coef2_sd), + }, obs_sd=observation_sd, response="sigmoid", response_label=observation_label, - k=k_assigned - ) + k=k_assigned, + ) return model -def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={}, - re_alphas={}, re_betas={}, obs_sd=None, - alpha_0=None, beta_0=None, response="normal", - response_label="y", k=None): +def bayesian_linear_model( + design, + w_means={}, + w_sqrtlambdas={}, + re_group_sizes={}, + re_alphas={}, + re_betas={}, + obs_sd=None, + alpha_0=None, + beta_0=None, + response="normal", + response_label="y", + k=None, +): """ A pyro model for Bayesian linear regression. @@ -206,14 +283,17 @@ def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={ if obs_sd is None: # First, sample tau (observation precision) - tau_prior = dist.Gamma(alpha_0.unsqueeze(-1), - beta_0.unsqueeze(-1)).to_event(1) + tau_prior = dist.Gamma( + alpha_0.unsqueeze(-1), beta_0.unsqueeze(-1) + ).to_event(1) tau = pyro.sample("tau", tau_prior) - obs_sd = 1./torch.sqrt(tau) + obs_sd = 1.0 / torch.sqrt(tau) elif alpha_0 is not None or beta_0 is not None: - warnings.warn("Values of `alpha_0` and `beta_0` unused becased" - "`obs_sd` was specified already.") + warnings.warn( + "Values of `alpha_0` and `beta_0` unused becased" + "`obs_sd` was specified already." + ) obs_sd = obs_sd.expand(batch_shape + (1,)) @@ -231,10 +311,10 @@ def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={ # Sample `G` once for this group alpha, beta = re_alphas[name], re_betas[name] G_prior = dist.Gamma(alpha, beta).to_event(1) - G = 1./torch.sqrt(pyro.sample("G_" + name, G_prior)) + G = 1.0 / torch.sqrt(pyro.sample("G_" + name, G_prior)) # Repeat `G` for each group repeat_shape = tuple(1 for _ in batch_shape) + (group_size,) - u_prior = dist.Normal(torch.tensor(0.), G.repeat(repeat_shape)).to_event(1) + u_prior = dist.Normal(torch.tensor(0.0), G.repeat(repeat_shape)).to_event(1) w.append(pyro.sample(name, u_prior)) # Regression coefficient `w` is batch x p w = broadcast_cat(w) @@ -243,14 +323,21 @@ def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={ prediction_mean = rmv(design, w) if response == "normal": # y is an n-vector: hence use .to_event(1) - return pyro.sample(response_label, dist.Normal(prediction_mean, obs_sd).to_event(1)) + return pyro.sample( + response_label, dist.Normal(prediction_mean, obs_sd).to_event(1) + ) elif response == "bernoulli": - return pyro.sample(response_label, dist.Bernoulli(logits=prediction_mean).to_event(1)) + return pyro.sample( + response_label, dist.Bernoulli(logits=prediction_mean).to_event(1) + ) elif response == "sigmoid": base_dist = dist.Normal(prediction_mean, obs_sd).to_event(1) # You can add loc via the linear model itself k = k.expand(prediction_mean.shape) - transforms = [AffineTransform(loc=torch.tensor(0.), scale=k), SigmoidTransform()] + transforms = [ + AffineTransform(loc=torch.tensor(0.0), scale=k), + SigmoidTransform(), + ] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample(response_label, response_dist) else: @@ -287,12 +374,16 @@ def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False): if obs_sd is None: # First, sample tau (observation precision) - alpha = softplus(pyro.param("invsoftplus_alpha", 20.*torch.ones(tau_shape))) - beta = softplus(pyro.param("invsoftplus_beta", 20.*torch.ones(tau_shape))) + alpha = softplus( + pyro.param("invsoftplus_alpha", 20.0 * torch.ones(tau_shape)) + ) + beta = softplus( + pyro.param("invsoftplus_beta", 20.0 * torch.ones(tau_shape)) + ) # Global variable tau_prior = dist.Gamma(alpha, beta) tau = pyro.sample("tau", tau_prior) - obs_sd = 1./torch.sqrt(tau) + obs_sd = 1.0 / torch.sqrt(tau) # response will be shape batch x n obs_sd = obs_sd.expand(tau_shape).unsqueeze(-1) @@ -300,17 +391,19 @@ def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False): for name, size in w_sizes.items(): w_shape = tau_shape + size # Set up mu and lambda - mw_param = pyro.param("{}_guide_mean".format(name), - torch.zeros(w_shape)) + mw_param = pyro.param("{}_guide_mean".format(name), torch.zeros(w_shape)) scale_tril = pyro.param( "{}_guide_scale_tril".format(name), torch.eye(*size).expand(tau_shape + size + size), - constraint=constraints.lower_cholesky) + constraint=constraints.lower_cholesky, + ) # guide distributions for w if mf: w_dist = dist.MultivariateNormal(mw_param, scale_tril=scale_tril) else: - w_dist = dist.MultivariateNormal(mw_param, scale_tril=obs_sd.unsqueeze(-1) * scale_tril) + w_dist = dist.MultivariateNormal( + mw_param, scale_tril=obs_sd.unsqueeze(-1) * scale_tril + ) pyro.sample(name, w_dist) @@ -330,21 +423,21 @@ def group_assignment_matrix(design): for col, i in enumerate(design): i = int(i) if i > 0: - X[t:t+i, col] = 1. + X[t : t + i, col] = 1.0 t += i if t < n: - X[t:, -1] = 1. + X[t:, -1] = 1.0 return X def rf_group_assignments(n, random_intercept=True): assert n % 2 == 0 - n_designs = n//2 + 1 + n_designs = n // 2 + 1 participant_matrix = torch.eye(n) Xs = [] for i in range(n_designs): - X1 = group_assignment_matrix(torch.tensor([i, n//2 - i])) - X2 = group_assignment_matrix(torch.tensor([n//2 - i, i])) + X1 = group_assignment_matrix(torch.tensor([i, n // 2 - i])) + X2 = group_assignment_matrix(torch.tensor([n // 2 - i, i])) X = torch.cat([X1, X2], dim=-2) Xs.append(X) X = torch.stack(Xs, dim=0) @@ -364,7 +457,8 @@ def analytic_posterior_cov(prior_cov, x, obs_sd): p = prior_cov.shape[-1] SigmaXX = prior_cov.mm(x.t().mm(x)) posterior_cov = prior_cov - torch.inverse( - SigmaXX + (obs_sd**2)*torch.eye(p)).mm(SigmaXX.mm(prior_cov)) + SigmaXX + (obs_sd ** 2) * torch.eye(p) + ).mm(SigmaXX.mm(prior_cov)) return posterior_cov diff --git a/pyro/contrib/oed/glmm/guides.py b/pyro/contrib/oed/glmm/guides.py index 7ae544080e..2643ae2108 100644 --- a/pyro/contrib/oed/glmm/guides.py +++ b/pyro/contrib/oed/glmm/guides.py @@ -21,8 +21,16 @@ class LinearModelPosteriorGuide(nn.Module): - - def __init__(self, d, w_sizes, y_sizes, regressor_init=0., scale_tril_init=3., use_softplus=True, **kwargs): + def __init__( + self, + d, + w_sizes, + y_sizes, + regressor_init=0.0, + scale_tril_init=3.0, + use_softplus=True, + **kwargs + ): """ Guide for linear models. No amortisation happens over designs. Amortisation over data is taken care of by analytic formulae for @@ -42,10 +50,20 @@ def __init__(self, d, w_sizes, y_sizes, regressor_init=0., scale_tril_init=3., u # To avoid this- combine labels if not isinstance(d, (tuple, list, torch.Tensor)): d = (d,) - self.regressor = nn.ParameterDict({l: nn.Parameter( - regressor_init * torch.ones(*(d + (p, sum(y_sizes.values()))))) for l, p in w_sizes.items()}) - self.scale_tril = nn.ParameterDict({l: nn.Parameter( - scale_tril_init * lexpand(torch.eye(p), *d)) for l, p in w_sizes.items()}) + self.regressor = nn.ParameterDict( + { + l: nn.Parameter( + regressor_init * torch.ones(*(d + (p, sum(y_sizes.values())))) + ) + for l, p in w_sizes.items() + } + ) + self.scale_tril = nn.ParameterDict( + { + l: nn.Parameter(scale_tril_init * lexpand(torch.eye(p), *d)) + for l, p in w_sizes.items() + } + ) self.w_sizes = w_sizes self.use_softplus = use_softplus self.softplus = nn.Softplus() @@ -88,6 +106,7 @@ class LinearModelLaplaceGuide(nn.Module): fixed variance. :param float init_value: initial value for the posterior mean parameters. """ + def __init__(self, d, w_sizes, tau_label=None, init_value=0.1, **kwargs): super().__init__() # start in train mode @@ -97,17 +116,25 @@ def __init__(self, d, w_sizes, tau_label=None, init_value=0.1, **kwargs): self.means = nn.ParameterDict() if tau_label is not None: w_sizes[tau_label] = 1 - for l, mu_l in tensor_to_dict(w_sizes, init_value*torch.ones(*(d + (sum(w_sizes.values()), )))).items(): + for l, mu_l in tensor_to_dict( + w_sizes, init_value * torch.ones(*(d + (sum(w_sizes.values()),))) + ).items(): self.means[l] = nn.Parameter(mu_l) self.scale_trils = {} self.w_sizes = w_sizes @staticmethod def _hessian_diag(y, x, event_shape): - batch_shape = x.shape[:-len(event_shape)] + batch_shape = x.shape[: -len(event_shape)] assert tuple(x.shape) == tuple(batch_shape) + tuple(event_shape) - dy = torch.autograd.grad(y, [x, ], create_graph=True)[0] + dy = torch.autograd.grad( + y, + [ + x, + ], + create_graph=True, + )[0] H = [] # collapse independent dimensions into a single one, @@ -125,7 +152,14 @@ def _hessian_diag(y, x, event_shape): # loop over dependent part for i in range(flat_dy.shape[-1]): dyi = flat_dy.index_select(-1, torch.tensor([i])) - Hi = torch.autograd.grad([dyi], [x, ], grad_outputs=[torch.ones_like(dyi)], retain_graph=True)[0] + Hi = torch.autograd.grad( + [dyi], + [ + x, + ], + grad_outputs=[torch.ones_like(dyi)], + retain_graph=True, + )[0] H.append(Hi) H = torch.stack(H, -1).reshape(*(x.shape + event_shape)) return H @@ -174,12 +208,13 @@ def forward(self, design, target_labels=None): else: # Laplace approximation via MVN with hessian for l in target_labels: - w_dist = dist.MultivariateNormal(self.means[l], scale_tril=self.scale_trils[l]) + w_dist = dist.MultivariateNormal( + self.means[l], scale_tril=self.scale_trils[l] + ) pyro.sample(l, w_dist) class SigmoidGuide(LinearModelPosteriorGuide): - def __init__(self, d, n, w_sizes, **kwargs): super().__init__(d, w_sizes, **kwargs) self.inverse_sigmoid_scale = nn.Parameter(torch.ones(n)) @@ -191,9 +226,9 @@ def get_params(self, y_dict, design, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) # Approx invert transformation on y in expectation - y, y1m = y.clamp(1e-35, 1), (1.-y).clamp(1e-35, 1) + y, y1m = y.clamp(1e-35, 1), (1.0 - y).clamp(1e-35, 1) logited = y.log() - y1m.log() - y_trans = logited/.1 + y_trans = logited / 0.1 y_trans = y_trans * self.inverse_sigmoid_scale hidden = self.softplus(y_trans) y_trans = y_trans + hidden * self.h1_weight + self.h1_bias @@ -202,12 +237,19 @@ def get_params(self, y_dict, design, target_labels): class NormalInverseGammaGuide(LinearModelPosteriorGuide): - - def __init__(self, d, w_sizes, mf=False, tau_label="tau", alpha_init=100., - b0_init=100., **kwargs): + def __init__( + self, + d, + w_sizes, + mf=False, + tau_label="tau", + alpha_init=100.0, + b0_init=100.0, + **kwargs + ): super().__init__(d, w_sizes, **kwargs) - self.alpha = nn.Parameter(alpha_init*torch.ones(d)) - self.b0 = nn.Parameter(b0_init*torch.ones(d)) + self.alpha = nn.Parameter(alpha_init * torch.ones(d)) + self.b0 = nn.Parameter(b0_init * torch.ones(d)) self.mf = mf self.tau_label = tau_label @@ -215,13 +257,15 @@ def get_params(self, y_dict, design, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) - coefficient_labels = [label for label in target_labels if label != self.tau_label] + coefficient_labels = [ + label for label in target_labels if label != self.tau_label + ] mu, scale_tril = self.linear_model_formula(y, design, coefficient_labels) mu_vec = torch.cat(list(mu.values()), dim=-1) yty = rvv(y, y) ytxmu = rvv(y, rmv(design, mu_vec)) - beta = self.b0 + .5*(yty - ytxmu) + beta = self.b0 + 0.5 * (yty - ytxmu) return mu, scale_tril, self.alpha, beta @@ -234,16 +278,18 @@ def forward(self, y_dict, design, observation_labels, target_labels): if self.tau_label in target_labels: tau_dist = dist.Gamma(alpha, beta) tau = pyro.sample(self.tau_label, tau_dist) - obs_sd = 1./tau.sqrt().unsqueeze(-1).unsqueeze(-1) + obs_sd = 1.0 / tau.sqrt().unsqueeze(-1).unsqueeze(-1) for label in target_labels: if label != self.tau_label: if self.mf: - w_dist = dist.MultivariateNormal(mu[label], - scale_tril=scale_tril[label]) + w_dist = dist.MultivariateNormal( + mu[label], scale_tril=scale_tril[label] + ) else: - w_dist = dist.MultivariateNormal(mu[label], - scale_tril=scale_tril[label]*obs_sd) + w_dist = dist.MultivariateNormal( + mu[label], scale_tril=scale_tril[label] * obs_sd + ) pyro.sample(label, w_dist) @@ -251,6 +297,7 @@ class GuideDV(nn.Module): """A Donsker-Varadhan `T` family based on a guide family via the relation `T = log p(theta) - log q(theta | y, d)` """ + def __init__(self, guide): super().__init__() self.guide = guide @@ -264,7 +311,8 @@ def forward(self, design, trace, observation_labels, target_labels): conditional_guide = pyro.condition(self.guide, data=theta_dict) cond_trace = poutine.trace(conditional_guide).get_trace( - y_dict, design, observation_labels, target_labels) + y_dict, design, observation_labels, target_labels + ) cond_trace.compute_log_prob() posterior_lp = sum(cond_trace.nodes[l]["log_prob"] for l in target_labels) diff --git a/pyro/contrib/oed/search.py b/pyro/contrib/oed/search.py index 721f6305c3..358eb87051 100644 --- a/pyro/contrib/oed/search.py +++ b/pyro/contrib/oed/search.py @@ -15,6 +15,7 @@ class Search(TracePosterior): """ Exact inference by enumerating over all possible executions """ + def __init__(self, model, max_tries=int(1e6), **kwargs): self.model = model self.max_tries = max_tries @@ -23,8 +24,7 @@ def __init__(self, model, max_tries=int(1e6), **kwargs): def _traces(self, *args, **kwargs): q = queue.Queue() q.put(poutine.Trace()) - p = poutine.trace( - poutine.queue(self.model, queue=q, max_tries=self.max_tries)) + p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries)) while not q.empty(): tr = p.get_trace(*args, **kwargs) yield tr, tr.log_prob_sum() diff --git a/pyro/contrib/oed/util.py b/pyro/contrib/oed/util.py index 50774ff0bd..4476e82555 100644 --- a/pyro/contrib/oed/util.py +++ b/pyro/contrib/oed/util.py @@ -10,19 +10,26 @@ from pyro.infer.autoguide.utils import mean_field_entropy -def linear_model_ground_truth(model, design, observation_labels, target_labels, eig=True): +def linear_model_ground_truth( + model, design, observation_labels, target_labels, eig=True +): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) - prior_cov = torch.diag(w_sd**2) + prior_cov = torch.diag(w_sd ** 2) design_shape = design.shape - posterior_covs = [analytic_posterior_cov(prior_cov, x, model.obs_sd) for x in - torch.unbind(design.reshape(-1, design_shape[-2], design_shape[-1]))] + posterior_covs = [ + analytic_posterior_cov(prior_cov, x, model.obs_sd) + for x in torch.unbind(design.reshape(-1, design_shape[-2], design_shape[-1])) + ] target_indices = get_indices(target_labels, tensors=model.w_sds) - target_posterior_covs = [S[target_indices, :][:, target_indices] for S in posterior_covs] - output = torch.tensor([0.5 * torch.logdet(2 * math.pi * math.e * C) - for C in target_posterior_covs]) + target_posterior_covs = [ + S[target_indices, :][:, target_indices] for S in posterior_covs + ] + output = torch.tensor( + [0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs] + ) if eig: prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) output = prior_entropy - output diff --git a/pyro/contrib/randomvariable/random_variable.py b/pyro/contrib/randomvariable/random_variable.py index b2421b7160..05d08238a5 100644 --- a/pyro/contrib/randomvariable/random_variable.py +++ b/pyro/contrib/randomvariable/random_variable.py @@ -19,38 +19,57 @@ class RVMagicOps: - """Mixin class for overloading __magic__ operations on random variables. - """ + """Mixin class for overloading __magic__ operations on random variables.""" def __add__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, 1)) + ) def __radd__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, 1))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, 1)) + ) def __sub__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(-x, 1))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(-x, 1)) + ) def __rsub__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(x, -1))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(x, -1)) + ) def __mul__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, x)) + ) def __rmul__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, x))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, x)) + ) def __truediv__(self, x: Union[float, Tensor]): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, 1/x))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, 1 / x)) + ) def __neg__(self): - return RandomVariable(TransformedDistribution(self.distribution, AffineTransform(0, -1))) + return RandomVariable( + TransformedDistribution(self.distribution, AffineTransform(0, -1)) + ) def __abs__(self): - return RandomVariable(TransformedDistribution(self.distribution, AbsTransform())) + return RandomVariable( + TransformedDistribution(self.distribution, AbsTransform()) + ) def __pow__(self, x): - return RandomVariable(TransformedDistribution(self.distribution, PowerTransform(x))) + return RandomVariable( + TransformedDistribution(self.distribution, PowerTransform(x)) + ) class RVChainOps: diff --git a/pyro/contrib/timeseries/base.py b/pyro/contrib/timeseries/base.py index c1e357c86b..76d9e299a5 100644 --- a/pyro/contrib/timeseries/base.py +++ b/pyro/contrib/timeseries/base.py @@ -8,6 +8,7 @@ class TimeSeriesModel(PyroModule): """ Base class for univariate and multivariate time series models. """ + @pyro_method def log_prob(self, targets): """ diff --git a/pyro/contrib/timeseries/gp.py b/pyro/contrib/timeseries/gp.py index 5c5a91583c..0c2eb71061 100644 --- a/pyro/contrib/timeseries/gp.py +++ b/pyro/contrib/timeseries/gp.py @@ -31,9 +31,16 @@ class IndependentMaternGP(TimeSeriesModel): :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale given as a ``obs_dim``-dimensional tensor """ - def __init__(self, nu=1.5, dt=1.0, obs_dim=1, - length_scale_init=None, kernel_scale_init=None, - obs_noise_scale_init=None): + + def __init__( + self, + nu=1.5, + dt=1.0, + obs_dim=1, + length_scale_init=None, + kernel_scale_init=None, + obs_noise_scale_init=None, + ): self.nu = nu self.dt = dt self.obs_dim = obs_dim @@ -44,23 +51,31 @@ def __init__(self, nu=1.5, dt=1.0, obs_dim=1, super().__init__() - self.kernel = MaternKernel(nu=nu, num_gps=obs_dim, - length_scale_init=length_scale_init, - kernel_scale_init=kernel_scale_init) + self.kernel = MaternKernel( + nu=nu, + num_gps=obs_dim, + length_scale_init=length_scale_init, + kernel_scale_init=kernel_scale_init, + ) - self.obs_noise_scale = PyroParam(obs_noise_scale_init, - constraint=constraints.positive) + self.obs_noise_scale = PyroParam( + obs_noise_scale_init, constraint=constraints.positive + ) obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1) self.register_buffer("obs_matrix", torch.tensor(obs_matrix).unsqueeze(-1)) def _get_init_dist(self): - return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, self.kernel.state_dim), - self.kernel.stationary_covariance().squeeze(-3)) + return torch.distributions.MultivariateNormal( + self.obs_matrix.new_zeros(self.obs_dim, self.kernel.state_dim), + self.kernel.stationary_covariance().squeeze(-3), + ) def _get_obs_dist(self): - return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim, 1, 1), - self.obs_noise_scale.unsqueeze(-1).unsqueeze(-1)).to_event(1) + return dist.Normal( + self.obs_matrix.new_zeros(self.obs_dim, 1, 1), + self.obs_noise_scale.unsqueeze(-1).unsqueeze(-1), + ).to_event(1) def get_dist(self, duration=None): """ @@ -71,12 +86,22 @@ def get_dist(self, duration=None): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) - trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim), - process_covar.unsqueeze(-3)) + trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance( + dt=self.dt + ) + trans_dist = MultivariateNormal( + self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim), + process_covar.unsqueeze(-3), + ) trans_matrix = trans_matrix.unsqueeze(-3) - return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, - self.obs_matrix, self._get_obs_dist(), duration=duration) + return dist.GaussianHMM( + self._get_init_dist(), + trans_matrix, + trans_dist, + self.obs_matrix, + self._get_obs_dist(), + duration=duration, + ) @pyro_method def log_prob(self, targets): @@ -104,15 +129,25 @@ def _forecast(self, dts, filtering_state, include_observation_noise=True): """ assert dts.dim() == 1 dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts) + trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance( + dt=dts + ) trans_matrix = trans_matrix[..., 0:1] - predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_matrix).squeeze(-2)[..., 0] - predicted_function_covar = torch.matmul(trans_matrix.transpose(-1, -2), torch.matmul( - filtering_state.covariance_matrix, trans_matrix))[..., 0, 0] + \ - process_covar[..., 0, 0] + predicted_mean = torch.matmul( + filtering_state.loc.unsqueeze(-2), trans_matrix + ).squeeze(-2)[..., 0] + predicted_function_covar = ( + torch.matmul( + trans_matrix.transpose(-1, -2), + torch.matmul(filtering_state.covariance_matrix, trans_matrix), + )[..., 0, 0] + + process_covar[..., 0, 0] + ) if include_observation_noise: - predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0) + predicted_function_covar = ( + predicted_function_covar + self.obs_noise_scale.pow(2.0) + ) return predicted_mean, predicted_function_covar @pyro_method @@ -157,9 +192,17 @@ class LinearlyCoupledMaternGP(TimeSeriesModel): :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale given as a ``obs_dim``-dimensional tensor """ - def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1, - length_scale_init=None, kernel_scale_init=None, - obs_noise_scale_init=None): + + def __init__( + self, + nu=1.5, + dt=1.0, + obs_dim=2, + num_gps=1, + length_scale_init=None, + kernel_scale_init=None, + obs_noise_scale_init=None, + ): self.nu = nu self.dt = dt assert obs_dim > 1, "If obs_dim==1 you should use IndependentMaternGP" @@ -176,19 +219,28 @@ def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1, super().__init__() - self.kernel = MaternKernel(nu=nu, num_gps=num_gps, - length_scale_init=length_scale_init, - kernel_scale_init=kernel_scale_init) + self.kernel = MaternKernel( + nu=nu, + num_gps=num_gps, + length_scale_init=length_scale_init, + kernel_scale_init=kernel_scale_init, + ) self.full_state_dim = num_gps * self.kernel.state_dim - self.obs_noise_scale = PyroParam(obs_noise_scale_init, - constraint=constraints.positive) + self.obs_noise_scale = PyroParam( + obs_noise_scale_init, constraint=constraints.positive + ) self.A = nn.Parameter(0.3 * torch.randn(self.num_gps, self.obs_dim)) def _get_obs_matrix(self): # (num_gps, obs_dim) => (state_dim * num_gps, obs_dim) - return self.A.repeat_interleave(self.kernel.state_dim, dim=0) * \ - self.A.new_tensor([1.0] + [0.0] * (self.kernel.state_dim - 1)).repeat(self.num_gps).unsqueeze(-1) + return self.A.repeat_interleave( + self.kernel.state_dim, dim=0 + ) * self.A.new_tensor([1.0] + [0.0] * (self.kernel.state_dim - 1)).repeat( + self.num_gps + ).unsqueeze( + -1 + ) def _stationary_covariance(self): return block_diag_embed(self.kernel.stationary_covariance()) @@ -210,13 +262,21 @@ def get_dist(self, duration=None): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) + trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance( + dt=self.dt + ) trans_matrix = block_diag_embed(trans_matrix) process_covar = block_diag_embed(process_covar) loc = self.A.new_zeros(self.full_state_dim) trans_dist = MultivariateNormal(loc, process_covar) - return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, - self._get_obs_matrix(), self._get_obs_dist(), duration=duration) + return dist.GaussianHMM( + self._get_init_dist(), + trans_matrix, + trans_dist, + self._get_obs_matrix(), + self._get_obs_dist(), + duration=duration, + ) @pyro_method def log_prob(self, targets): @@ -238,7 +298,9 @@ def _filter(self, targets): return self.get_dist().filter(targets) @torch.no_grad() - def _forecast(self, dts, filtering_state, include_observation_noise=True, full_covar=True): + def _forecast( + self, dts, filtering_state, include_observation_noise=True, full_covar=True + ): """ Internal helper for forecasting. """ @@ -246,21 +308,29 @@ def _forecast(self, dts, filtering_state, include_observation_noise=True, full_c dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) trans_mat, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts) trans_mat = block_diag_embed(trans_mat) # S x full_state_dim x full_state_dim - process_covar = block_diag_embed(process_covar) # S x full_state_dim x full_state_dim + process_covar = block_diag_embed( + process_covar + ) # S x full_state_dim x full_state_dim obs_matrix = self._get_obs_matrix() # full_state_dim x obs_dim trans_obs = torch.matmul(trans_mat, obs_matrix) # S x full_state_dim x obs_dim - predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_obs).squeeze(-2) - predicted_function_covar = torch.matmul(trans_obs.transpose(-1, -2), - torch.matmul(filtering_state.covariance_matrix, - trans_obs)) - predicted_function_covar = predicted_function_covar + \ - torch.matmul(obs_matrix.transpose(-1, -2), torch.matmul(process_covar, obs_matrix)) + predicted_mean = torch.matmul( + filtering_state.loc.unsqueeze(-2), trans_obs + ).squeeze(-2) + predicted_function_covar = torch.matmul( + trans_obs.transpose(-1, -2), + torch.matmul(filtering_state.covariance_matrix, trans_obs), + ) + predicted_function_covar = predicted_function_covar + torch.matmul( + obs_matrix.transpose(-1, -2), torch.matmul(process_covar, obs_matrix) + ) if include_observation_noise: obs_noise = self.obs_noise_scale.pow(2.0).diag_embed() predicted_function_covar = predicted_function_covar + obs_noise if not full_covar: - predicted_function_covar = predicted_function_covar.diagonal(dim1=-1, dim2=-2) + predicted_function_covar = predicted_function_covar.diagonal( + dim1=-1, dim2=-2 + ) return predicted_mean, predicted_function_covar @@ -305,8 +375,16 @@ class DependentMaternGP(TimeSeriesModel): References [1] "Dependent Matern Processes for Multivariate Time Series," Alexander Vandenberg-Rodes, Babak Shahbaba. """ - def __init__(self, nu=1.5, dt=1.0, obs_dim=1, linearly_coupled=False, - length_scale_init=None, obs_noise_scale_init=None): + + def __init__( + self, + nu=1.5, + dt=1.0, + obs_dim=1, + linearly_coupled=False, + length_scale_init=None, + obs_noise_scale_init=None, + ): if nu != 1.5: raise NotImplementedError("The only supported value of nu is 1.5") @@ -320,8 +398,9 @@ def __init__(self, nu=1.5, dt=1.0, obs_dim=1, linearly_coupled=False, super().__init__() - self.kernel = MaternKernel(nu=nu, num_gps=obs_dim, - length_scale_init=length_scale_init) + self.kernel = MaternKernel( + nu=nu, num_gps=obs_dim, length_scale_init=length_scale_init + ) self.full_state_dim = self.kernel.state_dim * obs_dim # we demote self.kernel.kernel_scale from being a nn.Parameter @@ -329,14 +408,18 @@ def __init__(self, nu=1.5, dt=1.0, obs_dim=1, linearly_coupled=False, del self.kernel.kernel_scale self.kernel.register_buffer("kernel_scale", torch.ones(obs_dim)) - self.obs_noise_scale = PyroParam(obs_noise_scale_init, - constraint=constraints.positive) - self.wiener_noise_tril = PyroParam(torch.eye(obs_dim) + - 0.03 * torch.randn(obs_dim, obs_dim).tril(-1), - constraint=constraints.lower_cholesky) + self.obs_noise_scale = PyroParam( + obs_noise_scale_init, constraint=constraints.positive + ) + self.wiener_noise_tril = PyroParam( + torch.eye(obs_dim) + 0.03 * torch.randn(obs_dim, obs_dim).tril(-1), + constraint=constraints.lower_cholesky, + ) if linearly_coupled: - self.obs_matrix = nn.Parameter(0.3 * torch.randn(self.obs_dim, self.obs_dim)) + self.obs_matrix = nn.Parameter( + 0.3 * torch.randn(self.obs_dim, self.obs_dim) + ) else: obs_matrix = torch.zeros(self.full_state_dim, obs_dim) for i in range(obs_dim): @@ -347,38 +430,49 @@ def _get_obs_matrix(self): if self.obs_matrix.size(0) == self.obs_dim: # (num_gps, obs_dim) => (state_dim * num_gps, obs_dim) selector = [1.0] + [0.0] * (self.kernel.state_dim - 1) - return self.obs_matrix.repeat_interleave(self.kernel.state_dim, dim=0) * \ - self.obs_matrix.new_tensor(selector).repeat(self.obs_dim).unsqueeze(-1) + return self.obs_matrix.repeat_interleave( + self.kernel.state_dim, dim=0 + ) * self.obs_matrix.new_tensor(selector).repeat(self.obs_dim).unsqueeze(-1) else: return self.obs_matrix def _get_init_dist(self, stationary_covariance): - return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.full_state_dim), - stationary_covariance) + return torch.distributions.MultivariateNormal( + self.obs_matrix.new_zeros(self.full_state_dim), stationary_covariance + ) def _get_obs_dist(self): - return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim), - self.obs_noise_scale).to_event(1) + return dist.Normal( + self.obs_matrix.new_zeros(self.obs_dim), self.obs_noise_scale + ).to_event(1) def _get_wiener_cov(self): chol = self.wiener_noise_tril wiener_cov = torch.mm(chol, chol.t()).reshape(self.obs_dim, 1, self.obs_dim, 1) - wiener_cov = wiener_cov * wiener_cov.new_ones(self.kernel.state_dim, 1, self.kernel.state_dim) + wiener_cov = wiener_cov * wiener_cov.new_ones( + self.kernel.state_dim, 1, self.kernel.state_dim + ) return wiener_cov.reshape(self.full_state_dim, self.full_state_dim) def _stationary_covariance(self): rho_j = math.sqrt(3.0) / self.kernel.length_scale.unsqueeze(-1).unsqueeze(-1) rho_i = rho_j.unsqueeze(-1) - block = 2.0 * self.kernel.mask00 + \ - (rho_i - rho_j) * (self.kernel.mask01 - self.kernel.mask10) + \ - (2.0 * rho_i * rho_j) * self.kernel.mask11 + block = ( + 2.0 * self.kernel.mask00 + + (rho_i - rho_j) * (self.kernel.mask01 - self.kernel.mask10) + + (2.0 * rho_i * rho_j) * self.kernel.mask11 + ) block = block / (rho_i + rho_j).pow(3.0) - block = block.transpose(-2, -3).reshape(self.full_state_dim, self.full_state_dim) + block = block.transpose(-2, -3).reshape( + self.full_state_dim, self.full_state_dim + ) return self._get_wiener_cov() * block def _get_trans_dist(self, trans_matrix, stationary_covariance): - covar = stationary_covariance - torch.matmul(trans_matrix.transpose(-1, -2), - torch.matmul(stationary_covariance, trans_matrix)) + covar = stationary_covariance - torch.matmul( + trans_matrix.transpose(-1, -2), + torch.matmul(stationary_covariance, trans_matrix), + ) return MultivariateNormal(covar.new_zeros(self.full_state_dim), covar) def _trans_matrix_distribution_stat_covar(self, dts): @@ -396,9 +490,19 @@ def get_dist(self, duration=None): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt) - return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix, - trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration) + ( + trans_matrix, + trans_dist, + stat_covar, + ) = self._trans_matrix_distribution_stat_covar(self.dt) + return dist.GaussianHMM( + self._get_init_dist(stat_covar), + trans_matrix, + trans_dist, + self._get_obs_matrix(), + self._get_obs_dist(), + duration=duration, + ) @pyro_method def log_prob(self, targets): @@ -430,13 +534,20 @@ def _forecast(self, dts, filtering_state, include_observation_noise=True): obs_matrix = self._get_obs_matrix() trans_obs = torch.matmul(trans_matrix, obs_matrix) - predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_obs).squeeze(-2) - predicted_function_covar = torch.matmul(trans_obs.transpose(-1, -2), - torch.matmul(filtering_state.covariance_matrix, trans_obs)) + \ - torch.matmul(obs_matrix.t(), torch.matmul(trans_dist.covariance_matrix, obs_matrix)) + predicted_mean = torch.matmul( + filtering_state.loc.unsqueeze(-2), trans_obs + ).squeeze(-2) + predicted_function_covar = torch.matmul( + trans_obs.transpose(-1, -2), + torch.matmul(filtering_state.covariance_matrix, trans_obs), + ) + torch.matmul( + obs_matrix.t(), torch.matmul(trans_dist.covariance_matrix, obs_matrix) + ) if include_observation_noise: - predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0) + predicted_function_covar = ( + predicted_function_covar + self.obs_noise_scale.pow(2.0) + ) return predicted_mean, predicted_function_covar diff --git a/pyro/contrib/timeseries/lgssm.py b/pyro/contrib/timeseries/lgssm.py index 5f7c21492a..031bc03b06 100644 --- a/pyro/contrib/timeseries/lgssm.py +++ b/pyro/contrib/timeseries/lgssm.py @@ -22,8 +22,14 @@ class GenericLGSSM(TimeSeriesModel): :param bool learnable_observation_loc: whether the mean of the observation model should be learned or not; defaults to False. """ - def __init__(self, obs_dim=1, state_dim=2, obs_noise_scale_init=None, - learnable_observation_loc=False): + + def __init__( + self, + obs_dim=1, + state_dim=2, + obs_noise_scale_init=None, + learnable_observation_loc=False, + ): self.obs_dim = obs_dim self.state_dim = state_dim @@ -33,19 +39,24 @@ def __init__(self, obs_dim=1, state_dim=2, obs_noise_scale_init=None, super().__init__() - self.obs_noise_scale = PyroParam(obs_noise_scale_init, - constraint=constraints.positive) - self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim), - constraint=constraints.positive) - self.trans_matrix = nn.Parameter(torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim)) + self.obs_noise_scale = PyroParam( + obs_noise_scale_init, constraint=constraints.positive + ) + self.trans_noise_scale_sq = PyroParam( + torch.ones(state_dim), constraint=constraints.positive + ) + self.trans_matrix = nn.Parameter( + torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim) + ) self.obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim)) - self.init_noise_scale_sq = PyroParam(torch.ones(state_dim), - constraint=constraints.positive) + self.init_noise_scale_sq = PyroParam( + torch.ones(state_dim), constraint=constraints.positive + ) if learnable_observation_loc: self.obs_loc = nn.Parameter(torch.zeros(obs_dim)) else: - self.register_buffer('obs_loc', torch.zeros(obs_dim)) + self.register_buffer("obs_loc", torch.zeros(obs_dim)) def _get_init_dist(self): loc = self.obs_matrix.new_zeros(self.state_dim) @@ -66,8 +77,14 @@ def get_dist(self, duration=None): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - return dist.GaussianHMM(self._get_init_dist(), self.trans_matrix, self._get_trans_dist(), - self.obs_matrix, self._get_obs_dist(), duration=duration) + return dist.GaussianHMM( + self._get_init_dist(), + self.trans_matrix, + self._get_trans_dist(), + self.obs_matrix, + self._get_obs_dist(), + duration=duration, + ) @pyro_method def log_prob(self, targets): @@ -98,21 +115,26 @@ def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True predicted_mean = torch.matmul(filtering_state.loc, N_trans_obs) # first compute the contribution from filtering_state.covariance_matrix - predicted_covar1 = torch.matmul(N_trans_obs.transpose(-1, -2), - torch.matmul(filtering_state.covariance_matrix, - N_trans_obs)) # N O O + predicted_covar1 = torch.matmul( + N_trans_obs.transpose(-1, -2), + torch.matmul(filtering_state.covariance_matrix, N_trans_obs), + ) # N O O # next compute the contribution from process noise that is injected at each timestep. # (we need to do a cumulative sum to integrate across time) process_covar = self._get_trans_dist().covariance_matrix N_trans_obs_shift = torch.cat([self.obs_matrix.unsqueeze(0), N_trans_obs[:-1]]) - predicted_covar2 = torch.matmul(N_trans_obs_shift.transpose(-1, -2), - torch.matmul(process_covar, N_trans_obs_shift)) # N O O + predicted_covar2 = torch.matmul( + N_trans_obs_shift.transpose(-1, -2), + torch.matmul(process_covar, N_trans_obs_shift), + ) # N O O predicted_covar = predicted_covar1 + torch.cumsum(predicted_covar2, dim=0) if include_observation_noise: - predicted_covar = predicted_covar + self.obs_noise_scale.pow(2.0).diag_embed() + predicted_covar = ( + predicted_covar + self.obs_noise_scale.pow(2.0).diag_embed() + ) return predicted_mean, predicted_covar diff --git a/pyro/contrib/timeseries/lgssmgp.py b/pyro/contrib/timeseries/lgssmgp.py index 00e987c0ff..640a257dbb 100644 --- a/pyro/contrib/timeseries/lgssmgp.py +++ b/pyro/contrib/timeseries/lgssmgp.py @@ -40,9 +40,17 @@ class GenericLGSSMWithGPNoiseModel(TimeSeriesModel): :param bool learnable_observation_loc: whether the mean of the observation model should be learned or not; defaults to False. """ - def __init__(self, obs_dim=1, state_dim=2, nu=1.5, obs_noise_scale_init=None, - length_scale_init=None, kernel_scale_init=None, - learnable_observation_loc=False): + + def __init__( + self, + obs_dim=1, + state_dim=2, + nu=1.5, + obs_noise_scale_init=None, + length_scale_init=None, + kernel_scale_init=None, + learnable_observation_loc=False, + ): self.obs_dim = obs_dim self.state_dim = state_dim self.nu = nu @@ -53,33 +61,43 @@ def __init__(self, obs_dim=1, state_dim=2, nu=1.5, obs_noise_scale_init=None, super().__init__() - self.kernel = MaternKernel(nu=nu, num_gps=obs_dim, - length_scale_init=length_scale_init, - kernel_scale_init=kernel_scale_init) + self.kernel = MaternKernel( + nu=nu, + num_gps=obs_dim, + length_scale_init=length_scale_init, + kernel_scale_init=kernel_scale_init, + ) self.dt = 1.0 self.full_state_dim = self.kernel.state_dim * obs_dim + state_dim self.full_gp_state_dim = self.kernel.state_dim * obs_dim - self.obs_noise_scale = PyroParam(obs_noise_scale_init, - constraint=constraints.positive) - self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim), - constraint=constraints.positive) - self.z_trans_matrix = nn.Parameter(torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim)) + self.obs_noise_scale = PyroParam( + obs_noise_scale_init, constraint=constraints.positive + ) + self.trans_noise_scale_sq = PyroParam( + torch.ones(state_dim), constraint=constraints.positive + ) + self.z_trans_matrix = nn.Parameter( + torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim) + ) self.z_obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim)) - self.init_noise_scale_sq = PyroParam(torch.ones(state_dim), - constraint=constraints.positive) + self.init_noise_scale_sq = PyroParam( + torch.ones(state_dim), constraint=constraints.positive + ) gp_obs_matrix = torch.zeros(self.kernel.state_dim * obs_dim, obs_dim) for i in range(obs_dim): gp_obs_matrix[self.kernel.state_dim * i, i] = 1.0 self.register_buffer("gp_obs_matrix", gp_obs_matrix) - self.obs_selector = torch.tensor([self.kernel.state_dim * d for d in range(obs_dim)], dtype=torch.long) + self.obs_selector = torch.tensor( + [self.kernel.state_dim * d for d in range(obs_dim)], dtype=torch.long + ) if learnable_observation_loc: self.obs_loc = nn.Parameter(torch.zeros(obs_dim)) else: - self.register_buffer('obs_loc', torch.zeros(obs_dim)) + self.register_buffer("obs_loc", torch.zeros(obs_dim)) def _get_obs_matrix(self): # (obs_dim + state_dim, obs_dim) => (gp_state_dim * obs_dim + state_dim, obs_dim) @@ -88,8 +106,12 @@ def _get_obs_matrix(self): def _get_init_dist(self): loc = self.z_trans_matrix.new_zeros(self.full_state_dim) covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) - covar[:self.full_gp_state_dim, :self.full_gp_state_dim] = block_diag_embed(self.kernel.stationary_covariance()) - covar[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.init_noise_scale_sq.diag_embed() + covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = block_diag_embed( + self.kernel.stationary_covariance() + ) + covar[ + self.full_gp_state_dim :, self.full_gp_state_dim : + ] = self.init_noise_scale_sq.diag_embed() return MultivariateNormal(loc, covar) def _get_obs_dist(self): @@ -104,19 +126,40 @@ def get_dist(self, duration=None): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) - - trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) - trans_covar[:self.full_gp_state_dim, :self.full_gp_state_dim] = block_diag_embed(gp_process_covar) - trans_covar[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed() - trans_dist = MultivariateNormal(trans_covar.new_zeros(self.full_state_dim), trans_covar) + ( + gp_trans_matrix, + gp_process_covar, + ) = self.kernel.transition_matrix_and_covariance(dt=self.dt) + + trans_covar = self.z_trans_matrix.new_zeros( + self.full_state_dim, self.full_state_dim + ) + trans_covar[ + : self.full_gp_state_dim, : self.full_gp_state_dim + ] = block_diag_embed(gp_process_covar) + trans_covar[ + self.full_gp_state_dim :, self.full_gp_state_dim : + ] = self.trans_noise_scale_sq.diag_embed() + trans_dist = MultivariateNormal( + trans_covar.new_zeros(self.full_state_dim), trans_covar + ) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) - full_trans_mat[:self.full_gp_state_dim, :self.full_gp_state_dim] = block_diag_embed(gp_trans_matrix) - full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix - - return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, - self._get_obs_matrix(), self._get_obs_dist(), duration=duration) + full_trans_mat[ + : self.full_gp_state_dim, : self.full_gp_state_dim + ] = block_diag_embed(gp_trans_matrix) + full_trans_mat[ + self.full_gp_state_dim :, self.full_gp_state_dim : + ] = self.z_trans_matrix + + return dist.GaussianHMM( + self._get_init_dist(), + full_trans_mat, + trans_dist, + self._get_obs_matrix(), + self._get_obs_dist(), + duration=duration, + ) @pyro_method def log_prob(self, targets): @@ -142,10 +185,20 @@ def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True """ Internal helper for forecasting. """ - dts = torch.arange(N_timesteps, dtype=self.z_trans_matrix.dtype, device=self.z_trans_matrix.device) + 1.0 + dts = ( + torch.arange( + N_timesteps, + dtype=self.z_trans_matrix.dtype, + device=self.z_trans_matrix.device, + ) + + 1.0 + ) dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(dt=dts) + ( + gp_trans_matrix, + gp_process_covar, + ) = self.kernel.transition_matrix_and_covariance(dt=dts) gp_trans_matrix = block_diag_embed(gp_trans_matrix) gp_process_covar = block_diag_embed(gp_process_covar[..., 0:1, 0:1]) @@ -153,33 +206,53 @@ def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True N_trans_obs = torch.matmul(N_trans_matrix, self.z_obs_matrix) # z-state contribution + gp contribution - predicted_mean1 = torch.matmul(filtering_state.loc[-self.state_dim:].unsqueeze(-2), N_trans_obs).squeeze(-2) - predicted_mean2 = torch.matmul(filtering_state.loc[:self.full_gp_state_dim].unsqueeze(-2), - gp_trans_matrix[..., self.obs_selector]).squeeze(-2) + predicted_mean1 = torch.matmul( + filtering_state.loc[-self.state_dim :].unsqueeze(-2), N_trans_obs + ).squeeze(-2) + predicted_mean2 = torch.matmul( + filtering_state.loc[: self.full_gp_state_dim].unsqueeze(-2), + gp_trans_matrix[..., self.obs_selector], + ).squeeze(-2) predicted_mean = predicted_mean1 + predicted_mean2 # first compute the contributions from filtering_state.covariance_matrix: z-space and gp fs_cov = filtering_state.covariance_matrix - predicted_covar1z = torch.matmul(N_trans_obs.transpose(-1, -2), - torch.matmul(fs_cov[self.full_gp_state_dim:, self.full_gp_state_dim:], - N_trans_obs)) # N O O + predicted_covar1z = torch.matmul( + N_trans_obs.transpose(-1, -2), + torch.matmul( + fs_cov[self.full_gp_state_dim :, self.full_gp_state_dim :], N_trans_obs + ), + ) # N O O gp_trans = gp_trans_matrix[..., self.obs_selector] - predicted_covar1gp = torch.matmul(gp_trans.transpose(-1, -2), - torch.matmul(fs_cov[:self.full_gp_state_dim:, :self.full_gp_state_dim], - gp_trans)) + predicted_covar1gp = torch.matmul( + gp_trans.transpose(-1, -2), + torch.matmul( + fs_cov[: self.full_gp_state_dim :, : self.full_gp_state_dim], gp_trans + ), + ) # next compute the contribution from process noise that is injected at each timestep. # (we need to do a cumulative sum to integrate across time for the z-state contribution) z_process_covar = self.trans_noise_scale_sq.diag_embed() - N_trans_obs_shift = torch.cat([self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]]) - predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose(-1, -2), - torch.matmul(z_process_covar, N_trans_obs_shift)) # N O O - - predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \ - torch.cumsum(predicted_covar2z, dim=0) + N_trans_obs_shift = torch.cat( + [self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]] + ) + predicted_covar2z = torch.matmul( + N_trans_obs_shift.transpose(-1, -2), + torch.matmul(z_process_covar, N_trans_obs_shift), + ) # N O O + + predicted_covar = ( + predicted_covar1z + + predicted_covar1gp + + gp_process_covar + + torch.cumsum(predicted_covar2z, dim=0) + ) if include_observation_noise: - predicted_covar = predicted_covar + self.obs_noise_scale.pow(2.0).diag_embed() + predicted_covar = ( + predicted_covar + self.obs_noise_scale.pow(2.0).diag_embed() + ) return predicted_mean, predicted_covar diff --git a/pyro/contrib/tracking/assignment.py b/pyro/contrib/tracking/assignment.py index d9699f33f8..dec9c4b9ea 100644 --- a/pyro/contrib/tracking/assignment.py +++ b/pyro/contrib/tracking/assignment.py @@ -12,7 +12,7 @@ def _product(factors): - result = 1. + result = 1.0 for factor in factors: result = result * factor return result @@ -52,6 +52,7 @@ class MarginalAssignment: final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ + def __init__(self, exists_logits, assign_logits, bp_iters=None): assert exists_logits.dim() == 1, exists_logits.shape assert assign_logits.dim() == 2, assign_logits.shape @@ -66,7 +67,9 @@ def __init__(self, exists_logits, assign_logits, bp_iters=None): if bp_iters is None: exists, assign = compute_marginals(exists_logits, assign_logits) else: - exists, assign = compute_marginals_bp(exists_logits, assign_logits, bp_iters) + exists, assign = compute_marginals_bp( + exists_logits, assign_logits, bp_iters + ) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. @@ -101,7 +104,10 @@ class MarginalAssignmentSparse: final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ - def __init__(self, num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters): + + def __init__( + self, num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters + ): assert edges.dim() == 2, edges.shape assert edges.shape[0] == 2, edges.shape assert exists_logits.shape == (num_objects,), exists_logits.shape @@ -116,12 +122,17 @@ def __init__(self, num_objects, num_detections, edges, exists_logits, assign_log # This does all the work. exists, assign = compute_marginals_sparse_bp( - num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters) + num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters + ) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. - padded_assign = torch.full((num_detections, num_objects + 1), -float('inf'), - dtype=assign.dtype, device=assign.device) + padded_assign = torch.full( + (num_detections, num_objects + 1), + -float("inf"), + dtype=assign.dtype, + device=assign.device, + ) padded_assign[:, -1] = 0 padded_assign[edges[0], edges[1]] = assign self.assign_dist = dist.Categorical(logits=padded_assign) @@ -165,6 +176,7 @@ class MarginalAssignmentPersistent: final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ + def __init__(self, exists_logits, assign_logits, bp_iters=None, bp_momentum=0.5): assert exists_logits.dim() == 1, exists_logits.shape assert assign_logits.dim() == 3, assign_logits.shape @@ -180,7 +192,8 @@ def __init__(self, exists_logits, assign_logits, bp_iters=None, bp_momentum=0.5) exists, assign = compute_marginals_persistent(exists_logits, assign_logits) else: exists, assign = compute_marginals_persistent_bp( - exists_logits, assign_logits, bp_iters, bp_momentum) + exists_logits, assign_logits, bp_iters, bp_momentum + ) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. @@ -203,11 +216,19 @@ def compute_marginals(exists_logits, assign_logits): dtype = exists_logits.dtype device = exists_logits.device - exists_probs = torch.zeros(2, num_objects, dtype=dtype, device=device) # [not exist, exist] - assign_probs = torch.zeros(num_detections, num_objects + 1, dtype=dtype, device=device) + exists_probs = torch.zeros( + 2, num_objects, dtype=dtype, device=device + ) # [not exist, exist] + assign_probs = torch.zeros( + num_detections, num_objects + 1, dtype=dtype, device=device + ) for assign in itertools.product(range(num_objects + 1), repeat=num_detections): - assign_part = sum(assign_logits[j, i] for j, i in enumerate(assign) if i < num_objects) - for exists in itertools.product(*[[1] if i in assign else [0, 1] for i in range(num_objects)]): + assign_part = sum( + assign_logits[j, i] for j, i in enumerate(assign) if i < num_objects + ) + for exists in itertools.product( + *[[1] if i in assign else [0, 1] for i in range(num_objects)] + ): exists_part = sum(exists_logits[i] for i, e in enumerate(exists) if e) prob = _exp(exists_part + assign_part) for i, e in enumerate(exists): @@ -220,8 +241,8 @@ def compute_marginals(exists_logits, assign_logits): assign = assign_probs.log() exists = exists[1] - exists[0] assign = assign[:, :-1] - assign[:, -1:] - warn_if_nan(exists, 'exists') - warn_if_nan(assign, 'assign') + warn_if_nan(exists, "exists") + warn_if_nan(assign, "assign") return exists, assign @@ -240,22 +261,29 @@ def compute_marginals_bp(exists_logits, assign_logits, bp_iters): message_e_to_a = torch.zeros_like(assign_logits) message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): - message_e_to_a = -(message_a_to_e - message_a_to_e.sum(0, True) - exists_logits).exp().log1p() + message_e_to_a = ( + -(message_a_to_e - message_a_to_e.sum(0, True) - exists_logits) + .exp() + .log1p() + ) joint = (assign_logits + message_e_to_a).exp() - message_a_to_e = (assign_logits - torch.log1p(joint.sum(1, True) - joint)).exp().log1p() - warn_if_nan(message_e_to_a, 'message_e_to_a iter {}'.format(i)) - warn_if_nan(message_a_to_e, 'message_a_to_e iter {}'.format(i)) + message_a_to_e = ( + (assign_logits - torch.log1p(joint.sum(1, True) - joint)).exp().log1p() + ) + warn_if_nan(message_e_to_a, "message_e_to_a iter {}".format(i)) + warn_if_nan(message_a_to_e, "message_a_to_e iter {}".format(i)) # Convert from probs to logits. exists = exists_logits + message_a_to_e.sum(0) assign = assign_logits + message_e_to_a - warn_if_nan(exists, 'exists') - warn_if_nan(assign, 'assign') + warn_if_nan(exists, "exists") + warn_if_nan(assign, "assign") return exists, assign -def compute_marginals_sparse_bp(num_objects, num_detections, edges, - exists_logits, assign_logits, bp_iters): +def compute_marginals_sparse_bp( + num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters +): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1]. @@ -271,8 +299,9 @@ def compute_marginals_sparse_bp(num_objects, num_detections, edges, def sparse_sum(x, dim, keepdim=False): assert dim in (0, 1) - x = (torch.zeros([num_objects, num_detections][dim], dtype=x.dtype, device=x.device) - .scatter_add_(0, edges[1 - dim], x)) + x = torch.zeros( + [num_objects, num_detections][dim], dtype=x.dtype, device=x.device + ).scatter_add_(0, edges[1 - dim], x) if keepdim: x = x[edges[1 - dim]] return x @@ -280,17 +309,25 @@ def sparse_sum(x, dim, keepdim=False): message_e_to_a = torch.zeros_like(assign_logits) message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): - message_e_to_a = -(message_a_to_e - sparse_sum(message_a_to_e, 0, True) - exists_factor).exp().log1p() + message_e_to_a = ( + -(message_a_to_e - sparse_sum(message_a_to_e, 0, True) - exists_factor) + .exp() + .log1p() + ) joint = (assign_logits + message_e_to_a).exp() - message_a_to_e = (assign_logits - torch.log1p(sparse_sum(joint, 1, True) - joint)).exp().log1p() - warn_if_nan(message_e_to_a, 'message_e_to_a iter {}'.format(i)) - warn_if_nan(message_a_to_e, 'message_a_to_e iter {}'.format(i)) + message_a_to_e = ( + (assign_logits - torch.log1p(sparse_sum(joint, 1, True) - joint)) + .exp() + .log1p() + ) + warn_if_nan(message_e_to_a, "message_e_to_a iter {}".format(i)) + warn_if_nan(message_a_to_e, "message_a_to_e iter {}".format(i)) # Convert from probs to logits. exists = exists_logits + sparse_sum(message_a_to_e, 0) assign = assign_logits + message_e_to_a - warn_if_nan(exists, 'exists') - warn_if_nan(assign, 'assign') + warn_if_nan(exists, "exists") + warn_if_nan(assign, "assign") return exists, assign @@ -308,7 +345,9 @@ def compute_marginals_persistent(exists_logits, assign_logits): total = 0 exists_probs = torch.zeros(num_objects, dtype=dtype, device=device) - assign_probs = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) + assign_probs = torch.zeros( + num_frames, num_detections, num_objects, dtype=dtype, device=device + ) for exists in itertools.product([0, 1], repeat=num_objects): exists = [i for i, e in enumerate(exists) if e] exists_part = _exp(sum(exists_logits[i] for i in exists)) @@ -322,7 +361,9 @@ def compute_marginals_persistent(exists_logits, assign_logits): for objects in itertools.combinations(exists, n): for detections in itertools.permutations(range(num_detections), n): assign = tuple(zip(objects, detections)) - assign_map[assign] = _exp(sum(assign_logits[t, j, i] for i, j in assign)) + assign_map[assign] = _exp( + sum(assign_logits[t, j, i] for i, j in assign) + ) assign_parts.append(assign_map) assign_sums.append(sum(assign_map.values())) @@ -331,7 +372,7 @@ def compute_marginals_persistent(exists_logits, assign_logits): for i in exists: exists_probs[i] += prob for t in range(num_frames): - other_part = exists_part * _product(assign_sums[:t] + assign_sums[t + 1:]) + other_part = exists_part * _product(assign_sums[:t] + assign_sums[t + 1 :]) for assign, assign_part in assign_parts[t].items(): prob = other_part * assign_part for i, j in assign: @@ -340,12 +381,14 @@ def compute_marginals_persistent(exists_logits, assign_logits): # Convert from probs to logits. exists = exists_probs.log() - (total - exists_probs).log() assign = assign_probs.log() - (total - assign_probs.sum(-1, True)).log() - warn_if_nan(exists, 'exists') - warn_if_nan(assign, 'assign') + warn_if_nan(exists, "exists") + warn_if_nan(assign, "assign") return exists, assign -def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_momentum=0.5): +def compute_marginals_persistent_bp( + exists_logits, assign_logits, bp_iters, bp_momentum=0.5 +): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1], [2]. @@ -372,31 +415,44 @@ def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_m num_frames, num_detections, num_objects = assign_logits.shape dtype = assign_logits.dtype device = assign_logits.device - message_b_to_a = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) - message_a_to_b = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) + message_b_to_a = torch.zeros( + num_frames, num_detections, num_objects, dtype=dtype, device=device + ) + message_a_to_b = torch.zeros( + num_frames, num_detections, num_objects, dtype=dtype, device=device + ) message_b_to_e = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) message_e_to_b = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) for i in range(bp_iters): odds_a = (assign_logits + message_b_to_a).exp() - message_a_to_b = (old * message_a_to_b + - new * (assign_logits - (odds_a.sum(2, True) - odds_a).log1p())) - message_b_to_e = (old * message_b_to_e + - new * message_a_to_b.exp().sum(1).log1p()) - message_e_to_b = (old * message_e_to_b + - new * (exists_logits + message_b_to_e.sum(0) - message_b_to_e)) + message_a_to_b = old * message_a_to_b + new * ( + assign_logits - (odds_a.sum(2, True) - odds_a).log1p() + ) + message_b_to_e = ( + old * message_b_to_e + new * message_a_to_b.exp().sum(1).log1p() + ) + message_e_to_b = old * message_e_to_b + new * ( + exists_logits + message_b_to_e.sum(0) - message_b_to_e + ) odds_b = message_a_to_b.exp() - message_b_to_a = (old * message_b_to_a - - new * ((-message_e_to_b).exp().unsqueeze(1) + (1 + odds_b.sum(1, True) - odds_b)).log()) - - warn_if_nan(message_a_to_b, 'message_a_to_b iter {}'.format(i)) - warn_if_nan(message_b_to_e, 'message_b_to_e iter {}'.format(i)) - warn_if_nan(message_e_to_b, 'message_e_to_b iter {}'.format(i)) - warn_if_nan(message_b_to_a, 'message_b_to_a iter {}'.format(i)) + message_b_to_a = ( + old * message_b_to_a + - new + * ( + (-message_e_to_b).exp().unsqueeze(1) + + (1 + odds_b.sum(1, True) - odds_b) + ).log() + ) + + warn_if_nan(message_a_to_b, "message_a_to_b iter {}".format(i)) + warn_if_nan(message_b_to_e, "message_b_to_e iter {}".format(i)) + warn_if_nan(message_e_to_b, "message_e_to_b iter {}".format(i)) + warn_if_nan(message_b_to_a, "message_b_to_a iter {}".format(i)) # Convert from probs to logits. exists = exists_logits + message_b_to_e.sum(0) assign = assign_logits + message_b_to_a - warn_if_nan(exists, 'exists') - warn_if_nan(assign, 'assign') + warn_if_nan(exists, "exists") + warn_if_nan(assign, "assign") return exists, assign diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index fc3e61c6c1..ab2c1a8230 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -27,24 +27,37 @@ class EKFDistribution(TorchDistribution): :param dt: time step :type dt: torch.Tensor """ - arg_constraints = {'measurement_cov': constraints.positive_definite, - 'P0': constraints.positive_definite, - 'x0': constraints.real_vector} + arg_constraints = { + "measurement_cov": constraints.positive_definite, + "P0": constraints.positive_definite, + "x0": constraints.real_vector, + } has_rsample = True - def __init__(self, x0, P0, dynamic_model, measurement_cov, time_steps=1, dt=1., validate_args=None): + def __init__( + self, + x0, + P0, + dynamic_model, + measurement_cov, + time_steps=1, + dt=1.0, + validate_args=None, + ): self.x0 = x0 self.P0 = P0 self.dynamic_model = dynamic_model self.measurement_cov = measurement_cov self.dt = dt - assert not x0.shape[-1] % 2, 'position and velocity vectors must be the same dimension' + assert ( + not x0.shape[-1] % 2 + ), "position and velocity vectors must be the same dimension" batch_shape = x0.shape[:-1] event_shape = (time_steps, x0.shape[-1] // 2) super().__init__(batch_shape, event_shape, validate_args=validate_args) def rsample(self, sample_shape=torch.Size()): - raise NotImplementedError('TODO: implement forward filter backward sample') + raise NotImplementedError("TODO: implement forward filter backward sample") def filter_states(self, value): """ @@ -54,13 +67,14 @@ def filter_states(self, value): :type value: torch.Tensor """ states = [] - state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.) + state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.0) assert value.shape[-1] == self.event_shape[-1] for i, measurement_mean in enumerate(value): if i: state = state.predict(self.dt) - measurement = PositionMeasurement(measurement_mean, self.measurement_cov, - time=state.time) + measurement = PositionMeasurement( + measurement_mean, self.measurement_cov, time=state.time + ) state, (dz, S) = state.update(measurement) states.append(state) return states @@ -72,15 +86,16 @@ def log_prob(self, value): :param value: measurement means of shape `(time_steps, event_shape)` :type value: torch.Tensor """ - state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.) - result = 0. + state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.0) + result = 0.0 assert value.shape == self.event_shape zero = torch.zeros(self.event_shape[-1], dtype=value.dtype, device=value.device) for i, measurement_mean in enumerate(value): if i: state = state.predict(self.dt) - measurement = PositionMeasurement(measurement_mean, self.measurement_cov, - time=state.time) + measurement = PositionMeasurement( + measurement_mean, self.measurement_cov, time=state.time + ) state, (dz, S) = state.update(measurement) result = result + dist.MultivariateNormal(dz, S).log_prob(zero) return result diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index 7ea41ad3a7..3b212d8a3a 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -12,7 +12,7 @@ class DynamicModel(nn.Module, metaclass=ABCMeta): - ''' + """ Dynamic model interface. :param dimension: native state dimension. @@ -20,7 +20,8 @@ class DynamicModel(nn.Module, metaclass=ABCMeta): :param num_process_noise_parameters: process noise parameter space dimension. This for UKF applications. Can be left as ``None`` for EKF and most other filters. - ''' + """ + def __init__(self, dimension, dimension_pv, num_process_noise_parameters=None): self._dimension = dimension self._dimension_pv = dimension_pv @@ -29,28 +30,28 @@ def __init__(self, dimension, dimension_pv, num_process_noise_parameters=None): @property def dimension(self): - ''' + """ Native state dimension access. - ''' + """ return self._dimension @property def dimension_pv(self): - ''' + """ PV state dimension access. - ''' + """ return self._dimension_pv @property def num_process_noise_parameters(self): - ''' + """ Process noise parameters space dimension access. - ''' + """ return self._num_process_noise_parameters @abstractmethod def forward(self, x, dt, do_normalization=True): - ''' + """ Integrate native state ``x`` over time interval ``dt``. :param x: current native state. If the DynamicModel is non-differentiable, @@ -60,47 +61,47 @@ def forward(self, x, dt, do_normalization=True): :param do_normalization: whether to perform normalization on output, e.g., mod'ing angles into an interval. :return: Native state x integrated dt into the future. - ''' + """ raise NotImplementedError def geodesic_difference(self, x1, x0): - ''' + """ Compute and return the geodesic difference between 2 native states. This is a generalization of the Euclidean operation ``x1 - x0``. :param x1: native state. :param x0: native state. :return: Geodesic difference between native states ``x1`` and ``x2``. - ''' + """ return x1 - x0 # Default to Euclidean behavior. @abstractmethod def mean2pv(self, x): - ''' + """ Compute and return PV state from native state. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param x: native state estimate mean. :return: PV state estimate mean. - ''' + """ raise NotImplementedError @abstractmethod def cov2pv(self, P): - ''' + """ Compute and return PV covariance from native covariance. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param P: native state estimate covariance. :return: PV state estimate covariance. - ''' + """ raise NotImplementedError @abstractmethod - def process_noise_cov(self, dt=0.): - ''' + def process_noise_cov(self, dt=0.0): + """ Compute and return process noise covariance (Q). :param dt: time interval to integrate over. @@ -108,40 +109,43 @@ def process_noise_cov(self, dt=0.): the covariance of the native state ``x`` resulting from stochastic integration (for use with EKF). Otherwise, it is the covariance directly of the process noise parameters (for use with UKF). - ''' + """ raise NotImplementedError - def process_noise_dist(self, dt=0.): - ''' + def process_noise_dist(self, dt=0.0): + """ Return a distribution object of state displacement from the process noise distribution over a time interval. :param dt: time interval that process noise accumulates over. :return: :class:`~pyro.distributions.torch.MultivariateNormal`. - ''' + """ Q = self.process_noise_cov(dt) - return dist.MultivariateNormal(torch.zeros(Q.shape[-1], dtype=Q.dtype, device=Q.device), Q) + return dist.MultivariateNormal( + torch.zeros(Q.shape[-1], dtype=Q.dtype, device=Q.device), Q + ) class DifferentiableDynamicModel(DynamicModel): - ''' + """ DynamicModel for which state transition Jacobians can be efficiently calculated, usu. analytically or by automatic differentiation. - ''' + """ + @abstractmethod def jacobian(self, dt): - ''' + """ Compute and return native state transition Jacobian (F) over time interval ``dt``. :param dt: time interval to integrate over. :return: Read-only Jacobian (F) of integration map (f). - ''' + """ raise NotImplementedError class Ncp(DifferentiableDynamicModel): - ''' + """ NCP (Nearly-Constant Position) dynamic model. May be subclassed, e.g., with CWNV (Continuous White Noise Velocity) or DWNV (Discrete White Noise Velocity). @@ -150,7 +154,8 @@ class Ncp(DifferentiableDynamicModel): :param sv2: variance of velocity. Usually chosen so that the standard deviation is roughly half of the max velocity one would ever expect to observe. - ''' + """ + def __init__(self, dimension, sv2): dimension_pv = 2 * dimension super().__init__(dimension, dimension_pv, num_process_noise_parameters=1) @@ -161,7 +166,7 @@ def __init__(self, dimension, sv2): self._Q_cache = {} # Process noise cov cache def forward(self, x, dt, do_normalization=True): - ''' + """ Integrate native state ``x`` over time interval ``dt``. :param x: current native state. If the DynamicModel is non-differentiable, @@ -171,62 +176,62 @@ def forward(self, x, dt, do_normalization=True): do_normalization: whether to perform normalization on output, e.g., mod'ing angles into an interval. Has no effect for this subclass. :return: Native state x integrated dt into the future. - ''' + """ return x def mean2pv(self, x): - ''' + """ Compute and return PV state from native state. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param x: native state estimate mean. :return: PV state estimate mean. - ''' + """ with torch.no_grad(): x_pv = torch.zeros(2 * self._dimension, dtype=x.dtype, device=x.device) - x_pv[:self._dimension] = x + x_pv[: self._dimension] = x return x_pv def cov2pv(self, P): - ''' + """ Compute and return PV covariance from native covariance. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param P: native state estimate covariance. :return: PV state estimate covariance. - ''' - d = 2*self._dimension + """ + d = 2 * self._dimension with torch.no_grad(): P_pv = torch.zeros(d, d, dtype=P.dtype, device=P.device) - P_pv[:self._dimension, :self._dimension] = P + P_pv[: self._dimension, : self._dimension] = P return P_pv def jacobian(self, dt): - ''' + """ Compute and return cached native state transition Jacobian (F) over time interval ``dt``. :param dt: time interval to integrate over. :return: Read-only Jacobian (F) of integration map (f). - ''' + """ return self._F_cache @abstractmethod - def process_noise_cov(self, dt=0.): - ''' + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). - ''' + """ raise NotImplementedError class Ncv(DifferentiableDynamicModel): - ''' + """ NCV (Nearly-Constant Velocity) dynamic model. May be subclassed, e.g., with CWNA (Continuous White Noise Acceleration) or DWNA (Discrete White Noise Acceleration). @@ -235,7 +240,8 @@ class Ncv(DifferentiableDynamicModel): :param sa2: variance of acceleration. Usually chosen so that the standard deviation is roughly half of the max acceleration one would ever expect to observe. - ''' + """ + def __init__(self, dimension, sa2): dimension_pv = dimension super().__init__(dimension, dimension_pv, num_process_noise_parameters=1) @@ -246,7 +252,7 @@ def __init__(self, dimension, sa2): self._Q_cache = {} # Process noise cov cache def forward(self, x, dt, do_normalization=True): - ''' + """ Integrate native state ``x`` over time interval ``dt``. :param x: current native state. If the DynamicModel is non-differentiable, @@ -257,63 +263,63 @@ def forward(self, x, dt, do_normalization=True): mod'ing angles into an interval. Has no effect for this subclass. :return: Native state x integrated dt into the future. - ''' + """ F = self.jacobian(dt) return F.mm(x.unsqueeze(1)).squeeze(1) def mean2pv(self, x): - ''' + """ Compute and return PV state from native state. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param x: native state estimate mean. :return: PV state estimate mean. - ''' + """ return x def cov2pv(self, P): - ''' + """ Compute and return PV covariance from native covariance. Useful for combining state estimates of different types in IMM (Interacting Multiple Model) filtering. :param P: native state estimate covariance. :return: PV state estimate covariance. - ''' + """ return P def jacobian(self, dt): - ''' + """ Compute and return cached native state transition Jacobian (F) over time interval ``dt``. :param dt: time interval to integrate over. :return: Read-only Jacobian (F) of integration map (f). - ''' + """ if dt not in self._F_cache: d = self._dimension with torch.no_grad(): F = eye_like(self.sa2, d) - F[:d//2, d//2:] = dt * eye_like(self.sa2, d//2) + F[: d // 2, d // 2 :] = dt * eye_like(self.sa2, d // 2) self._F_cache[dt] = F return self._F_cache[dt] @abstractmethod - def process_noise_cov(self, dt=0.): - ''' + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). - ''' + """ raise NotImplementedError class NcpContinuous(Ncp): - ''' + """ NCP (Nearly-Constant Position) dynamic model with CWNV (Continuous White Noise Velocity). @@ -325,15 +331,16 @@ class NcpContinuous(Ncp): :param sv2: variance of velocity. Usually chosen so that the standard deviation is roughly half of the max velocity one would ever expect to observe. - ''' - def process_noise_cov(self, dt=0.): - ''' + """ + + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). - ''' + """ if dt not in self._Q_cache: # q: continuous-time process noise intensity with units # length^2/time (m^2/s). Choose ``q`` so that changes in position, @@ -346,7 +353,7 @@ def process_noise_cov(self, dt=0.): class NcvContinuous(Ncv): - ''' + """ NCV (Nearly-Constant Velocity) dynamic model with CWNA (Continuous White Noise Acceleration). @@ -358,16 +365,17 @@ class NcvContinuous(Ncv): :param sa2: variance of acceleration. Usually chosen so that the standard deviation is roughly half of the max acceleration one would ever expect to observe. - ''' - def process_noise_cov(self, dt=0.): - ''' + """ + + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). - ''' + """ if dt not in self._Q_cache: with torch.no_grad(): @@ -375,11 +383,11 @@ def process_noise_cov(self, dt=0.): dt2 = dt * dt dt3 = dt2 * dt Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) - eye = eye_like(self.sa2, d//2) - Q[:d//2, :d//2] = dt3 * eye / 3.0 - Q[:d//2, d//2:] = dt2 * eye / 2.0 - Q[d//2:, :d//2] = dt2 * eye / 2.0 - Q[d//2:, d//2:] = dt * eye + eye = eye_like(self.sa2, d // 2) + Q[: d // 2, : d // 2] = dt3 * eye / 3.0 + Q[: d // 2, d // 2 :] = dt2 * eye / 2.0 + Q[d // 2 :, : d // 2] = dt2 * eye / 2.0 + Q[d // 2 :, d // 2 :] = dt * eye # sa2 * dt is an intensity factor that changes in velocity # over a sampling period ``dt``, ideally should be ~``sqrt(q*dt)``. Q = Q * (self.sa2 * dt) @@ -389,7 +397,7 @@ def process_noise_cov(self, dt=0.): class NcpDiscrete(Ncp): - ''' + """ NCP (Nearly-Constant Position) dynamic model with DWNV (Discrete White Noise Velocity). @@ -401,15 +409,16 @@ class NcpDiscrete(Ncp): References: "Estimation with Applications to Tracking and Navigation" by Y. Bar- Shalom et al, 2001, p.273. - ''' - def process_noise_cov(self, dt=0.): - ''' + """ + + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state `x` resulting from stochastic integration (for use with EKF). - ''' + """ if dt not in self._Q_cache: Q = self.sv2 * dt * dt * eye_like(self.sv2, self._dimension) self._Q_cache[dt] = Q @@ -418,7 +427,7 @@ def process_noise_cov(self, dt=0.): class NcvDiscrete(Ncv): - ''' + """ NCV (Nearly-Constant Velocity) dynamic model with DWNA (Discrete White Noise Acceleration). @@ -430,9 +439,10 @@ class NcvDiscrete(Ncv): References: "Estimation with Applications to Tracking and Navigation" by Y. Bar- Shalom et al, 2001, p.273. - ''' - def process_noise_cov(self, dt=0.): - ''' + """ + + def process_noise_cov(self, dt=0.0): + """ Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. @@ -440,18 +450,18 @@ def process_noise_cov(self, dt=0.): stochastic integration (for use with EKF). (Note that this Q, modulo numerical error, has rank `dimension/2`. So, it is only positive semi-definite.) - ''' + """ if dt not in self._Q_cache: with torch.no_grad(): d = self._dimension - dt2 = dt*dt - dt3 = dt2*dt - dt4 = dt2*dt2 + dt2 = dt * dt + dt3 = dt2 * dt + dt4 = dt2 * dt2 Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) - Q[:d//2, :d//2] = 0.25 * dt4 * eye_like(self.sa2, d//2) - Q[:d//2, d//2:] = 0.5 * dt3 * eye_like(self.sa2, d//2) - Q[d//2:, :d//2] = 0.5 * dt3 * eye_like(self.sa2, d//2) - Q[d//2:, d//2:] = dt2 * eye_like(self.sa2, d//2) + Q[: d // 2, : d // 2] = 0.25 * dt4 * eye_like(self.sa2, d // 2) + Q[: d // 2, d // 2 :] = 0.5 * dt3 * eye_like(self.sa2, d // 2) + Q[d // 2 :, : d // 2] = 0.5 * dt3 * eye_like(self.sa2, d // 2) + Q[d // 2 :, d // 2 :] = dt2 * eye_like(self.sa2, d // 2) Q = Q * self.sa2 self._Q_cache[dt] = Q diff --git a/pyro/contrib/tracking/extended_kalman_filter.py b/pyro/contrib/tracking/extended_kalman_filter.py index e8d2a431cf..68cbae0b66 100644 --- a/pyro/contrib/tracking/extended_kalman_filter.py +++ b/pyro/contrib/tracking/extended_kalman_filter.py @@ -9,7 +9,7 @@ class EKFState: - ''' + """ State-Centric EKF (Extended Kalman Filter) for use with either an NCP (Nearly-Constant Position) or NCV (Nearly-Constant Velocity) target dynamic model. Stores a target dynamic model, state estimate, and state time. @@ -22,81 +22,82 @@ class EKFState: :param mean: mean of target state estimate. :param cov: covariance of target state estimate. :param time: time of state estimate. - ''' + """ + def __init__(self, dynamic_model, mean, cov, time=None, frame_num=None): self._dynamic_model = dynamic_model self._mean = mean self._cov = cov if time is None and frame_num is None: - raise ValueError('Must provide time or frame_num!') + raise ValueError("Must provide time or frame_num!") self._time = time self._frame_num = frame_num @property def dynamic_model(self): - ''' + """ Dynamic model access. - ''' + """ return self._dynamic_model @property def dimension(self): - ''' + """ Native state dimension access. - ''' + """ return self._dynamic_model.dimension @property def mean(self): - ''' + """ Native state estimate mean access. - ''' + """ return self._mean @property def cov(self): - ''' + """ Native state estimate covariance access. - ''' + """ return self._cov @property def dimension_pv(self): - ''' + """ PV state dimension access. - ''' + """ return self._dynamic_model.dimension_pv @lazy_property def mean_pv(self): - ''' + """ Compute and return cached PV state estimate mean. - ''' + """ return self._dynamic_model.mean2pv(self._mean) @lazy_property def cov_pv(self): - ''' + """ Compute and return cached PV state estimate covariance. - ''' + """ return self._dynamic_model.cov2pv(self._cov) @property def time(self): - ''' + """ Continuous State time access. - ''' + """ return self._time @property def frame_num(self): - ''' + """ Discrete State time access. - ''' + """ return self._frame_num def predict(self, dt=None, destination_time=None, destination_frame_num=None): - ''' + """ Use dynamic model to predict (aka propagate aka integrate) state estimate in-place. @@ -110,7 +111,7 @@ def predict(self, dt=None, destination_time=None, destination_frame_num=None): :param destination_frame_num: optional value to set discrete state time to after integration. If this is not provided, then `destination_frame_num` must be. - ''' + """ assert (dt is None) ^ (destination_time is None) if dt is None: dt = destination_time - self._time @@ -123,13 +124,20 @@ def predict(self, dt=None, destination_time=None, destination_frame_num=None): pred_cov = F.mm(self._cov).mm(F.transpose(-1, -2)) + Q if destination_time is None and destination_frame_num is None: - raise ValueError('destination_time or destination_frame_num must be specified!') - - return EKFState(self._dynamic_model, pred_mean, pred_cov, - destination_time, destination_frame_num) + raise ValueError( + "destination_time or destination_frame_num must be specified!" + ) + + return EKFState( + self._dynamic_model, + pred_mean, + pred_cov, + destination_time, + destination_frame_num, + ) def innovation(self, measurement): - ''' + """ Compute and return the innovation that a measurement would induce if it were used for an update, but don't actually perform the update. Assumes state and measurement are time-aligned. Useful for computing @@ -138,13 +146,14 @@ def innovation(self, measurement): :param measurement: measurement :return: Innovation mean and covariance of hypothetical update. :rtype: tuple(``torch.Tensor``, ``torch.Tensor``) - ''' - assert self._time == measurement.time, \ - 'State time and measurement time must be aligned!' + """ + assert ( + self._time == measurement.time + ), "State time and measurement time must be aligned!" # Compute innovation. x_pv = self._dynamic_model.mean2pv(self._mean) - H = measurement.jacobian(x_pv)[:, :self.dimension] + H = measurement.jacobian(x_pv)[:, : self.dimension] R = measurement.cov z = measurement.mean z_predicted = measurement(x_pv) @@ -154,7 +163,7 @@ def innovation(self, measurement): return dz, S def log_likelihood_of_update(self, measurement): - ''' + """ Compute and return the likelihood of a potential update, but don't actually perform the update. Assumes state and measurement are time- aligned. Useful for gating and calculating costs in assignment problems @@ -162,13 +171,14 @@ def log_likelihood_of_update(self, measurement): :param: measurement. :return: Likelihood of hypothetical update. - ''' + """ dz, S = self.innovation(measurement) - return dist.MultivariateNormal(torch.zeros(S.size(-1), dtype=S.dtype, device=S.device), - S).log_prob(dz) + return dist.MultivariateNormal( + torch.zeros(S.size(-1), dtype=S.dtype, device=S.device), S + ).log_prob(dz) def update(self, measurement): - ''' + """ Use measurement to update state estimate in-place and return innovation. The innovation is useful, e.g., for evaluating filter consistency or updating model likelihoods when the ``EKFState`` is part @@ -176,18 +186,20 @@ def update(self, measurement): :param: measurement. :returns: EKF State, Innovation mean and covariance. - ''' + """ if self._time is not None: - assert self._time == measurement.time, \ - 'State time and measurement time must be aligned!' + assert ( + self._time == measurement.time + ), "State time and measurement time must be aligned!" if self._frame_num is not None: - assert self._frame_num == measurement.frame_num, \ - 'State time and measurement time must be aligned!' + assert ( + self._frame_num == measurement.frame_num + ), "State time and measurement time must be aligned!" x = self._mean x_pv = self._dynamic_model.mean2pv(x) P = self.cov - H = measurement.jacobian(x_pv)[:, :self.dimension] + H = measurement.jacobian(x_pv)[:, : self.dimension] R = measurement.cov z = measurement.mean z_predicted = measurement(x_pv) @@ -202,11 +214,14 @@ def update(self, measurement): ImKH = I - K_prefix.mm(torch.linalg.solve(S, H)) # *Joseph form* of covariance update for numerical stability. S_inv_R = torch.linalg.solve(S, R) - P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) \ - + K_prefix.mm(torch.linalg.solve(S, K_prefix.mm(S_inv_R).transpose(-1, -2))) + P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) + K_prefix.mm( + torch.linalg.solve(S, K_prefix.mm(S_inv_R).transpose(-1, -2)) + ) pred_mean = x pred_cov = P - state = EKFState(self._dynamic_model, pred_mean, pred_cov, self._time, self._frame_num) + state = EKFState( + self._dynamic_model, pred_mean, pred_cov, self._time, self._frame_num + ) return state, (dz, S) diff --git a/pyro/contrib/tracking/hashing.py b/pyro/contrib/tracking/hashing.py index 1d650b7fe0..e0b3130112 100644 --- a/pyro/contrib/tracking/hashing.py +++ b/pyro/contrib/tracking/hashing.py @@ -44,9 +44,12 @@ class LSH: :param float radius: Scaling parameter used in hash function. Determines the size of the neighbourhood. """ + def __init__(self, radius): if not (isinstance(radius, Number) and radius > 0): - raise ValueError("radius must be float greater than 0, given: {}".format(radius)) + raise ValueError( + "radius must be float greater than 0, given: {}".format(radius) + ) self._radius = radius self._hash_to_key = defaultdict(set) self._key_to_hash = {} @@ -112,9 +115,12 @@ class ApproxSet: :param float radius: scaling parameter used in hash function. Determines the size of the bin. See :class:`LSH` for details. """ + def __init__(self, radius): if not (isinstance(radius, Number) and radius > 0): - raise ValueError("radius must be float greater than 0, given: {}".format(radius)) + raise ValueError( + "radius must be float greater than 0, given: {}".format(radius) + ) self._radius = radius self._bins = set() @@ -156,10 +162,16 @@ def merge_points(points, radius): :rtype: tuple """ if points.dim() != 2: - raise ValueError('Expected points.shape == (K,D), but got {}'.format(points.shape)) + raise ValueError( + "Expected points.shape == (K,D), but got {}".format(points.shape) + ) if not (isinstance(radius, Number) and radius > 0): - raise ValueError('Expected radius to be a positive number, but got {}'.format(radius)) - radius = 0.99 * radius # avoid merging points exactly radius apart, e.g. grid points + raise ValueError( + "Expected radius to be a positive number, but got {}".format(radius) + ) + radius = ( + 0.99 * radius + ) # avoid merging points exactly radius apart, e.g. grid points threshold = radius ** 2 # setup data structures to cheaply search for nearest pairs diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index 8f24ea4360..0df914241d 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -9,7 +9,7 @@ class Measurement(object, metaclass=ABCMeta): - ''' + """ Gaussian measurement interface. :param mean: mean of measurement distribution. @@ -18,54 +18,55 @@ class Measurement(object, metaclass=ABCMeta): provided, `frame_num` must be. :param frame_num: discrete time of measurement. If this is not provided, `time` must be. - ''' + """ + def __init__(self, mean, cov, time=None, frame_num=None): self._dimension = len(mean) self._mean = mean self._cov = cov if time is None and frame_num is None: - raise ValueError('Must provide time or frame_num!') + raise ValueError("Must provide time or frame_num!") self._time = time self._frame_num = frame_num @property def dimension(self): - ''' + """ Measurement space dimension access. - ''' + """ return self._dimension @property def mean(self): - ''' + """ Measurement mean (``z`` in most Kalman Filtering literature). - ''' + """ return self._mean @property def cov(self): - ''' + """ Noise covariance (``R`` in most Kalman Filtering literature). - ''' + """ return self._cov @property def time(self): - ''' + """ Continuous time of measurement. - ''' + """ return self._time @property def frame_num(self): - ''' + """ Discrete time of measurement. - ''' + """ return self._frame_num @abstractmethod def __call__(self, x, do_normalization=True): - ''' + """ Measurement map (h) for predicting a measurement ``z`` from target state ``x``. @@ -73,55 +74,63 @@ def __call__(self, x, do_normalization=True): :param do_normalization: whether to normalize output, e.g., mod'ing angles into an interval. :return Measurement predicted from state ``x``. - ''' + """ raise NotImplementedError def geodesic_difference(self, z1, z0): - ''' + """ Compute and return the geodesic difference between 2 measurements. This is a generalization of the Euclidean operation ``z1 - z0``. :param z1: measurement. :param z0: measurement. :return: Geodesic difference between ``z1`` and ``z2``. - ''' + """ return z1 - z0 # Default to Euclidean behavior. class DifferentiableMeasurement(Measurement): - ''' + """ Interface for Gaussian measurement for which Jacobians can be efficiently calculated, usu. analytically or by automatic differentiation. - ''' + """ + @abstractmethod def jacobian(self, x=None): - ''' + """ Compute and return Jacobian (H) of measurement map (h) at target PV state ``x`` . :param x: PV state. Use default argument ``None`` when the Jacobian is not state-dependent. :return: Read-only Jacobian (H) of measurement map (h). - ''' + """ raise NotImplementedError class PositionMeasurement(DifferentiableMeasurement): - ''' + """ Full-rank Gaussian position measurement in Euclidean space. :param mean: mean of measurement distribution. :param cov: covariance of measurement distribution. :param time: time of measurement. - ''' + """ + def __init__(self, mean, cov, time=None, frame_num=None): super().__init__(mean, cov, time=time, frame_num=frame_num) - self._jacobian = torch.cat([ - eye_like(mean, self.dimension), - torch.zeros(self.dimension, self.dimension, dtype=mean.dtype, device=mean.device)], dim=1) + self._jacobian = torch.cat( + [ + eye_like(mean, self.dimension), + torch.zeros( + self.dimension, self.dimension, dtype=mean.dtype, device=mean.device + ), + ], + dim=1, + ) def __call__(self, x, do_normalization=True): - ''' + """ Measurement map (h) for predicting a measurement ``z`` from target state ``x``. @@ -129,16 +138,16 @@ def __call__(self, x, do_normalization=True): :param do_normalization: whether to normalize output. Has no effect for this subclass. :return: Measurement predicted from state ``x``. - ''' - return x[:self._dimension] + """ + return x[: self._dimension] def jacobian(self, x=None): - ''' + """ Compute and return Jacobian (H) of measurement map (h) at target PV state ``x`` . :param x: PV state. The default argument ``None`` may be used in this subclass since the Jacobian is not state-dependent. :return: Read-only Jacobian (H) of measurement map (h). - ''' + """ return self._jacobian diff --git a/pyro/contrib/util.py b/pyro/contrib/util.py index e250639ca7..df4f08f697 100644 --- a/pyro/contrib/util.py +++ b/pyro/contrib/util.py @@ -14,7 +14,7 @@ def get_indices(labels, sizes=None, tensors=None): if sizes is None: sizes = OrderedDict([(l, t.shape[0]) for l, t in tensors.items()]) for label in sizes: - end = start+sizes[label] + end = start + sizes[label] if label in labels: indices.extend(range(start, end)) start = end @@ -56,19 +56,19 @@ def lexpand(A, *dimensions): def rexpand(A, *dimensions): """Expand tensor, adding new dimensions on right.""" - return A.view(A.shape + (1,)*len(dimensions)).expand(A.shape + tuple(dimensions)) + return A.view(A.shape + (1,) * len(dimensions)).expand(A.shape + tuple(dimensions)) def rdiag(v): """Converts the rightmost dimension to a diagonal matrix.""" - return rexpand(v, v.shape[-1])*torch.eye(v.shape[-1]) + return rexpand(v, v.shape[-1]) * torch.eye(v.shape[-1]) def rtril(M, diagonal=0, upper=False): """Takes the lower-triangular of the rightmost 2 dimensions.""" if upper: return rtril(M, diagonal=diagonal, upper=False).transpose(-1, -2) - return M*torch.tril(torch.ones(M.shape[-2], M.shape[-1]), diagonal=diagonal) + return M * torch.tril(torch.ones(M.shape[-2], M.shape[-1]), diagonal=diagonal) def iter_plates_to_shape(shape): diff --git a/pyro/distributions/asymmetriclaplace.py b/pyro/distributions/asymmetriclaplace.py index f14218b3f1..1f0f405083 100644 --- a/pyro/distributions/asymmetriclaplace.py +++ b/pyro/distributions/asymmetriclaplace.py @@ -24,9 +24,12 @@ class AsymmetricLaplace(TorchDistribution): :param scale: Scale parameter = geometric mean of left and right scales. :param asymmetry: Square of ratio of left to right scales. """ - arg_constraints = {"loc": constraints.real, - "scale": constraints.positive, - "asymmetry": constraints.positive} + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "asymmetry": constraints.positive, + } support = constraints.real has_rsample = True @@ -105,16 +108,22 @@ class SoftAsymmetricLaplace(TorchDistribution): :param asymmetry: Square of ratio of left to right scales. Defaults to 1. :param softness: Scale parameter of the Gaussian smoother. Defaults to 1. """ - arg_constraints = {"loc": constraints.real, - "scale": constraints.positive, - "asymmetry": constraints.positive, - "softness": constraints.positive} + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "asymmetry": constraints.positive, + "softness": constraints.positive, + } support = constraints.real has_rsample = True def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validate_args=None): self.loc, self.scale, self.asymmetry, self.softness = broadcast_all( - loc, scale, asymmetry, softness, + loc, + scale, + asymmetry, + softness, ) super().__init__(self.loc.shape, validate_args=validate_args) @@ -160,17 +169,23 @@ def log_prob(self, value): # = 1/2 e^((2 L x + S^2)/(2 L^2)) erfc((L x + S^2)/(sqrt(2) L S)) # right = Integrate[e^(-t/R - ((x-t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}] # = 1/2 e^((S^2 - 2 R x)/(2 R^2)) erfc((S^2 - R x)/(sqrt(2) R S)) - return math.log(0.5) + torch.logaddexp( - (SS / 2 + Lx) / L ** 2 + _logerfc((SS + Lx) / (L * S2)), - (SS / 2 - Rx) / R ** 2 + _logerfc((SS - Rx) / (R * S2)), - ) - (L + R).log() - self.scale.log() + return ( + math.log(0.5) + + torch.logaddexp( + (SS / 2 + Lx) / L ** 2 + _logerfc((SS + Lx) / (L * S2)), + (SS / 2 - Rx) / R ** 2 + _logerfc((SS - Rx) / (R * S2)), + ) + - (L + R).log() + - self.scale.log() + ) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) z = self.loc.new_empty(shape).normal_() u, v = self.loc.new_empty((2,) + shape).exponential_() - return (self.loc + self.soft_scale * z - self.left_scale * u - + self.right_scale * v) + return ( + self.loc + self.soft_scale * z - self.left_scale * u + self.right_scale * v + ) @property def mean(self): @@ -184,8 +199,9 @@ def variance(self): total = left + right p = left / total q = right / total - return (p * left ** 2 + q * right ** 2 + p * q * total ** 2 - + self.soft_scale ** 2) + return ( + p * left ** 2 + q * right ** 2 + p * q * total ** 2 + self.soft_scale ** 2 + ) def _logerfc(x): diff --git a/pyro/distributions/avf_mvn.py b/pyro/distributions/avf_mvn.py index ba57fd784f..d5988f385b 100644 --- a/pyro/distributions/avf_mvn.py +++ b/pyro/distributions/avf_mvn.py @@ -38,21 +38,33 @@ class AVFMultivariateNormal(MultivariateNormal): opt_cv.zero_grad() """ - arg_constraints = {"loc": constraints.real, "scale_tril": constraints.lower_triangular, - "control_var": constraints.real} + + arg_constraints = { + "loc": constraints.real, + "scale_tril": constraints.lower_triangular, + "control_var": constraints.real, + } def __init__(self, loc, scale_tril, control_var): if loc.dim() != 1: raise ValueError("AVFMultivariateNormal loc must be 1-dimensional") if scale_tril.dim() != 2: raise ValueError("AVFMultivariateNormal scale_tril must be 2-dimensional") - if control_var.dim() != 3 or control_var.size(0) != 2 or control_var.size(2) != loc.size(0): - raise ValueError("control_var should be of size 2 x L x D, where D is the dimension of the location parameter loc") # noqa: E501 + if ( + control_var.dim() != 3 + or control_var.size(0) != 2 + or control_var.size(2) != loc.size(0) + ): + raise ValueError( + "control_var should be of size 2 x L x D, where D is the dimension of the location parameter loc" + ) # noqa: E501 self.control_var = control_var super().__init__(loc, scale_tril=scale_tril) def rsample(self, sample_shape=torch.Size()): - return _AVFMVNSample.apply(self.loc, self.scale_tril, self.control_var, sample_shape + self.loc.shape) + return _AVFMVNSample.apply( + self.loc, self.scale_tril, self.control_var, sample_shape + self.loc.shape + ) class _AVFMVNSample(Function): @@ -86,7 +98,9 @@ def backward(ctx, grad_output): # compute control_var grads diff_B = (L_grad.unsqueeze(0) * C.unsqueeze(-2) * xi_ab.unsqueeze(0)).sum(2) - diff_C = (L_grad.t().unsqueeze(0) * B.unsqueeze(-2) * xi_ab.t().unsqueeze(0)).sum(2) + diff_C = ( + L_grad.t().unsqueeze(0) * B.unsqueeze(-2) * xi_ab.t().unsqueeze(0) + ).sum(2) diff_CV = torch.stack([diff_B, diff_C]) return loc_grad, L_grad, diff_CV, None diff --git a/pyro/distributions/coalescent.py b/pyro/distributions/coalescent.py index 87b2e362e7..817bc6486e 100644 --- a/pyro/distributions/coalescent.py +++ b/pyro/distributions/coalescent.py @@ -59,12 +59,11 @@ class CoalescentTimes(TorchDistribution): :param torch.Tensor rate: Base coalescent rate (pairwise rate of coalescence) under a constant population size model. Defaults to 1. """ - arg_constraints = {"leaf_times": constraints.real, - "rate": constraints.positive} - def __init__(self, leaf_times, rate=1., *, validate_args=None): - rate = torch.as_tensor(rate, dtype=leaf_times.dtype, - device=leaf_times.device) + arg_constraints = {"leaf_times": constraints.real, "rate": constraints.positive} + + def __init__(self, leaf_times, rate=1.0, *, validate_args=None): + rate = torch.as_tensor(rate, dtype=leaf_times.dtype, device=leaf_times.device) batch_shape = broadcast_shape(rate.shape, leaf_times.shape[:-1]) event_shape = (leaf_times.size(-1) - 1,) self.leaf_times = leaf_times @@ -85,8 +84,9 @@ def log_prob(self, value): # in the number of lineages, which changes at each event. binomial = phylogeny.binomial[..., :-1] interval = phylogeny.times[..., :-1] - phylogeny.times[..., 1:] - log_prob = (self.rate.log() * coal_times.size(-1) - - self.rate * (binomial * interval).sum(-1)) + log_prob = self.rate.log() * coal_times.size(-1) - self.rate * ( + binomial * interval + ).sum(-1) # Scaling by those rates and accounting for log|jacobian|, the density # is that of a collection of independent Exponential intervals. @@ -140,8 +140,11 @@ class CoalescentTimesWithRate(TorchDistribution): ``beta S / I``. The rightmost dimension is time, and this tensor represents a (batch of) rates that are piecwise constant in time. """ - arg_constraints = {"leaf_times": constraints.real, - "rate_grid": constraints.positive} + + arg_constraints = { + "leaf_times": constraints.real, + "rate_grid": constraints.positive, + } def __init__(self, leaf_times, rate_grid, *, validate_args=None): batch_shape = broadcast_shape(leaf_times.shape[:-1], rate_grid.shape[:-1]) @@ -163,7 +166,8 @@ def expand(self, batch_shape, _instance=None): new.leaf_times = self.leaf_times new.rate_grid = self.rate_grid super(CoalescentTimesWithRate, new).__init__( - batch_shape, self.event_shape, validate_args=False) + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self.__dict__.get("_validate_args") return new @@ -189,7 +193,9 @@ def log_prob(self, value): # Compute survival factors for closed intervals. cumsum = self.rate_grid.cumsum(-1) cumsum = torch.nn.functional.pad(cumsum, (1, 0), value=0) - integral = _interpolate_gather(cumsum, phylogeny.times[..., 1:]) # ignore the final lonely leaf + integral = _interpolate_gather( + cumsum, phylogeny.times[..., 1:] + ) # ignore the final lonely leaf integral = integral[..., :-1] - integral[..., 1:] integral = integral.clamp(min=torch.finfo(integral.dtype).tiny) # avoid nan log_prob = -(phylogeny.binomial[..., 1:-1] * integral).sum(-1) @@ -239,6 +245,7 @@ class CoalescentRateLikelihood: and should be sorted along that dimension. :param int duration: Size of the rate grid, ``rate_grid.size(-1)``. """ + def __init__(self, leaf_times, coal_times, duration, *, validate_args=None): assert leaf_times.size(-1) == 1 + coal_times.size(-1) assert isinstance(duration, int) and duration >= 2 @@ -259,9 +266,14 @@ def __init__(self, leaf_times, coal_times, duration, *, validate_args=None): times = phylogeny.times.clamp(min=duration) intervals = times[..., 1:] - times[..., :-1] post_linear = (phylogeny.binomial[..., :-1] * intervals).sum(-1, keepdim=True) - self._linear = torch.cat([pre_linear, - new_zeros(pre_linear.shape[:-1] + (duration - 2,)), - post_linear], dim=-1) + self._linear = torch.cat( + [ + pre_linear, + new_zeros(pre_linear.shape[:-1] + (duration - 2,)), + post_linear, + ], + dim=-1, + ) # Construct linear part from intervals of survival within [0, duration]. times = phylogeny.times.clamp(min=0, max=duration) @@ -304,7 +316,10 @@ def __call__(self, rate_grid, t=slice(None)): """ const = self._const[..., t] linear = self._linear[..., t] * rate_grid - log = self._log[..., t] * rate_grid.clamp(min=torch.finfo(rate_grid.dtype).tiny).log() + log = ( + self._log[..., t] + * rate_grid.clamp(min=torch.finfo(rate_grid.dtype).tiny).log() + ) return const + linear + log @@ -328,6 +343,7 @@ def bio_phylo_to_times(tree, *, get_time=None): def get_branch_length(clade): branch_length = clade.branch_length return 1.0 if branch_length is None else branch_length + times = {tree.root: get_branch_length(tree.root)} leaf_times = [] @@ -417,9 +433,16 @@ def memoized_fn(*args): # This helper data structure has only timing information. -_Phylogeny = namedtuple("_Phylogeny", ( - "times", "signs", "lineages", "binomial", "coal_binomial", -)) +_Phylogeny = namedtuple( + "_Phylogeny", + ( + "times", + "signs", + "lineages", + "binomial", + "coal_binomial", + ), +) @_weak_memoize @@ -439,7 +462,9 @@ def _make_phylogeny(leaf_times, coal_times): # (coal_times) into a pair (times, signs) of arrays of length 2N-1, where # leaf sample sign is +1 and coalescent sign is -1. times = torch.cat([coal_times, leaf_times], dim=-1) - signs = torch.linspace(1.5 - N, N - 0.5, 2 * N - 1).sign() # e.g. [-1, -1, +1, +1, +1] + signs = torch.linspace( + 1.5 - N, N - 0.5, 2 * N - 1 + ).sign() # e.g. [-1, -1, +1, +1, +1] # Sort the events reverse-ordered in time, i.e. latest to earliest. times, index = times.sort(dim=-1, descending=True) @@ -453,7 +478,7 @@ def _make_phylogeny(leaf_times, coal_times): binomial = lineages * (lineages - 1) / 2 # Compute the binomial coefficient following each coalescent event. - coal_index = inv_index[..., :N - 1] + coal_index = inv_index[..., : N - 1] coal_binomial = binomial.gather(-1, coal_index - 1) return _Phylogeny(times, signs, lineages, binomial, coal_binomial) @@ -469,7 +494,9 @@ def _sample_coalescent_times(leaf_times): # instead we simply sequentially sample and stack. if batch_shape: flat_leaf_times = leaf_times.reshape(-1, N) - flat_coal_times = torch.stack(list(map(_sample_coalescent_times, flat_leaf_times))) + flat_coal_times = torch.stack( + list(map(_sample_coalescent_times, flat_leaf_times)) + ) return flat_coal_times.reshape(batch_shape + (N - 1,)) assert leaf_times.shape == (N,) diff --git a/pyro/distributions/conditional.py b/pyro/distributions/conditional.py index 82918d40ae..ba5e045d6f 100644 --- a/pyro/distributions/conditional.py +++ b/pyro/distributions/conditional.py @@ -60,10 +60,17 @@ def clear_cache(self): class ConditionalTransformedDistribution(ConditionalDistribution): def __init__(self, base_dist, transforms): - self.base_dist = base_dist if isinstance( - base_dist, ConditionalDistribution) else ConstantConditionalDistribution(base_dist) - self.transforms = [t if isinstance(t, ConditionalTransform) - else ConstantConditionalTransform(t) for t in transforms] + self.base_dist = ( + base_dist + if isinstance(base_dist, ConditionalDistribution) + else ConstantConditionalDistribution(base_dist) + ) + self.transforms = [ + t + if isinstance(t, ConditionalTransform) + else ConstantConditionalTransform(t) + for t in transforms + ] def condition(self, context): base_dist = self.base_dist.condition(context) diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index 33f2092864..cb2cbefffb 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -16,15 +16,19 @@ def _log_beta_1(alpha, value, is_sparse): if is_sparse: - mask = (value != 0) + mask = value != 0 value, alpha, mask = torch.broadcast_tensors(value, alpha, mask) result = torch.zeros_like(value) value = value[mask] alpha = alpha[mask] - result[mask] = torch.lgamma(1 + value) + torch.lgamma(alpha) - torch.lgamma(value + alpha) + result[mask] = ( + torch.lgamma(1 + value) + torch.lgamma(alpha) - torch.lgamma(value + alpha) + ) return result else: - return torch.lgamma(1 + value) + torch.lgamma(alpha) - torch.lgamma(value + alpha) + return ( + torch.lgamma(1 + value) + torch.lgamma(alpha) - torch.lgamma(value + alpha) + ) class BetaBinomial(TorchDistribution): @@ -43,8 +47,11 @@ class BetaBinomial(TorchDistribution): :param total_count: Number of Bernoulli trials. :type total_count: float or torch.Tensor """ - arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive, - 'total_count': constraints.nonnegative_integer} + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + "total_count": constraints.nonnegative_integer, + } has_enumerate_support = True support = Binomial.support @@ -52,11 +59,14 @@ class BetaBinomial(TorchDistribution): # a shifted Sterling's approximation to the Beta function, reducing # computational cost from 9 lgamma() evaluations to 12 log() evaluations # plus arithmetic. Recommended values are between 0.1 and 0.01. - approx_log_prob_tol = 0. + approx_log_prob_tol = 0.0 - def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): + def __init__( + self, concentration1, concentration0, total_count=1, validate_args=None + ): concentration1, concentration0, total_count = broadcast_all( - concentration1, concentration0, total_count) + concentration1, concentration0, total_count + ) self._beta = Beta(concentration1, concentration0) self.total_count = total_count super().__init__(self._beta._batch_shape, validate_args=validate_args) @@ -91,7 +101,11 @@ def log_prob(self, value): a = self.concentration1 b = self.concentration0 tol = self.approx_log_prob_tol - return log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) - log_beta(a, b, tol) + return ( + log_binomial(n, k, tol) + + log_beta(k + a, n - k + b, tol) + - log_beta(a, b, tol) + ) @property def mean(self): @@ -99,13 +113,23 @@ def mean(self): @property def variance(self): - return self._beta.variance * self.total_count * (self.concentration0 + self.concentration1 + self.total_count) + return ( + self._beta.variance + * self.total_count + * (self.concentration0 + self.concentration1 + self.total_count) + ) def enumerate_support(self, expand=True): total_count = int(self.total_count.max()) if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.") - values = torch.arange(1 + total_count, dtype=self.concentration1.dtype, device=self.concentration1.device) + raise NotImplementedError( + "Inhomogeneous total count not supported by `enumerate_support`." + ) + values = torch.arange( + 1 + total_count, + dtype=self.concentration1.dtype, + device=self.concentration1.device, + ) values = values.view((-1,) + (1,) * len(self._batch_shape)) if expand: values = values.expand((-1,) + self._batch_shape) @@ -126,11 +150,15 @@ class DirichletMultinomial(TorchDistribution): :param bool is_sparse: Whether to assume value is mostly zero when computing :meth:`log_prob`, which can speed up computation when data is sparse. """ - arg_constraints = {'concentration': constraints.independent(constraints.positive, 1), - 'total_count': constraints.nonnegative_integer} + arg_constraints = { + "concentration": constraints.independent(constraints.positive, 1), + "total_count": constraints.nonnegative_integer, + } support = Multinomial.support - def __init__(self, concentration, total_count=1, is_sparse=False, validate_args=None): + def __init__( + self, concentration, total_count=1, is_sparse=False, validate_args=None + ): batch_shape = concentration.shape[:-1] event_shape = concentration.shape[-1:] if isinstance(total_count, numbers.Number): @@ -161,7 +189,8 @@ def expand(self, batch_shape, _instance=None): new.total_count = self.total_count.expand(batch_shape) new.is_sparse = self.is_sparse super(DirichletMultinomial, new).__init__( - new._dirichlet.batch_shape, new._dirichlet.event_shape, validate_args=False) + new._dirichlet.batch_shape, new._dirichlet.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -169,15 +198,18 @@ def sample(self, sample_shape=()): probs = self._dirichlet.sample(sample_shape) total_count = int(self.total_count.max()) if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `sample`.") + raise NotImplementedError( + "Inhomogeneous total count not supported by `sample`." + ) return Multinomial(total_count, probs).sample() def log_prob(self, value): if self._validate_args: self._validate_sample(value) alpha = self.concentration - return (_log_beta_1(alpha.sum(-1), value.sum(-1), self.is_sparse) - - _log_beta_1(alpha, value, self.is_sparse).sum(-1)) + return _log_beta_1(alpha.sum(-1), value.sum(-1), self.is_sparse) - _log_beta_1( + alpha, value, self.is_sparse + ).sum(-1) @property def mean(self): @@ -209,7 +241,10 @@ class GammaPoisson(TorchDistribution): distribution. """ - arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive} + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } support = Poisson.support def __init__(self, concentration, rate, validate_args=None): @@ -241,8 +276,12 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) post_value = self.concentration + value - return -log_beta(self.concentration, value + 1) - post_value.log() + \ - self.concentration * self.rate.log() - post_value * (1 + self.rate).log() + return ( + -log_beta(self.concentration, value + 1) + - post_value.log() + + self.concentration * self.rate.log() + - post_value * (1 + self.rate).log() + ) @property def mean(self): diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index 60db57f9b6..d6bd399459 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -21,6 +21,7 @@ class _Integer(Constraint): """ Constrain to integers. """ + is_discrete = True def check(self, value): @@ -34,8 +35,9 @@ class _Sphere(Constraint): """ Constrain to the Euclidean sphere of any dimension. """ + event_dim = 1 - reltol = 10. # Relative to finfo.eps. + reltol = 10.0 # Relative to finfo.eps. def check(self, value): eps = torch.finfo(value.dtype).eps @@ -51,11 +53,14 @@ class _CorrMatrix(Constraint): """ Constrains to a correlation matrix. """ + event_dim = 2 def check(self, value): # check for diagonal equal to 1 - unit_variance = torch.all(torch.abs(torch.diagonal(value, dim1=-2, dim2=-1) - 1) < 1e-6, dim=-1) + unit_variance = torch.all( + torch.abs(torch.diagonal(value, dim1=-2, dim2=-1) - 1) < 1e-6, dim=-1 + ) # TODO: fix upstream - positive_definite has an extra dimension in front of output shape return positive_definite.check(value) & unit_variance @@ -65,6 +70,7 @@ class _OrderedVector(Constraint): Constrains to a real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. """ + event_dim = 1 def check(self, value): @@ -105,14 +111,14 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ - 'corr_cholesky_constraint', - 'corr_matrix', - 'integer', - 'ordered_vector', - 'positive_ordered_vector', - 'softplus_lower_cholesky', - 'softplus_positive', - 'sphere', + "corr_cholesky_constraint", + "corr_matrix", + "integer", + "ordered_vector", + "positive_ordered_vector", + "softplus_lower_cholesky", + "softplus_positive", + "sphere", ] __all__.extend(torch_constraints) @@ -125,17 +131,22 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): Pyro's constraints library extends :mod:`torch.distributions.constraints`. """ -__doc__ += "\n".join([ - """ +__doc__ += "\n".join( + [ + """ {} ---------------------------------------------------------------- {} """.format( - _name, - "alias of :class:`torch.distributions.constraints.{}`".format(_name) - if globals()[_name].__module__.startswith("torch") else - ".. autoclass:: {}".format(_name if type(globals()[_name]) is type else - type(globals()[_name]).__name__) - ) - for _name in sorted(__all__) -]) + _name, + "alias of :class:`torch.distributions.constraints.{}`".format(_name) + if globals()[_name].__module__.startswith("torch") + else ".. autoclass:: {}".format( + _name + if type(globals()[_name]) is type + else type(globals()[_name]).__name__ + ), + ) + for _name in sorted(__all__) + ] +) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index affaf630e7..34afb9508f 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -25,21 +25,30 @@ class Delta(TorchDistribution): under differentiable transformation. :param int event_dim: Optional event dimension, defaults to zero. """ + has_rsample = True - arg_constraints = {'v': constraints.dependent, - 'log_density': constraints.real} + arg_constraints = {"v": constraints.dependent, "log_density": constraints.real} def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): if event_dim > v.dim(): - raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'.format(event_dim, v.dim())) + raise ValueError( + "Expected event_dim <= v.dim(), actual {} vs {}".format( + event_dim, v.dim() + ) + ) batch_dim = v.dim() - event_dim batch_shape = v.shape[:batch_dim] event_shape = v.shape[batch_dim:] if isinstance(log_density, numbers.Number): - log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device) + log_density = torch.full( + batch_shape, log_density, dtype=v.dtype, device=v.device + ) elif validate_args and log_density.shape != batch_shape: - raise ValueError('Expected log_density.shape = {}, actual {}'.format( - log_density.shape, batch_shape)) + raise ValueError( + "Expected log_density.shape = {}, actual {}".format( + log_density.shape, batch_shape + ) + ) self.v = v self.log_density = log_density super().__init__(batch_shape, event_shape, validate_args=validate_args) diff --git a/pyro/distributions/diag_normal_mixture.py b/pyro/distributions/diag_normal_mixture.py index ad36b227ea..4a7d09b1db 100644 --- a/pyro/distributions/diag_normal_mixture.py +++ b/pyro/distributions/diag_normal_mixture.py @@ -39,28 +39,37 @@ class MixtureOfDiagNormals(TorchDistribution): :param torch.Tensor coord_scale: K x D scale matrix :param torch.Tensor component_logits: K-dimensional vector of softmax logits """ + has_rsample = True - arg_constraints = {"locs": constraints.real, "coord_scale": constraints.positive, - "component_logits": constraints.real} + arg_constraints = { + "locs": constraints.real, + "coord_scale": constraints.positive, + "component_logits": constraints.real, + } def __init__(self, locs, coord_scale, component_logits): - self.batch_mode = (locs.dim() > 2) - assert(coord_scale.shape == locs.shape) - assert(self.batch_mode or locs.dim() == 2), \ - "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)" + self.batch_mode = locs.dim() > 2 + assert coord_scale.shape == locs.shape + assert ( + self.batch_mode or locs.dim() == 2 + ), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)" if not self.batch_mode: - assert(coord_scale.dim() == 2), \ - "The coord_scale parameter in MixtureOfDiagNormals should be K x D dimensional" - assert(component_logits.dim() == 1), \ - "The component_logits parameter in MixtureOfDiagNormals should be K dimensional" - assert(component_logits.size(-1) == locs.size(-2)) + assert ( + coord_scale.dim() == 2 + ), "The coord_scale parameter in MixtureOfDiagNormals should be K x D dimensional" + assert ( + component_logits.dim() == 1 + ), "The component_logits parameter in MixtureOfDiagNormals should be K dimensional" + assert component_logits.size(-1) == locs.size(-2) batch_shape = () else: - assert(coord_scale.dim() > 2), \ - "The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional" - assert(component_logits.dim() > 1), \ - "The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional" - assert(component_logits.size(-1) == locs.size(-2)) + assert ( + coord_scale.dim() > 2 + ), "The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional" + assert ( + component_logits.dim() > 1 + ), "The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional" + assert component_logits.size(-1) == locs.size(-2) batch_shape = tuple(locs.shape[:-2]) self.locs = locs @@ -69,8 +78,9 @@ def __init__(self, locs, coord_scale, component_logits): self.dim = locs.size(-1) self.categorical = Categorical(logits=component_logits) self.probs = self.categorical.probs - super().__init__(batch_shape=torch.Size(batch_shape), - event_shape=torch.Size((self.dim,))) + super().__init__( + batch_shape=torch.Size(batch_shape), event_shape=torch.Size((self.dim,)) + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(MixtureOfDiagNormals, _instance) @@ -78,11 +88,17 @@ def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new.dim = self.dim new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) - new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[-2:]) - new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.coord_scale = self.coord_scale.expand( + batch_shape + self.coord_scale.shape[-2:] + ) + new.component_logits = self.component_logits.expand( + batch_shape + self.component_logits.shape[-1:] + ) new.categorical = self.categorical.expand(batch_shape) new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) - super(MixtureOfDiagNormals, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(MixtureOfDiagNormals, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -91,7 +107,11 @@ def log_prob(self, value): eps_sqr = 0.5 * torch.pow(epsilon, 2.0).sum(-1) # L B K eps_sqr_min = torch.min(eps_sqr, -1)[0] # L B K coord_scale_prod_log_sum = self.coord_scale.log().sum(-1) # B K - result = self.categorical.logits + (-eps_sqr + eps_sqr_min.unsqueeze(-1)) - coord_scale_prod_log_sum # L B K + result = ( + self.categorical.logits + + (-eps_sqr + eps_sqr_min.unsqueeze(-1)) + - coord_scale_prod_log_sum + ) # L B K result = torch.logsumexp(result, dim=-1) # L B result = result - 0.5 * math.log(2.0 * math.pi) * float(self.dim) result = result - eps_sqr_min @@ -99,9 +119,14 @@ def log_prob(self, value): def rsample(self, sample_shape=torch.Size()): which = self.categorical.sample(sample_shape) - return _MixDiagNormalSample.apply(self.locs, self.coord_scale, - self.component_logits, self.categorical.probs, which, - sample_shape + self.locs.shape[:-2] + (self.dim,)) + return _MixDiagNormalSample.apply( + self.locs, + self.coord_scale, + self.component_logits, + self.categorical.probs, + which, + sample_shape + self.locs.shape[:-2] + (self.dim,), + ) class _MixDiagNormalSample(Function): @@ -144,7 +169,9 @@ def backward(ctx, grad_output): mu_ll_cd = (locs.unsqueeze(-2) * mu_cd).sum(-1) # b c d z_ll_cd = (z.unsqueeze(-2).unsqueeze(-2) * mu_cd).sum(-1) # l b c d - z_perp_cd = z.unsqueeze(-2).unsqueeze(-2) - z_ll_cd.unsqueeze(-1) * mu_cd # l b c d i + z_perp_cd = ( + z.unsqueeze(-2).unsqueeze(-2) - z_ll_cd.unsqueeze(-1) * mu_cd + ) # l b c d i z_perp_cd_sqr = torch.pow(z_perp_cd, 2.0).sum(-1) # l b c d shift_indices = torch.empty((dim,), dtype=torch.long, device=z.device) @@ -153,7 +180,9 @@ def backward(ctx, grad_output): shift_indices[0] = 0 z_shift_cumsum = torch.pow(z_shift, 2.0) - z_shift_cumsum = z_shift_cumsum.sum(-1, keepdim=True) - torch.cumsum(z_shift_cumsum, dim=-1) # l b j i + z_shift_cumsum = z_shift_cumsum.sum(-1, keepdim=True) - torch.cumsum( + z_shift_cumsum, dim=-1 + ) # l b j i z_tilde_cumsum = torch.cumsum(torch.pow(z_tilde, 2.0), dim=-1) # l b j i z_tilde_cumsum = torch.index_select(z_tilde_cumsum, -1, shift_indices) z_tilde_cumsum[..., 0] = 0.0 @@ -161,7 +190,9 @@ def backward(ctx, grad_output): log_scales = torch.log(scales) # b j i epsilons_sqr = torch.pow(z_tilde, 2.0) # l b j i - log_qs = -0.5 * epsilons_sqr - 0.5 * math.log(2.0 * math.pi) - log_scales # l b j i + log_qs = ( + -0.5 * epsilons_sqr - 0.5 * math.log(2.0 * math.pi) - log_scales + ) # l b j i log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1 q_j = torch.exp(log_q_j) # l b j 1 q_tot = (pis * q_j.squeeze(-1)).sum(-1) # l b @@ -172,13 +203,19 @@ def backward(ctx, grad_output): shift_log_scales[..., 0] = 0.0 sigma_products = torch.cumsum(shift_log_scales, dim=-1).exp() # b j i - reverse_indices = torch.tensor(range(dim - 1, -1, -1), dtype=torch.long, device=z.device) + reverse_indices = torch.tensor( + range(dim - 1, -1, -1), dtype=torch.long, device=z.device + ) reverse_log_sigma_0 = sigma_0.log()[..., reverse_indices] # b 1 i - sigma_0_products = torch.cumsum(reverse_log_sigma_0, dim=-1).exp()[..., reverse_indices - 1] # b 1 i + sigma_0_products = torch.cumsum(reverse_log_sigma_0, dim=-1).exp()[ + ..., reverse_indices - 1 + ] # b 1 i sigma_0_products[..., -1] = 1.0 sigma_products *= sigma_0_products - logits_grad = torch.erf(z_tilde / root_two) - torch.erf(z_shift / root_two) # l b j i + logits_grad = torch.erf(z_tilde / root_two) - torch.erf( + z_shift / root_two + ) # l b j i logits_grad *= torch.exp(-0.5 * r_sqr_ji) # l b j i logits_grad = (logits_grad * g / sigma_products).sum(-1) # l b j logits_grad = sum_leftmost(logits_grad / q_tot, -1 - batch_dims) # b j @@ -187,10 +224,17 @@ def backward(ctx, grad_output): logits_grad = logits_grad - logits_grad.sum(-1, keepdim=True) * pis mu_ll_dc = torch.transpose(mu_ll_cd, -1, -2) - v_cd = torch.erf((z_ll_cd - mu_ll_cd) / root_two) - torch.erf((z_ll_cd + mu_ll_dc) / root_two) + v_cd = torch.erf((z_ll_cd - mu_ll_cd) / root_two) - torch.erf( + (z_ll_cd + mu_ll_dc) / root_two + ) v_cd *= torch.exp(-0.5 * z_perp_cd_sqr) # l b c d mu_cd_g = (g.unsqueeze(-2) * mu_cd).sum(-1) # l b c d - v_cd *= -mu_cd_g * pis.unsqueeze(-2) * 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1)) # l b c d + v_cd *= ( + -mu_cd_g + * pis.unsqueeze(-2) + * 0.5 + * math.pow(2.0 * math.pi, -0.5 * (dim - 1)) + ) # l b c d v_cd = pis * sum_leftmost(v_cd.sum(-1) / q_tot, -1 - batch_dims) logits_grad += v_cd diff --git a/pyro/distributions/diag_normal_mixture_shared_cov.py b/pyro/distributions/diag_normal_mixture_shared_cov.py index de6a2b5c26..f8ee2ed357 100644 --- a/pyro/distributions/diag_normal_mixture_shared_cov.py +++ b/pyro/distributions/diag_normal_mixture_shared_cov.py @@ -38,58 +38,80 @@ class MixtureOfDiagNormalsSharedCovariance(TorchDistribution): :param torch.Tensor coord_scale: shared D-dimensional scale vector :param torch.Tensor component_logits: K-dimensional vector of softmax logits """ + has_rsample = True - arg_constraints = {"locs": constraints.real, "coord_scale": constraints.positive, - "component_logits": constraints.real} + arg_constraints = { + "locs": constraints.real, + "coord_scale": constraints.positive, + "component_logits": constraints.real, + } def __init__(self, locs, coord_scale, component_logits): - self.batch_mode = (locs.dim() > 2) - assert(self.batch_mode or locs.dim() == 2), \ - "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or ... x B x K x D in batch mode)" + self.batch_mode = locs.dim() > 2 + assert ( + self.batch_mode or locs.dim() == 2 + ), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or ... x B x K x D in batch mode)" if not self.batch_mode: - assert(coord_scale.dim() == 1), "The coord_scale parameter in MixtureOfDiagNormals should be D dimensional" - assert(component_logits.dim() == 1), \ - "The component_logits parameter in MixtureOfDiagNormals should be K dimensional" - assert(component_logits.size(0) == locs.size(0)) + assert ( + coord_scale.dim() == 1 + ), "The coord_scale parameter in MixtureOfDiagNormals should be D dimensional" + assert ( + component_logits.dim() == 1 + ), "The component_logits parameter in MixtureOfDiagNormals should be K dimensional" + assert component_logits.size(0) == locs.size(0) batch_shape = () else: - assert(coord_scale.dim() > 1), \ - "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x D dimensional" - assert(component_logits.dim() > 1), \ - "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional" - assert(component_logits.size(-1) == locs.size(-2)) + assert ( + coord_scale.dim() > 1 + ), "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x D dimensional" + assert ( + component_logits.dim() > 1 + ), "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional" + assert component_logits.size(-1) == locs.size(-2) batch_shape = tuple(locs.shape[:-2]) self.locs = locs self.coord_scale = coord_scale self.component_logits = component_logits self.dim = locs.size(-1) if self.dim < 2: - raise NotImplementedError('This distribution does not support D = 1') + raise NotImplementedError("This distribution does not support D = 1") self.categorical = Categorical(logits=component_logits) self.probs = self.categorical.probs super().__init__(batch_shape=batch_shape, event_shape=(self.dim,)) def expand(self, batch_shape, _instance=None): - new = self._get_checked_instance(MixtureOfDiagNormalsSharedCovariance, _instance) + new = self._get_checked_instance( + MixtureOfDiagNormalsSharedCovariance, _instance + ) new.batch_mode = True batch_shape = torch.Size(batch_shape) new.dim = self.dim new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) coord_scale_shape = -1 if self.batch_mode else -2 - new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[coord_scale_shape:]) - new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.coord_scale = self.coord_scale.expand( + batch_shape + self.coord_scale.shape[coord_scale_shape:] + ) + new.component_logits = self.component_logits.expand( + batch_shape + self.component_logits.shape[-1:] + ) new.categorical = self.categorical.expand(batch_shape) new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) - super(MixtureOfDiagNormalsSharedCovariance, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(MixtureOfDiagNormalsSharedCovariance, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new def log_prob(self, value): - coord_scale = self.coord_scale.unsqueeze(-2) if self.batch_mode else self.coord_scale + coord_scale = ( + self.coord_scale.unsqueeze(-2) if self.batch_mode else self.coord_scale + ) epsilon = (value.unsqueeze(-2) - self.locs) / coord_scale # L B K D eps_sqr = 0.5 * torch.pow(epsilon, 2.0).sum(-1) # L B K eps_sqr_min = torch.min(eps_sqr, -1)[0] # L B - result = self.categorical.logits + (-eps_sqr + eps_sqr_min.unsqueeze(-1)) # L B K + result = self.categorical.logits + ( + -eps_sqr + eps_sqr_min.unsqueeze(-1) + ) # L B K result = torch.logsumexp(result, dim=-1) # L B result = result - (0.5 * math.log(2.0 * math.pi) * float(self.dim)) result = result - (torch.log(self.coord_scale).sum(-1)) @@ -98,8 +120,14 @@ def log_prob(self, value): def rsample(self, sample_shape=torch.Size()): which = self.categorical.sample(sample_shape) - return _MixDiagNormalSharedCovarianceSample.apply(self.locs, self.coord_scale, self.component_logits, - self.probs, which, sample_shape + self.coord_scale.shape) + return _MixDiagNormalSharedCovarianceSample.apply( + self.locs, + self.coord_scale, + self.component_logits, + self.probs, + which, + sample_shape + self.coord_scale.shape, + ) class _MixDiagNormalSharedCovarianceSample(Function): @@ -136,12 +164,14 @@ def backward(ctx, grad_output): mu_ll_ab = (locs_tilde.unsqueeze(-2) * mu_ab).sum(-1) # b k j z_ll_ab = (z_tilde.unsqueeze(-2).unsqueeze(-2) * mu_ab).sum(-1) # l b k j - z_perp_ab = z_tilde.unsqueeze(-2).unsqueeze(-2) - z_ll_ab.unsqueeze(-1) * mu_ab # l b k j i + z_perp_ab = ( + z_tilde.unsqueeze(-2).unsqueeze(-2) - z_ll_ab.unsqueeze(-1) * mu_ab + ) # l b k j i z_perp_ab_sqr = torch.pow(z_perp_ab, 2.0).sum(-1) # l b k j epsilons = z_tilde.unsqueeze(-2) - locs_tilde # l b j i - log_qs = -0.5 * torch.pow(epsilons, 2.0) # l b j i - log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1 + log_qs = -0.5 * torch.pow(epsilons, 2.0) # l b j i + log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1 log_q_j_max = torch.max(log_q_j, -2, keepdim=True)[0] q_j_prime = torch.exp(log_q_j - log_q_j_max) # l b j 1 q_j = torch.exp(log_q_j) # l b j 1 @@ -151,18 +181,28 @@ def backward(ctx, grad_output): root_two = math.sqrt(2.0) mu_ll_ba = torch.transpose(mu_ll_ab, -1, -2) - logits_grad = torch.erf((z_ll_ab - mu_ll_ab) / root_two) - torch.erf((z_ll_ab + mu_ll_ba) / root_two) + logits_grad = torch.erf((z_ll_ab - mu_ll_ab) / root_two) - torch.erf( + (z_ll_ab + mu_ll_ba) / root_two + ) logits_grad *= torch.exp(-0.5 * z_perp_ab_sqr) # l b k j # bi lbi bkji - mu_ab_sigma_g = ((coord_scale * g).unsqueeze(-2).unsqueeze(-2) * mu_ab).sum(-1) # l b k j + mu_ab_sigma_g = ((coord_scale * g).unsqueeze(-2).unsqueeze(-2) * mu_ab).sum( + -1 + ) # l b k j logits_grad *= -mu_ab_sigma_g * pis.unsqueeze(-2) # l b k j - logits_grad = pis * sum_leftmost(logits_grad.sum(-1) / q_tot, -(1 + batch_dims)) # b k + logits_grad = pis * sum_leftmost( + logits_grad.sum(-1) / q_tot, -(1 + batch_dims) + ) # b k logits_grad *= math.sqrt(0.5 * math.pi) # b j l b j 1 l b i l b 1 1 - prefactor = pis.unsqueeze(-1) * q_j_prime * g.unsqueeze(-2) / q_tot_prime # l b j i + prefactor = ( + pis.unsqueeze(-1) * q_j_prime * g.unsqueeze(-2) / q_tot_prime + ) # l b j i locs_grad = sum_leftmost(prefactor, -(2 + batch_dims)) # b j i - coord_scale_grad = sum_leftmost(prefactor * epsilons, -(2 + batch_dims)).sum(-2) # b i + coord_scale_grad = sum_leftmost(prefactor * epsilons, -(2 + batch_dims)).sum( + -2 + ) # b i return locs_grad, coord_scale_grad, logits_grad, None, None, None diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index 5d18e37d9b..dfaf0595c7 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -44,6 +44,7 @@ class Distribution(metaclass=DistributionMeta): Take a look at the `examples `_ to see how they interact with inference algorithms. """ + has_rsample = False has_enumerate_support = False @@ -109,11 +110,15 @@ def score_parts(self, x, *args, **kwargs): """ log_prob = self.log_prob(x, *args, **kwargs) if self.has_rsample: - return ScoreParts(log_prob=log_prob, score_function=0, entropy_term=log_prob) + return ScoreParts( + log_prob=log_prob, score_function=0, entropy_term=log_prob + ) else: # XXX should the user be able to control inclusion of the entropy term? # See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" https://arxiv.org/abs/1703.09194 - return ScoreParts(log_prob=log_prob, score_function=log_prob, entropy_term=0) + return ScoreParts( + log_prob=log_prob, score_function=log_prob, entropy_term=0 + ) def enumerate_support(self, expand=True): """ @@ -131,7 +136,9 @@ def enumerate_support(self, expand=True): :return: An iterator over the distribution's discrete support. :rtype: iterator """ - raise NotImplementedError("Support not implemented for {}".format(type(self).__name__)) + raise NotImplementedError( + "Support not implemented for {}".format(type(self).__name__) + ) def conjugate_update(self, other): """ @@ -162,8 +169,9 @@ def conjugate_update(self, other): updated distribution of type ``type(self)``, and ``log_normalizer`` is a :class:`~torch.Tensor` representing the normalization factor. """ - raise NotImplementedError("{} does not support .conjugate_update()" - .format(type(self).__name__)) + raise NotImplementedError( + "{} does not support .conjugate_update()".format(type(self).__name__) + ) def has_rsample_(self, value): """ @@ -206,4 +214,5 @@ def rv(self): :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable """ from pyro.contrib.randomvariable import RandomVariable + return RandomVariable(self) diff --git a/pyro/distributions/empirical.py b/pyro/distributions/empirical.py index 5deb432d31..7383412aa9 100644 --- a/pyro/distributions/empirical.py +++ b/pyro/distributions/empirical.py @@ -54,15 +54,24 @@ def __init__(self, samples, log_weights, validate_args=None): self._samples = samples self._log_weights = log_weights sample_shape, weight_shape = samples.size(), log_weights.size() - if weight_shape > sample_shape or weight_shape != sample_shape[:len(weight_shape)]: - raise ValueError("The shape of ``log_weights`` ({}) must match " - "the leftmost shape of ``samples`` ({})".format(weight_shape, sample_shape)) + if ( + weight_shape > sample_shape + or weight_shape != sample_shape[: len(weight_shape)] + ): + raise ValueError( + "The shape of ``log_weights`` ({}) must match " + "the leftmost shape of ``samples`` ({})".format( + weight_shape, sample_shape + ) + ) self._aggregation_dim = log_weights.dim() - 1 - event_shape = sample_shape[len(weight_shape):] + event_shape = sample_shape[len(weight_shape) :] self._categorical = Categorical(logits=self._log_weights) - super().__init__(batch_shape=weight_shape[:-1], - event_shape=event_shape, - validate_args=validate_args) + super().__init__( + batch_shape=weight_shape[:-1], + event_shape=event_shape, + validate_args=validate_args, + ) @property def sample_size(self): @@ -74,12 +83,20 @@ def sample_size(self): return self._log_weights.numel() def sample(self, sample_shape=torch.Size()): - sample_idx = self._categorical.sample(sample_shape) # sample_shape x batch_shape + sample_idx = self._categorical.sample( + sample_shape + ) # sample_shape x batch_shape # reorder samples to bring aggregation_dim to the front: # batch_shape x num_samples x event_shape -> num_samples x batch_shape x event_shape - samples = self._samples.unsqueeze(0).transpose(0, self._aggregation_dim + 1).squeeze(self._aggregation_dim + 1) + samples = ( + self._samples.unsqueeze(0) + .transpose(0, self._aggregation_dim + 1) + .squeeze(self._aggregation_dim + 1) + ) # make sample_idx.shape compatible with samples.shape: sample_shape_numel x batch_shape x event_shape - sample_idx = sample_idx.reshape((-1,) + self.batch_shape + (1,) * len(self.event_shape)) + sample_idx = sample_idx.reshape( + (-1,) + self.batch_shape + (1,) * len(self.event_shape) + ) sample_idx = sample_idx.expand((-1,) + samples.shape[1:]) return samples.gather(0, sample_idx).reshape(sample_shape + samples.shape[1:]) @@ -93,7 +110,11 @@ def log_prob(self, value): """ if self._validate_args: if value.shape != self.batch_shape + self.event_shape: - raise ValueError("``value.shape`` must be {}".format(self.batch_shape + self.event_shape)) + raise ValueError( + "``value.shape`` must be {}".format( + self.batch_shape + self.event_shape + ) + ) if self.batch_shape: value = value.unsqueeze(self._aggregation_dim) selection_mask = self._samples.eq(value) @@ -105,12 +126,16 @@ def log_prob(self, value): return (self._categorical.probs * selection_mask).sum(dim=-1).log() def _weighted_mean(self, value, keepdim=False): - weights = self._log_weights.reshape(self._log_weights.size() + - torch.Size([1] * (value.dim() - self._log_weights.dim()))) + weights = self._log_weights.reshape( + self._log_weights.size() + + torch.Size([1] * (value.dim() - self._log_weights.dim())) + ) dim = self._aggregation_dim max_weight = weights.max(dim=dim, keepdim=True)[0] relative_probs = (weights - max_weight).exp() - return (value * relative_probs).sum(dim=dim, keepdim=keepdim) / relative_probs.sum(dim=dim, keepdim=keepdim) + return (value * relative_probs).sum( + dim=dim, keepdim=keepdim + ) / relative_probs.sum(dim=dim, keepdim=keepdim) @property def event_shape(self): @@ -119,21 +144,25 @@ def event_shape(self): @property def mean(self): if self._samples.dtype in (torch.int32, torch.int64): - raise ValueError("Mean for discrete empirical distribution undefined. " + - "Consider converting samples to ``torch.float32`` " + - "or ``torch.float64``. If these are samples from a " + - "`Categorical` distribution, consider converting to a " + - "`OneHotCategorical` distribution.") + raise ValueError( + "Mean for discrete empirical distribution undefined. " + + "Consider converting samples to ``torch.float32`` " + + "or ``torch.float64``. If these are samples from a " + + "`Categorical` distribution, consider converting to a " + + "`OneHotCategorical` distribution." + ) return self._weighted_mean(self._samples) @property def variance(self): if self._samples.dtype in (torch.int32, torch.int64): - raise ValueError("Variance for discrete empirical distribution undefined. " + - "Consider converting samples to ``torch.float32`` " + - "or ``torch.float64``. If these are samples from a " + - "`Categorical` distribution, consider converting to a " + - "`OneHotCategorical` distribution.") + raise ValueError( + "Variance for discrete empirical distribution undefined. " + + "Consider converting samples to ``torch.float32`` " + + "or ``torch.float64``. If these are samples from a " + + "`Categorical` distribution, consider converting to a " + + "`OneHotCategorical` distribution." + ) mean = self.mean.unsqueeze(self._aggregation_dim) deviation_squared = torch.pow(self._samples - mean, 2) return self._weighted_mean(deviation_squared) diff --git a/pyro/distributions/extended.py b/pyro/distributions/extended.py index eb3720481f..16c32c6657 100644 --- a/pyro/distributions/extended.py +++ b/pyro/distributions/extended.py @@ -16,9 +16,12 @@ class ExtendedBinomial(Binomial): ``total_count``. Numerical support is still the integer interval ``[0, total_count]``. """ - arg_constraints = {"total_count": constraints.integer, - "probs": constraints.unit_interval, - "logits": constraints.real} + + arg_constraints = { + "total_count": constraints.integer, + "probs": constraints.unit_interval, + "logits": constraints.real, + } support = constraints.integer def log_prob(self, value): @@ -34,9 +37,12 @@ class ExtendedBetaBinomial(BetaBinomial): integer ``total_count``. Numerical support is still the integer interval ``[0, total_count]``. """ - arg_constraints = {"concentration1": constraints.positive, - "concentration0": constraints.positive, - "total_count": constraints.integer} + + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + "total_count": constraints.integer, + } support = constraints.integer def log_prob(self, value): diff --git a/pyro/distributions/folded.py b/pyro/distributions/folded.py index 311685c08d..89f3fae661 100644 --- a/pyro/distributions/folded.py +++ b/pyro/distributions/folded.py @@ -15,6 +15,7 @@ class FoldedDistribution(TransformedDistribution): :param ~torch.distributions.Distribution base_dist: The distribution to reflect. """ + support = constraints.positive def __init__(self, base_dist, validate_args=None): @@ -30,5 +31,5 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) dim = max(len(self.batch_shape), value.dim()) - plus_minus = value.new_tensor([1., -1.]).reshape((2,) + (1,) * dim) + plus_minus = value.new_tensor([1.0, -1.0]).reshape((2,) + (1,) * dim) return self.base_dist.log_prob(plus_minus * value).logsumexp(0) diff --git a/pyro/distributions/gaussian_scale_mixture.py b/pyro/distributions/gaussian_scale_mixture.py index c2b15717bd..f1dc47cf64 100644 --- a/pyro/distributions/gaussian_scale_mixture.py +++ b/pyro/distributions/gaussian_scale_mixture.py @@ -49,21 +49,30 @@ class GaussianScaleMixture(TorchDistribution): :param torch.tensor component_logits: K-dimensional vector of logits :param torch.tensor component_scale: K-dimensional vector of scale multipliers """ + has_rsample = True - arg_constraints = {"component_scale": constraints.positive, "coord_scale": constraints.positive, - "component_logits": constraints.real} + arg_constraints = { + "component_scale": constraints.positive, + "coord_scale": constraints.positive, + "component_logits": constraints.real, + } def __init__(self, coord_scale, component_logits, component_scale): self.dim = coord_scale.size(0) if self.dim < 2: - raise NotImplementedError('This distribution does not support D = 1') - assert(coord_scale.dim() == 1), "The coord_scale parameter in GaussianScaleMixture should be D dimensional" - assert(component_scale.dim() == 1), \ - "The component_scale parameter in GaussianScaleMixture should be K dimensional" - assert(component_logits.dim() == 1), \ - "The component_logits parameter in GaussianScaleMixture should be K dimensional" - assert(component_logits.shape == component_scale.shape), \ - "The component_logits and component_scale parameters in GaussianScaleMixture should be K dimensional" + raise NotImplementedError("This distribution does not support D = 1") + assert ( + coord_scale.dim() == 1 + ), "The coord_scale parameter in GaussianScaleMixture should be D dimensional" + assert ( + component_scale.dim() == 1 + ), "The component_scale parameter in GaussianScaleMixture should be K dimensional" + assert ( + component_logits.dim() == 1 + ), "The component_logits parameter in GaussianScaleMixture should be K dimensional" + assert ( + component_logits.shape == component_scale.shape + ), "The component_logits and component_scale parameters in GaussianScaleMixture should be K dimensional" self.coord_scale = coord_scale self.component_logits = component_logits self.component_scale = component_scale @@ -78,7 +87,7 @@ def _compute_coeffs(self): dimov2 = int(self.dim / 2) # this is correct for both even and odd dimensions coeffs = torch.ones(dimov2) for k in range(dimov2 - 1): - coeffs[k + 1:] *= self.dim - 2 * (k + 1) + coeffs[k + 1 :] *= self.dim - 2 * (k + 1) return coeffs def log_prob(self, value): @@ -87,31 +96,52 @@ def log_prob(self, value): component_scale_log_power = self.component_scale.log() * -self.dim # logits in Categorical is already normalized result = torch.logsumexp( - component_scale_log_power + self.categorical.logits + - -0.5 * epsilon_sqr / torch.pow(self.component_scale, 2.0), dim=-1) # K + component_scale_log_power + + self.categorical.logits + + -0.5 * epsilon_sqr / torch.pow(self.component_scale, 2.0), + dim=-1, + ) # K result -= 0.5 * math.log(2.0 * math.pi) * float(self.dim) result -= self.coord_scale.log().sum() return result def rsample(self, sample_shape=torch.Size()): which = self.categorical.sample(sample_shape) - return _GSMSample.apply(self.coord_scale, self.component_logits, self.component_scale, self.categorical.probs, - which, sample_shape + torch.Size((self.dim,)), self.coeffs) + return _GSMSample.apply( + self.coord_scale, + self.component_logits, + self.component_scale, + self.categorical.probs, + which, + sample_shape + torch.Size((self.dim,)), + self.coeffs, + ) class _GSMSample(Function): @staticmethod - def forward(ctx, coord_scale, component_logits, component_scale, pis, which, shape, coeffs): + def forward( + ctx, coord_scale, component_logits, component_scale, pis, which, shape, coeffs + ): white = coord_scale.new(shape).normal_() which_component_scale = component_scale[which].unsqueeze(-1) z = coord_scale * which_component_scale * white - ctx.save_for_backward(z, coord_scale, component_logits, component_scale, pis, coeffs) + ctx.save_for_backward( + z, coord_scale, component_logits, component_scale, pis, coeffs + ) return z @staticmethod @once_differentiable def backward(ctx, grad_output): - z, coord_scale, component_logits, component_scale, pis, coeffs = ctx.saved_tensors + ( + z, + coord_scale, + component_logits, + component_scale, + pis, + coeffs, + ) = ctx.saved_tensors dim = coord_scale.size(0) g = grad_output # l i g = g.unsqueeze(-2) # l 1 i @@ -124,31 +154,51 @@ def backward(ctx, grad_output): coord_scale_product = coord_scale.prod() component_scale_power = torch.pow(component_scale, float(dim)) - q_j = torch.exp(-0.5 * r_sqr_j) / math.pow(2.0 * math.pi, 0.5 * float(dim)) # l j + q_j = torch.exp(-0.5 * r_sqr_j) / math.pow( + 2.0 * math.pi, 0.5 * float(dim) + ) # l j q_j /= coord_scale_product * component_scale_power # l j q_tot = (pis * q_j).sum(-1, keepdim=True) # l Phi_j = torch.exp(-0.5 * r_sqr_j) # l j - exponents = - torch.arange(1., int(dim/2) + 1., 1.) + exponents = -torch.arange(1.0, int(dim / 2) + 1.0, 1.0) if z.dim() > 1: - r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim/2)) # l j d/2 + r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim / 2)) # l j d/2 else: - r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, int(dim/2)) # l j d/2 + r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, int(dim / 2)) # l j d/2 r_j_poly = coeffs * torch.pow(r_j_poly, exponents) Phi_j *= r_j_poly.sum(-1) if dim % 2 == 1: root_two = math.sqrt(2.0) - extra_term = coeffs[-1] * math.sqrt(0.5 * math.pi) * (1.0 - torch.erf(r_sqr_j.sqrt() / root_two)) # l j + extra_term = ( + coeffs[-1] + * math.sqrt(0.5 * math.pi) + * (1.0 - torch.erf(r_sqr_j.sqrt() / root_two)) + ) # l j Phi_j += extra_term * torch.pow(r_sqr_j, -0.5 * float(dim)) logits_grad = (z.unsqueeze(-2) * Phi_j.unsqueeze(-1) * g).sum(-1) # l j logits_grad /= q_tot - logits_grad = sum_leftmost(logits_grad, -1) * math.pow(2.0 * math.pi, -0.5 * float(dim)) + logits_grad = sum_leftmost(logits_grad, -1) * math.pow( + 2.0 * math.pi, -0.5 * float(dim) + ) logits_grad = pis * logits_grad / (component_scale_power * coord_scale_product) logits_grad = logits_grad - logits_grad.sum() * pis - prefactor = pis.unsqueeze(-1) * q_j.unsqueeze(-1) * g / q_tot.unsqueeze(-1) # l j i + prefactor = ( + pis.unsqueeze(-1) * q_j.unsqueeze(-1) * g / q_tot.unsqueeze(-1) + ) # l j i coord_scale_grad = sum_leftmost(prefactor * epsilons.unsqueeze(-2), -1) - component_scale_grad = sum_leftmost((prefactor * z.unsqueeze(-2)).sum(-1) / component_scale, -1) - - return coord_scale_grad, logits_grad, component_scale_grad, None, None, None, None + component_scale_grad = sum_leftmost( + (prefactor * z.unsqueeze(-2)).sum(-1) / component_scale, -1 + ) + + return ( + coord_scale_grad, + logits_grad, + component_scale_grad, + None, + None, + None, + None, + ) diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 0cb663b4f9..28e065429a 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -119,9 +119,13 @@ def _sequential_gaussian_filter_sample(init, trans, sample_shape): assert _is_subshape(trans.batch_shape[:-1], init.batch_shape) state_dim = trans.dim() // 2 device = trans.precision.device - perm = torch.cat([torch.arange(1 * state_dim, 2 * state_dim, device=device), - torch.arange(0 * state_dim, 1 * state_dim, device=device), - torch.arange(2 * state_dim, 3 * state_dim, device=device)]) + perm = torch.cat( + [ + torch.arange(1 * state_dim, 2 * state_dim, device=device), + torch.arange(0 * state_dim, 1 * state_dim, device=device), + torch.arange(2 * state_dim, 3 * state_dim, device=device), + ] + ) # Forward filter, similar to _sequential_gaussian_tensordot(). tape = [] @@ -157,10 +161,13 @@ def _sequential_gaussian_filter_sample(init, trans, sample_shape): cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4] sample = joint.condition(cond).rsample() # [z1, z3] sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0] - result = torch.stack([ - result, # [z0, z2, z4] - sample, # [z1, z3, 0] - ], dim=-2) # [[z0, z1], [z2, z3], [z4, 0]] + result = torch.stack( + [ + result, # [z0, z2, z4] + sample, # [z1, z3, 0] + ], + dim=-2, + ) # [[z0, z1], [z2, z3], [z4, 0]] result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0] result = result[..., :-1, :] # [z0, z1, z2, z3, z4] else: # ODD case. @@ -171,10 +178,13 @@ def _sequential_gaussian_filter_sample(init, trans, sample_shape): cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2] sample = joint.condition(cond).rsample() # [z1] sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3] - result = torch.stack([ - result[..., :-1, :], # [z0, z2] - sample, # [z1, z3] - ], dim=-2) # [[z0, z1], [z2, z3]] + result = torch.stack( + [ + result[..., :-1, :], # [z0, z2] + sample, # [z1, z3] + ], + dim=-2, + ) # [[z0, z1], [z2, z3]] result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3] return result[..., 1:, :] # [z1, z2, z3, ...] @@ -198,7 +208,9 @@ def _sequential_gamma_gaussian_tensordot(gamma_gaussian): x, y = x_y[..., 0], x_y[..., 1] contracted = gamma_gaussian_tensordot(x, y, state_dim) if time > even_time: - contracted = GammaGaussian.cat((contracted, gamma_gaussian[..., -1:]), dim=-1) + contracted = GammaGaussian.cat( + (contracted, gamma_gaussian[..., -1:]), dim=-1 + ) gamma_gaussian = contracted return gamma_gaussian[..., 0] @@ -213,6 +225,7 @@ class HiddenMarkovModel(TorchDistribution): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ + def __init__(self, duration, batch_shape, event_shape, validate_args=None): if duration is None: if event_shape[0] != 1: @@ -220,8 +233,11 @@ def __init__(self, duration, batch_shape, event_shape, validate_args=None): duration = event_shape[0] elif duration != event_shape[0]: if event_shape[0] != 1: - raise ValueError("duration, event_shape mismatch: {} vs {}" - .format(duration, event_shape)) + raise ValueError( + "duration, event_shape mismatch: {} vs {}".format( + duration, event_shape + ) + ) # Infer event_shape from duration. event_shape = torch.Size((duration,) + event_shape[1:]) self._duration = duration @@ -295,29 +311,50 @@ class DiscreteHMM(HiddenMarkovModel): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ - arg_constraints = {"initial_logits": constraints.real, - "transition_logits": constraints.real} - def __init__(self, initial_logits, transition_logits, observation_dist, - validate_args=None, duration=None): + arg_constraints = { + "initial_logits": constraints.real, + "transition_logits": constraints.real, + } + + def __init__( + self, + initial_logits, + transition_logits, + observation_dist, + validate_args=None, + duration=None, + ): if initial_logits.dim() < 1: - raise ValueError("expected initial_logits to have at least one dim, " - "actual shape = {}".format(initial_logits.shape)) + raise ValueError( + "expected initial_logits to have at least one dim, " + "actual shape = {}".format(initial_logits.shape) + ) if transition_logits.dim() < 2: - raise ValueError("expected transition_logits to have at least two dims, " - "actual shape = {}".format(transition_logits.shape)) + raise ValueError( + "expected transition_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape) + ) if len(observation_dist.batch_shape) < 1: - raise ValueError("expected observation_dist to have at least one batch dim, " - "actual .batch_shape = {}".format(observation_dist.batch_shape)) - shape = broadcast_shape(initial_logits.shape[:-1] + (1,), - transition_logits.shape[:-2], - observation_dist.batch_shape[:-1]) + raise ValueError( + "expected observation_dist to have at least one batch dim, " + "actual .batch_shape = {}".format(observation_dist.batch_shape) + ) + shape = broadcast_shape( + initial_logits.shape[:-1] + (1,), + transition_logits.shape[:-2], + observation_dist.batch_shape[:-1], + ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True) - self.transition_logits = transition_logits - transition_logits.logsumexp(-1, True) + self.transition_logits = transition_logits - transition_logits.logsumexp( + -1, True + ) self.observation_dist = observation_dist - super().__init__(duration, batch_shape, event_shape, validate_args=validate_args) + super().__init__( + duration, batch_shape, event_shape, validate_args=validate_args + ) @constraints.dependent_property(event_dim=2) def support(self): @@ -333,8 +370,10 @@ def expand(self, batch_shape, _instance=None): new.initial_logits = self.initial_logits.expand(batch_shape + (-1,)) new.transition_logits = self.transition_logits new.observation_dist = self.observation_dist - super(DiscreteHMM, new).__init__(self.duration, batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(DiscreteHMM, new).__init__( + self.duration, batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def log_prob(self, value): @@ -444,36 +483,52 @@ class GaussianHMM(HiddenMarkovModel): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ + has_rsample = True arg_constraints = {} support = constraints.independent(constraints.real, 2) - def __init__(self, initial_dist, transition_matrix, transition_dist, - observation_matrix, observation_dist, validate_args=None, duration=None): - assert (isinstance(initial_dist, torch.distributions.MultivariateNormal) or - (isinstance(initial_dist, torch.distributions.Independent) and - isinstance(initial_dist.base_dist, torch.distributions.Normal))) + def __init__( + self, + initial_dist, + transition_matrix, + transition_dist, + observation_matrix, + observation_dist, + validate_args=None, + duration=None, + ): + assert isinstance(initial_dist, torch.distributions.MultivariateNormal) or ( + isinstance(initial_dist, torch.distributions.Independent) + and isinstance(initial_dist.base_dist, torch.distributions.Normal) + ) assert isinstance(transition_matrix, torch.Tensor) - assert (isinstance(transition_dist, torch.distributions.MultivariateNormal) or - (isinstance(transition_dist, torch.distributions.Independent) and - isinstance(transition_dist.base_dist, torch.distributions.Normal))) + assert isinstance(transition_dist, torch.distributions.MultivariateNormal) or ( + isinstance(transition_dist, torch.distributions.Independent) + and isinstance(transition_dist.base_dist, torch.distributions.Normal) + ) assert isinstance(observation_matrix, torch.Tensor) - assert (isinstance(observation_dist, torch.distributions.MultivariateNormal) or - (isinstance(observation_dist, torch.distributions.Independent) and - isinstance(observation_dist.base_dist, torch.distributions.Normal))) + assert isinstance(observation_dist, torch.distributions.MultivariateNormal) or ( + isinstance(observation_dist, torch.distributions.Independent) + and isinstance(observation_dist.base_dist, torch.distributions.Normal) + ) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert initial_dist.event_shape == (hidden_dim,) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim,) assert observation_dist.event_shape == (obs_dim,) - shape = broadcast_shape(initial_dist.batch_shape + (1,), - transition_matrix.shape[:-2], - transition_dist.batch_shape, - observation_matrix.shape[:-2], - observation_dist.batch_shape) + shape = broadcast_shape( + initial_dist.batch_shape + (1,), + transition_matrix.shape[:-2], + transition_dist.batch_shape, + observation_matrix.shape[:-2], + observation_dist.batch_shape, + ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) - super().__init__(duration, batch_shape, event_shape, validate_args=validate_args) + super().__init__( + duration, batch_shape, event_shape, validate_args=validate_args + ) self.hidden_dim = hidden_dim self.obs_dim = obs_dim @@ -494,8 +549,10 @@ def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) new._init = self._init.expand(batch_shape) - super(GaussianHMM, new).__init__(self.duration, batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(GaussianHMM, new).__init__( + self.duration, batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def log_prob(self, value): @@ -503,7 +560,9 @@ def log_prob(self, value): self._validate_sample(value) # Combine observation and transition factors. - result = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) + result = self._trans + self._obs.condition(value).event_pad( + left=self.hidden_dim + ) # Eliminate time dimension. result = _sequential_gaussian_tensordot(result.expand(result.batch_shape)) @@ -518,7 +577,9 @@ def log_prob(self, value): def rsample(self, sample_shape=torch.Size()): assert self.duration is not None sample_shape = torch.Size(sample_shape) - trans = self._trans + self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim) + trans = self._trans + self._obs.marginalize(right=self.obs_dim).event_pad( + left=self.hidden_dim + ) trans = trans.expand(trans.batch_shape[:-1] + (self.duration,)) z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape) x = self._obs.left_condition(z).rsample() @@ -558,9 +619,12 @@ def filter(self, value): # Convert to a distribution precision = logp.precision - loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze(-1) - return MultivariateNormal(loc, precision_matrix=precision, - validate_args=self._validate_args) + loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze( + -1 + ) + return MultivariateNormal( + loc, precision_matrix=precision, validate_args=self._validate_args + ) def conjugate_update(self, other): """ @@ -580,9 +644,10 @@ def conjugate_update(self, other): updated :class:`GaussianHMM` , and ``log_normalizer`` is a :class:`~torch.Tensor` representing the normalization factor. """ - assert (isinstance(other, torch.distributions.Independent) and - (isinstance(other.base_dist, torch.distributions.Normal) or - isinstance(other.base_dist, torch.distributions.MultivariateNormal))) + assert isinstance(other, torch.distributions.Independent) and ( + isinstance(other.base_dist, torch.distributions.Normal) + or isinstance(other.base_dist, torch.distributions.MultivariateNormal) + ) duration = other.event_shape[0] if self.duration is None else self.duration event_shape = torch.Size((duration, self.obs_dim)) assert other.event_shape == event_shape @@ -592,19 +657,25 @@ def conjugate_update(self, other): new.obs_dim = self.obs_dim new._init = self._init new._trans = self._trans - new._obs = self._obs + mvn_to_gaussian(other.to_event(-1)).event_pad(left=self.hidden_dim) + new._obs = self._obs + mvn_to_gaussian(other.to_event(-1)).event_pad( + left=self.hidden_dim + ) # Normalize. # TODO cache this computation for the forward pass of .rsample(). - logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad(left=new.hidden_dim) + logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad( + left=new.hidden_dim + ) logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim) log_normalizer = logp.event_logsumexp() new._init = new._init - log_normalizer batch_shape = log_normalizer.shape - super(GaussianHMM, new).__init__(duration, batch_shape, event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(GaussianHMM, new).__init__( + duration, batch_shape, event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new, log_normalizer def prefix_condition(self, data): @@ -640,20 +711,24 @@ def prefix_condition(self, data): left._obs = self._obs[..., :t] right._obs = self._obs[..., t:] - if self._trans.batch_shape == () or self._trans.batch_shape[-1] == 1: # homogeneous + if ( + self._trans.batch_shape == () or self._trans.batch_shape[-1] == 1 + ): # homogeneous left._trans = self._trans right._trans = self._trans else: # heterogeneous left._trans = self._trans[..., :t] right._trans = self._trans[..., t:] - super(GaussianHMM, left).__init__(t, self.batch_shape, (t, self.obs_dim), - validate_args=self._validate_args) + super(GaussianHMM, left).__init__( + t, self.batch_shape, (t, self.obs_dim), validate_args=self._validate_args + ) initial_dist = left.filter(data) right._init = mvn_to_gaussian(initial_dist) batch_shape = broadcast_shape(right._init.batch_shape, self.batch_shape) - super(GaussianHMM, right).__init__(f, batch_shape, (f, self.obs_dim), - validate_args=self._validate_args) + super(GaussianHMM, right).__init__( + f, batch_shape, (f, self.obs_dim), validate_args=self._validate_args + ) return right @@ -726,11 +801,21 @@ class GammaGaussianHMM(HiddenMarkovModel): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ + arg_constraints = {} support = constraints.independent(constraints.real, 2) - def __init__(self, scale_dist, initial_dist, transition_matrix, transition_dist, - observation_matrix, observation_dist, validate_args=None, duration=None): + def __init__( + self, + scale_dist, + initial_dist, + transition_matrix, + transition_dist, + observation_matrix, + observation_dist, + validate_args=None, + duration=None, + ): assert isinstance(scale_dist, Gamma) assert isinstance(initial_dist, MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) @@ -742,20 +827,28 @@ def __init__(self, scale_dist, initial_dist, transition_matrix, transition_dist, assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim,) assert observation_dist.event_shape == (obs_dim,) - shape = broadcast_shape(scale_dist.batch_shape + (1,), - initial_dist.batch_shape + (1,), - transition_matrix.shape[:-2], - transition_dist.batch_shape, - observation_matrix.shape[:-2], - observation_dist.batch_shape) + shape = broadcast_shape( + scale_dist.batch_shape + (1,), + initial_dist.batch_shape + (1,), + transition_matrix.shape[:-2], + transition_dist.batch_shape, + observation_matrix.shape[:-2], + observation_dist.batch_shape, + ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) - super().__init__(duration, batch_shape, event_shape, validate_args=validate_args) + super().__init__( + duration, batch_shape, event_shape, validate_args=validate_args + ) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = gamma_and_mvn_to_gamma_gaussian(scale_dist, initial_dist) - self._trans = matrix_and_mvn_to_gamma_gaussian(transition_matrix, transition_dist) - self._obs = matrix_and_mvn_to_gamma_gaussian(observation_matrix, observation_dist) + self._trans = matrix_and_mvn_to_gamma_gaussian( + transition_matrix, transition_dist + ) + self._obs = matrix_and_mvn_to_gamma_gaussian( + observation_matrix, observation_dist + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(GammaGaussianHMM, _instance) @@ -768,8 +861,10 @@ def expand(self, batch_shape, _instance=None): new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs - super(GammaGaussianHMM, new).__init__(self.duration, batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(GammaGaussianHMM, new).__init__( + self.duration, batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def log_prob(self, value): @@ -777,7 +872,9 @@ def log_prob(self, value): self._validate_sample(value) # Combine observation and transition factors. - result = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) + result = self._trans + self._obs.condition(value).event_pad( + left=self.hidden_dim + ) # Eliminate time dimension. result = _sequential_gamma_gaussian_tensordot(result.expand(result.batch_shape)) @@ -818,13 +915,15 @@ def filter(self, value): # Posterior of the scale gamma_dist = logp.event_logsumexp() - scale_post = Gamma(gamma_dist.concentration, gamma_dist.rate, - validate_args=self._validate_args) + scale_post = Gamma( + gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args + ) # Conditional of last state on unit scale scale_tril = cholesky(logp.precision) loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1) - mvn = MultivariateNormal(loc, scale_tril=scale_tril, - validate_args=self._validate_args) + mvn = MultivariateNormal( + loc, scale_tril=scale_tril, validate_args=self._validate_args + ) return scale_post, mvn @@ -899,17 +998,27 @@ class LinearHMM(HiddenMarkovModel): support = constraints.independent(constraints.real, 2) has_rsample = True - def __init__(self, initial_dist, transition_matrix, transition_dist, - observation_matrix, observation_dist, - validate_args=None, duration=None): + def __init__( + self, + initial_dist, + transition_matrix, + transition_dist, + observation_matrix, + observation_dist, + validate_args=None, + duration=None, + ): assert initial_dist.has_rsample assert initial_dist.event_dim == 1 - assert (isinstance(transition_matrix, torch.Tensor) and - transition_matrix.dim() >= 2) + assert ( + isinstance(transition_matrix, torch.Tensor) and transition_matrix.dim() >= 2 + ) assert transition_dist.has_rsample assert transition_dist.event_dim == 1 - assert (isinstance(observation_matrix, torch.Tensor) and - observation_matrix.dim() >= 2) + assert ( + isinstance(observation_matrix, torch.Tensor) + and observation_matrix.dim() >= 2 + ) assert observation_dist.has_rsample assert observation_dist.event_dim == 1 @@ -918,26 +1027,32 @@ def __init__(self, initial_dist, transition_matrix, transition_dist, assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim,) assert observation_dist.event_shape == (obs_dim,) - shape = broadcast_shape(initial_dist.batch_shape + (1,), - transition_matrix.shape[:-2], - transition_dist.batch_shape, - observation_matrix.shape[:-2], - observation_dist.batch_shape) + shape = broadcast_shape( + initial_dist.batch_shape + (1,), + transition_matrix.shape[:-2], + transition_dist.batch_shape, + observation_matrix.shape[:-2], + observation_dist.batch_shape, + ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) - super().__init__(duration, batch_shape, event_shape, validate_args=validate_args) + super().__init__( + duration, batch_shape, event_shape, validate_args=validate_args + ) # Expand eagerly. if initial_dist.batch_shape != batch_shape: initial_dist = initial_dist.expand(batch_shape) if transition_matrix.shape[:-2] != batch_shape + time_shape: transition_matrix = transition_matrix.expand( - batch_shape + time_shape + (hidden_dim, hidden_dim)) + batch_shape + time_shape + (hidden_dim, hidden_dim) + ) if transition_dist.batch_shape != batch_shape + time_shape: transition_dist = transition_dist.expand(batch_shape + time_shape) if observation_matrix.shape[:-2] != batch_shape + time_shape: observation_matrix = observation_matrix.expand( - batch_shape + time_shape + (hidden_dim, obs_dim)) + batch_shape + time_shape + (hidden_dim, obs_dim) + ) if observation_dist.batch_shape != batch_shape + time_shape: observation_dist = observation_dist.expand(batch_shape + time_shape) @@ -946,7 +1061,9 @@ def __init__(self, initial_dist, transition_matrix, transition_dist, while True: if isinstance(observation_dist, torch.distributions.Independent): observation_dist = observation_dist.base_dist - elif isinstance(observation_dist, torch.distributions.TransformedDistribution): + elif isinstance( + observation_dist, torch.distributions.TransformedDistribution + ): transforms = observation_dist.transforms + transforms observation_dist = observation_dist.base_dist else: @@ -975,14 +1092,18 @@ def expand(self, batch_shape, _instance=None): new.obs_dim = self.obs_dim new.initial_dist = self.initial_dist.expand(batch_shape) new.transition_matrix = self.transition_matrix.expand( - batch_shape + time_shape + (self.hidden_dim, self.hidden_dim)) + batch_shape + time_shape + (self.hidden_dim, self.hidden_dim) + ) new.transition_dist = self.transition_dist.expand(batch_shape + time_shape) new.observation_matrix = self.observation_matrix.expand( - batch_shape + time_shape + (self.hidden_dim, self.obs_dim)) + batch_shape + time_shape + (self.hidden_dim, self.obs_dim) + ) new.observation_dist = self.observation_dist.expand(batch_shape + time_shape) new.transforms = self.transforms - super(LinearHMM, new).__init__(self.duration, batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(LinearHMM, new).__init__( + self.duration, batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def log_prob(self, value): @@ -991,9 +1112,15 @@ def log_prob(self, value): def rsample(self, sample_shape=torch.Size()): assert self.duration is not None init = self.initial_dist.rsample(sample_shape) - trans = self.transition_dist.expand(self.batch_shape + (self.duration,)).rsample(sample_shape) - obs = self.observation_dist.expand(self.batch_shape + (self.duration,)).rsample(sample_shape) - trans_matrix = self.transition_matrix.expand(self.batch_shape + (self.duration, -1, -1)) + trans = self.transition_dist.expand( + self.batch_shape + (self.duration,) + ).rsample(sample_shape) + obs = self.observation_dist.expand(self.batch_shape + (self.duration,)).rsample( + sample_shape + ) + trans_matrix = self.transition_matrix.expand( + self.batch_shape + (self.duration, -1, -1) + ) z = _linear_integrate(init, trans_matrix, trans) x = (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs for t in self.transforms: @@ -1016,6 +1143,7 @@ class IndependentHMM(TorchDistribution): :param HiddenMarkovModel base_dist: A base hidden Markov model instance. """ + arg_constraints = {} def __init__(self, base_dist): @@ -1042,9 +1170,13 @@ def duration(self): def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(IndependentHMM, _instance) - new.base_dist = self.base_dist.expand(batch_shape + self.base_dist.batch_shape[-1:]) - super(IndependentHMM, new).__init__(batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + new.base_dist = self.base_dist.expand( + batch_shape + self.base_dist.batch_shape[-1:] + ) + super(IndependentHMM, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def rsample(self, sample_shape=torch.Size()): @@ -1096,9 +1228,12 @@ class GaussianMRF(TorchDistribution): have batch_shape broadcastable to ``self.batch_shape + (num_steps,)``. This should have event_shape ``(hidden_dim + obs_dim,)``. """ + arg_constraints = {} - def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): + def __init__( + self, initial_dist, transition_dist, observation_dist, validate_args=None + ): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) @@ -1106,9 +1241,11 @@ def __init__(self, initial_dist, transition_dist, observation_dist, validate_arg assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim - shape = broadcast_shape(initial_dist.batch_shape + (1,), - transition_dist.batch_shape, - observation_dist.batch_shape) + shape = broadcast_shape( + initial_dist.batch_shape + (1,), + transition_dist.batch_shape, + observation_dist.batch_shape, + ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) super().__init__(batch_shape, event_shape, validate_args=validate_args) @@ -1135,8 +1272,10 @@ def expand(self, batch_shape, _instance=None): new._trans = self._trans new._obs = self._obs new._support = self._support - super(GaussianMRF, new).__init__(batch_shape, self.event_shape, validate_args=False) - new._validate_args = self.__dict__.get('_validate_args') + super(GaussianMRF, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self.__dict__.get("_validate_args") return new def log_prob(self, value): @@ -1146,13 +1285,16 @@ def log_prob(self, value): # Combine observation and transition factors. logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim) - logp_h += self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim) + logp_h += self._obs.marginalize(right=self.obs_dim).event_pad( + left=self.hidden_dim + ) # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian. batch_dim = 1 + max(len(self._init.batch_shape) + 1, len(logp_oh.batch_shape)) - batch_shape = (1,) * (batch_dim - len(logp_oh.batch_shape)) + logp_oh.batch_shape - logp = Gaussian.cat([logp_oh.expand(batch_shape), - logp_h.expand(batch_shape)]) + batch_shape = (1,) * ( + batch_dim - len(logp_oh.batch_shape) + ) + logp_oh.batch_shape + logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp) diff --git a/pyro/distributions/improper_uniform.py b/pyro/distributions/improper_uniform.py index 3c4ad0c5dc..10071c21a9 100644 --- a/pyro/distributions/improper_uniform.py +++ b/pyro/distributions/improper_uniform.py @@ -40,6 +40,7 @@ class ImproperUniform(TorchDistribution): :param torch.Size batch_shape: The batch shape. :param torch.Size event_shape: The event shape. """ + arg_constraints = {} def __init__(self, support, batch_shape, event_shape): @@ -59,7 +60,7 @@ def expand(self, batch_shape, _instance=None): return new def log_prob(self, value): - batch_shape = value.shape[:value.dim() - self.event_dim] + batch_shape = value.shape[: value.dim() - self.event_dim] batch_shape = broadcast_shape(batch_shape, self.batch_shape) return torch.zeros(()).expand(batch_shape) diff --git a/pyro/distributions/inverse_gamma.py b/pyro/distributions/inverse_gamma.py index 24dba5fa86..6235aa3000 100644 --- a/pyro/distributions/inverse_gamma.py +++ b/pyro/distributions/inverse_gamma.py @@ -18,14 +18,20 @@ class InverseGamma(TransformedDistribution): :param torch.Tensor concentration: the concentration parameter (i.e. alpha). :param torch.Tensor rate: the rate parameter (i.e. beta). """ - arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive} + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } support = constraints.positive has_rsample = True def __init__(self, concentration, rate, validate_args=None): base_dist = Gamma(concentration, rate) - super().__init__(base_dist, PowerTransform(-base_dist.rate.new_ones(())), - validate_args=validate_args) + super().__init__( + base_dist, + PowerTransform(-base_dist.rate.new_ones(())), + validate_args=validate_args, + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(InverseGamma, _instance) diff --git a/pyro/distributions/kl.py b/pyro/distributions/kl.py index 6fcef4a1a5..2ee9f66ebf 100644 --- a/pyro/distributions/kl.py +++ b/pyro/distributions/kl.py @@ -43,10 +43,12 @@ def _kl_independent_mvn(p, q): dim = q.event_shape[0] p_cov = p.base_dist.scale ** 2 q_precision = q.precision_matrix.diagonal(dim1=-2, dim2=-1) - return (0.5 * (p_cov * q_precision).sum(-1) - - 0.5 * dim * (1 + math.log(2 * math.pi)) - - q.log_prob(p.base_dist.loc) - - p.base_dist.scale.log().sum(-1)) + return ( + 0.5 * (p_cov * q_precision).sum(-1) + - 0.5 * dim * (1 + math.log(2 * math.pi)) + - q.log_prob(p.base_dist.loc) + - p.base_dist.scale.log().sum(-1) + ) raise NotImplementedError diff --git a/pyro/distributions/lkj.py b/pyro/distributions/lkj.py index f6cd9fe100..720c1eb0cf 100644 --- a/pyro/distributions/lkj.py +++ b/pyro/distributions/lkj.py @@ -14,8 +14,8 @@ class LKJCorrCholesky(LKJCholesky): # DEPRECATED def __init__(self, d, eta, validate_args=None): warnings.warn( - 'class LKJCorrCholesky(d, eta, validate_args=None) is deprecated ' - 'in favor of LKJCholesky(dim, concentration, validate_args=None).', + "class LKJCorrCholesky(d, eta, validate_args=None) is deprecated " + "in favor of LKJCholesky(dim, concentration, validate_args=None).", FutureWarning, ) super().__init__(d, concentration=eta, validate_args=validate_args) @@ -42,14 +42,15 @@ class LKJ(TransformedDistribution): [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ - arg_constraints = {'concentration': constraints.positive} + arg_constraints = {"concentration": constraints.positive} support = constraints.corr_matrix - def __init__(self, dim, concentration=1., validate_args=None): + def __init__(self, dim, concentration=1.0, validate_args=None): base_dist = LKJCholesky(dim, concentration) self.dim, self.concentration = base_dist.dim, base_dist.concentration - super(LKJ, self).__init__(base_dist, CorrMatrixCholeskyTransform().inv, - validate_args=validate_args) + super(LKJ, self).__init__( + base_dist, CorrMatrixCholeskyTransform().inv, validate_args=validate_args + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LKJCholesky, _instance) diff --git a/pyro/distributions/logistic.py b/pyro/distributions/logistic.py index ffdd0c26f2..b98ea10139 100644 --- a/pyro/distributions/logistic.py +++ b/pyro/distributions/logistic.py @@ -40,12 +40,15 @@ class SkewLogistic(TorchDistribution): distribution. """ - arg_constraints = {"loc": constraints.real, "scale": constraints.positive, - "asymmetry": constraints.positive} + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "asymmetry": constraints.positive, + } support = constraints.real has_rsample = True - def __init__(self, loc, scale, asymmetry=1., *, validate_args=None): + def __init__(self, loc, scale, asymmetry=1.0, *, validate_args=None): self.loc, self.scale, self.asymmetry = broadcast_all(loc, scale, asymmetry) super().__init__(self.loc.shape, validate_args=validate_args) diff --git a/pyro/distributions/mixture.py b/pyro/distributions/mixture.py index 6eb8378c08..9f9ed501f1 100644 --- a/pyro/distributions/mixture.py +++ b/pyro/distributions/mixture.py @@ -19,6 +19,7 @@ class MaskedConstraint(constraints.Constraint): :param torch.constraints.Constraint constraint1: constraint that holds wherever ``mask == 1`` """ + def __init__(self, mask, constraint0, constraint1): self.mask = mask self.constraint0 = constraint0 @@ -26,7 +27,11 @@ def __init__(self, mask, constraint0, constraint1): def check(self, value): result = self.constraint0.check(value) - mask = self.mask.expand(result.shape) if result.shape != self.mask.shape else self.mask + mask = ( + self.mask.expand(result.shape) + if result.shape != self.mask.shape + else self.mask + ) result[mask] = self.constraint1.check(value)[mask] return result @@ -55,15 +60,23 @@ class MaskedMixture(TorchDistribution): :param pyro.distributions.TorchDistribution component1: a distribution for batch elements ``mask == True``. """ + arg_constraints = {} # nothing can be constrained def __init__(self, mask, component0, component1, validate_args=None): if not torch.is_tensor(mask) or mask.dtype != torch.bool: - raise ValueError('Expected mask to be a BoolTensor but got {}'.format(type(mask))) + raise ValueError( + "Expected mask to be a BoolTensor but got {}".format(type(mask)) + ) if component0.event_shape != component1.event_shape: - raise ValueError('components event_shape disagree: {} vs {}' - .format(component0.event_shape, component1.event_shape)) - batch_shape = broadcast_shape(mask.shape, component0.batch_shape, component1.batch_shape) + raise ValueError( + "components event_shape disagree: {} vs {}".format( + component0.event_shape, component1.event_shape + ) + ) + batch_shape = broadcast_shape( + mask.shape, component0.batch_shape, component1.batch_shape + ) if mask.shape != batch_shape: mask = mask.expand(batch_shape) if component0.batch_shape != batch_shape: @@ -89,7 +102,9 @@ def has_rsample(self): def support(self): if self.component0.support is self.component1.support: return self.component0.support - return MaskedConstraint(self.mask, self.component0.support, self.component1.support) + return MaskedConstraint( + self.mask, self.component0.support, self.component1.support + ) def expand(self, batch_shape): try: @@ -103,17 +118,21 @@ def expand(self, batch_shape): def sample(self, sample_shape=torch.Size()): mask = self.mask.reshape(self.mask.shape + (1,) * self.event_dim) mask = mask.expand(sample_shape + self.shape()) - result = torch.where(mask, - self.component1.sample(sample_shape), - self.component0.sample(sample_shape)) + result = torch.where( + mask, + self.component1.sample(sample_shape), + self.component0.sample(sample_shape), + ) return result def rsample(self, sample_shape=torch.Size()): mask = self.mask.reshape(self.mask.shape + (1,) * self.event_dim) mask = mask.expand(sample_shape + self.shape()) - result = torch.where(mask, - self.component1.rsample(sample_shape), - self.component0.rsample(sample_shape)) + result = torch.where( + mask, + self.component1.rsample(sample_shape), + self.component0.rsample(sample_shape), + ) return result def log_prob(self, value): @@ -122,13 +141,13 @@ def log_prob(self, value): value = value.expand(value_shape) if self._validate_args: self._validate_sample(value) - mask_shape = value_shape[:len(value_shape) - len(self.event_shape)] + mask_shape = value_shape[: len(value_shape) - len(self.event_shape)] mask = self.mask if mask.shape != mask_shape: mask = mask.expand(mask_shape) - result = torch.where(mask, - self.component1.log_prob(value), - self.component0.log_prob(value)) + result = torch.where( + mask, self.component1.log_prob(value), self.component0.log_prob(value) + ) return result @lazy_property diff --git a/pyro/distributions/multivariate_studentt.py b/pyro/distributions/multivariate_studentt.py index 8b6a2d77c4..895ff66182 100644 --- a/pyro/distributions/multivariate_studentt.py +++ b/pyro/distributions/multivariate_studentt.py @@ -22,9 +22,12 @@ class MultivariateStudentT(TorchDistribution): :param ~torch.Tensor scale_tril: scale of the distribution, which is a lower triangular matrix with positive diagonal entries """ - arg_constraints = {'df': constraints.positive, - 'loc': constraints.real_vector, - 'scale_tril': constraints.lower_cholesky} + + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real_vector, + "scale_tril": constraints.lower_cholesky, + } support = constraints.real_vector has_rsample = True @@ -44,21 +47,26 @@ def __init__(self, df, loc, scale_tril, validate_args=None): @lazy_property def scale_tril(self): return self._unbroadcasted_scale_tril.expand( - self._batch_shape + self._event_shape + self._event_shape) + self._batch_shape + self._event_shape + self._event_shape + ) @lazy_property def covariance_matrix(self): # NB: this is not covariance of this distribution; # the actual covariance is df / (df - 2) * covariance_matrix - return (torch.matmul(self._unbroadcasted_scale_tril, - self._unbroadcasted_scale_tril.transpose(-1, -2)) - .expand(self._batch_shape + self._event_shape + self._event_shape)) + return torch.matmul( + self._unbroadcasted_scale_tril, + self._unbroadcasted_scale_tril.transpose(-1, -2), + ).expand(self._batch_shape + self._event_shape + self._event_shape) @lazy_property def precision_matrix(self): - identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype) + identity = torch.eye( + self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype + ) return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( - self._batch_shape + self._event_shape + self._event_shape) + self._batch_shape + self._event_shape + self._event_shape + ) @staticmethod def infer_shapes(df, loc, scale_tril): @@ -74,14 +82,16 @@ def expand(self, batch_shape, _instance=None): new.df = self.df.expand(batch_shape) new.loc = self.loc.expand(loc_shape) new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril - if 'scale_tril' in self.__dict__: + if "scale_tril" in self.__dict__: new.scale_tril = self.scale_tril.expand(scale_shape) - if 'covariance_matrix' in self.__dict__: + if "covariance_matrix" in self.__dict__: new.covariance_matrix = self.covariance_matrix.expand(scale_shape) - if 'precision_matrix' in self.__dict__: + if "precision_matrix" in self.__dict__: new.precision_matrix = self.precision_matrix.expand(scale_shape) new._chi2 = self._chi2.expand(batch_shape) - super(MultivariateStudentT, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(MultivariateStudentT, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -96,23 +106,30 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) n = self.loc.size(-1) - y = (value - self.loc).unsqueeze(-1).triangular_solve(self.scale_tril, upper=False).solution.squeeze(-1) - Z = (self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + - 0.5 * n * self.df.log() + - 0.5 * n * math.log(math.pi) + - torch.lgamma(0.5 * self.df) - - torch.lgamma(0.5 * (self.df + n))) + y = ( + (value - self.loc) + .unsqueeze(-1) + .triangular_solve(self.scale_tril, upper=False) + .solution.squeeze(-1) + ) + Z = ( + self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + + 0.5 * n * self.df.log() + + 0.5 * n * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + n)) + ) return -0.5 * (self.df + n) * torch.log1p(y.pow(2).sum(-1) / self.df) - Z @property def mean(self): m = self.loc.clone() - m[self.df <= 1, :] = float('nan') + m[self.df <= 1, :] = float("nan") return m @property def variance(self): m = self.scale_tril.pow(2).sum(-1) * (self.df / (self.df - 2)).unsqueeze(-1) - m[(self.df <= 2) & (self.df > 1), :] = float('inf') - m[self.df <= 1, :] = float('nan') + m[(self.df <= 2) & (self.df > 1), :] = float("inf") + m[self.df <= 1, :] = float("nan") return m diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index a0a77e10fd..c87d52f561 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -21,7 +21,11 @@ class OMTMultivariateNormal(MultivariateNormal): :param torch.Tensor loc: Mean. :param torch.Tensor scale_tril: Cholesky of Covariance matrix. """ - arg_constraints = {"loc": constraints.real, "scale_tril": constraints.lower_triangular} + + arg_constraints = { + "loc": constraints.real, + "scale_tril": constraints.lower_triangular, + } def __init__(self, loc, scale_tril): if loc.dim() != 1: @@ -31,7 +35,9 @@ def __init__(self, loc, scale_tril): super().__init__(loc, scale_tril=scale_tril) def rsample(self, sample_shape=torch.Size()): - return _OMTMVNSample.apply(self.loc, self.scale_tril, sample_shape + self.loc.shape) + return _OMTMVNSample.apply( + self.loc, self.scale_tril, sample_shape + self.loc.shape + ) class _OMTMVNSample(Function): @@ -69,11 +75,15 @@ def backward(ctx, grad_output): z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple) g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple) - Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2) + Y = sum_leftmost( + torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2 + ) Y = torch.mm(V, torch.mm(Y, V.t())) Y = Y + Y.t() - Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv)) + Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm( + Y, torch.mm(Sigma_inv, R_inv) + ) diff_L_ab += 0.5 * Tr_xi_Y L_grad = torch.tril(diff_L_ab) diff --git a/pyro/distributions/one_one_matching.py b/pyro/distributions/one_one_matching.py index c8ba769054..2d808611e8 100644 --- a/pyro/distributions/one_one_matching.py +++ b/pyro/distributions/one_one_matching.py @@ -149,8 +149,9 @@ def sample(self, sample_shape=torch.Size()): return d[sample] if sample_shape: - return torch.stack([self.sample(sample_shape[1:]) - for _ in range(sample_shape[0])]) + return torch.stack( + [self.sample(sample_shape[1:]) for _ in range(sample_shape[0])] + ) # TODO initialize via .mode(), then perform a small number of MCMC steps # https://www.cc.gatech.edu/~vigoda/Permanent.pdf # https://papers.nips.cc/paper/2012/file/4c27cea8526af8cfee3be5e183ac9605-Paper.pdf diff --git a/pyro/distributions/one_two_matching.py b/pyro/distributions/one_two_matching.py index 9737cff659..f423769279 100644 --- a/pyro/distributions/one_two_matching.py +++ b/pyro/distributions/one_two_matching.py @@ -158,8 +158,9 @@ def sample(self, sample_shape=torch.Size()): return d[sample] if sample_shape: - return torch.stack([self.sample(sample_shape[1:]) - for _ in range(sample_shape[0])]) + return torch.stack( + [self.sample(sample_shape[1:]) for _ in range(sample_shape[0])] + ) # TODO initialize via .mode(), then perform a small number of MCMC steps # https://www.cc.gatech.edu/~vigoda/Permanent.pdf # https://papers.nips.cc/paper/2012/file/4c27cea8526af8cfee3be5e183ac9605-Paper.pdf @@ -182,21 +183,21 @@ def enumerate_one_two_matchings(num_destins): num_sources = num_destins * 2 subproblem = enumerate_one_two_matchings(num_destins - 1) subsize = subproblem.size(0) - result = torch.empty(subsize * num_sources * (num_sources - 1) // 2, - num_sources, - dtype=torch.long) + result = torch.empty( + subsize * num_sources * (num_sources - 1) // 2, num_sources, dtype=torch.long + ) # Iterate over pairs of sources s0 0] curr_b0 = b0[missing > 0] - x = torch.distributions.Normal(0., torch.sqrt(1 + 2 * curr_eig / curr_b0)).sample( - (missing[missing > 0].min(),)).view(2, -1, missing[missing > 0].min()) + x = ( + torch.distributions.Normal(0.0, torch.sqrt(1 + 2 * curr_eig / curr_b0)) + .sample((missing[missing > 0].min(),)) + .view(2, -1, missing[missing > 0].min()) + ) x /= x.norm(dim=0)[None, ...] # Angular Central Gaussian distribution - lf = curr_conc[0] * (x[0] - 1) + eigmin[missing > 0] + log_I1(0, torch.sqrt( - curr_conc[1] ** 2 + (curr_corr * x[1]) ** 2)).squeeze(0) - phi_den[missing > 0] + lf = ( + curr_conc[0] * (x[0] - 1) + + eigmin[missing > 0] + + log_I1( + 0, torch.sqrt(curr_conc[1] ** 2 + (curr_corr * x[1]) ** 2) + ).squeeze(0) + - phi_den[missing > 0] + ) assert lf.shape == ((missing > 0).sum(), missing[missing > 0].min()) - lg_inv = 1. - curr_b0.view(-1, 1) / 2 + torch.log( - curr_b0.view(-1, 1) / 2 + (curr_eig.view(2, -1, 1) * x ** 2).sum(0)) + lg_inv = ( + 1.0 + - curr_b0.view(-1, 1) / 2 + + torch.log( + curr_b0.view(-1, 1) / 2 + (curr_eig.view(2, -1, 1) * x ** 2).sum(0) + ) + ) assert lg_inv.shape == lf.shape - accepted = torch.distributions.Uniform(0., torch.ones((), device=conc.device)).sample(lf.shape) < ( - lf + lg_inv).exp() - - phi_mask = torch.zeros((*missing.shape, total), dtype=torch.bool, device=conc.device) - phi_mask[missing > 0] = torch.logical_and(lengths < (start[missing > 0] + accepted.sum(-1)).view(-1, 1), - lengths >= start[missing > 0].view(-1, 1)) + accepted = ( + torch.distributions.Uniform( + 0.0, torch.ones((), device=conc.device) + ).sample(lf.shape) + < (lf + lg_inv).exp() + ) + + phi_mask = torch.zeros( + (*missing.shape, total), dtype=torch.bool, device=conc.device + ) + phi_mask[missing > 0] = torch.logical_and( + lengths < (start[missing > 0] + accepted.sum(-1)).view(-1, 1), + lengths >= start[missing > 0].view(-1, 1), + ) phi[:, phi_mask] = x[:, accepted] @@ -187,8 +251,10 @@ def sample(self, sample_shape=torch.Size()): max_iter -= 1 if max_iter == 0 or torch.any(missing > 0): - raise ValueError("maximum number of iterations exceeded; " - "try increasing `SineBivariateVonMises.max_sample_iter`") + raise ValueError( + "maximum number of iterations exceeded; " + "try increasing `SineBivariateVonMises.max_sample_iter`" + ) phi = torch.atan2(phi[1], phi[0]) @@ -197,8 +263,13 @@ def sample(self, sample_shape=torch.Size()): psi = VonMises(beta, alpha).sample() - phi_psi = torch.stack(((phi + self.phi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi, - (psi + self.psi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi), dim=-1).permute(1, 0, 2) + phi_psi = torch.stack( + ( + (phi + self.phi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi, + (psi + self.psi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi, + ), + dim=-1, + ).permute(1, 0, 2) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape) @property @@ -216,12 +287,18 @@ def expand(self, batch_shape, _instance=None): for k in SineBivariateVonMises.arg_constraints.keys(): setattr(new, k, getattr(self, k).expand(batch_shape)) new.norm_const = self.norm_const.expand(batch_shape) - super(SineBivariateVonMises, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(SineBivariateVonMises, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new def _bfind(self, eig): - b = eig.shape[0] / 2 * torch.ones(self.batch_shape, dtype=eig.dtype, device=eig.device) + b = ( + eig.shape[0] + / 2 + * torch.ones(self.batch_shape, dtype=eig.dtype, device=eig.device) + ) g1 = torch.sum(1 / (b + 2 * eig) ** 2, dim=0) g2 = torch.sum(-2 / (b + 2 * eig) ** 3, dim=0) return torch.where(eig.norm(0) != 0, b - g1 / g2, b) diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index 6156c14250..e705e7de64 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -82,15 +82,19 @@ def model(obs): :class:`~pyro.distributions.Uniform` (-pi, pi). :param torch.tensor skewness: skewness of the distribution. """ - arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)} + + arg_constraints = { + "skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1) + } support = constraints.independent(constraints.real, 1) def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): - assert base_dist.event_shape == skewness.shape[-1:], \ - 'Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`.' + assert ( + base_dist.event_shape == skewness.shape[-1:] + ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." - if (skewness.abs().sum(-1) > 1.).any(): + if (skewness.abs().sum(-1) > 1.0).any(): warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning) batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) @@ -100,22 +104,42 @@ def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): super().__init__(batch_shape, event_shape, validate_args=validate_args) if self._validate_args and base_dist.mean.device != skewness.device: - raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " - f"must be on same device.") + raise ValueError( + f"base_density: {base_dist.__class__.__name__} and SineSkewed " + f"must be on same device." + ) def __repr__(self): - args_string = ', '.join(['{}: {}'.format(p, getattr(self, p) - if getattr(self, p).numel() == 1 - else getattr(self, p).size()) for p in self.arg_constraints.keys()]) - return self.__class__.__name__ + '(' + f'base_density: {str(self.base_dist)}, ' + args_string + ')' + args_string = ", ".join( + [ + "{}: {}".format( + p, + getattr(self, p) + if getattr(self, p).numel() == 1 + else getattr(self, p).size(), + ) + for p in self.arg_constraints.keys() + ] + ) + return ( + self.__class__.__name__ + + "(" + + f"base_density: {str(self.base_dist)}, " + + args_string + + ")" + ) def sample(self, sample_shape=torch.Size()): bd = self.base_dist ys = bd.sample(sample_shape) - u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape) + u = Uniform(0.0, self.skewness.new_ones(())).sample( + sample_shape + self.batch_shape + ) # Section 2.3 step 3 in [1] - mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) + mask = u <= 0.5 + 0.5 * ( + self.skewness * torch.sin((ys - bd.mean) % (2 * pi)) + ).sum(-1) mask = mask[..., None] samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples @@ -125,7 +149,11 @@ def log_prob(self, value): self._validate_sample(value) # Eq. 2.1 in [1] - skew_prob = torch.log1p((self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1)) + skew_prob = torch.log1p( + (self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum( + -1 + ) + ) return self.base_dist.log_prob(value) + skew_prob def expand(self, batch_shape, _instance=None): @@ -134,6 +162,8 @@ def expand(self, batch_shape, _instance=None): base_dist = self.base_dist.expand(batch_shape) new.base_dist = base_dist new.skewness = self.skewness.expand(batch_shape + (-1,)) - super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(SineSkewed, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index 87a5e94040..b8c38714f7 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -49,7 +49,8 @@ class SpanningTree(TorchDistribution): ``initial_edges`` defaulting to a cheap approximate sample; ``backend`` one of "python" or "cpp", defaulting to "python". """ - arg_constraints = {'edge_logits': constraints.real} + + arg_constraints = {"edge_logits": constraints.real} support = constraints.nonnegative_integer has_enumerate_support = True @@ -57,7 +58,7 @@ def __init__(self, edge_logits, sampler_options=None, validate_args=None): if edge_logits.is_cuda: raise NotImplementedError("SpanningTree does not support cuda tensors") K = len(edge_logits) - V = int(round(0.5 + (0.25 + 2 * K)**0.5)) + V = int(round(0.5 + (0.25 + 2 * K) ** 0.5)) assert K == V * (V - 1) // 2 E = V - 1 event_shape = (E, 2) @@ -66,8 +67,11 @@ def __init__(self, edge_logits, sampler_options=None, validate_args=None): super().__init__(batch_shape, event_shape, validate_args=validate_args) if self._validate_args: if edge_logits.shape != (K,): - raise ValueError("Expected edge_logits of shape ({},), but got shape {}" - .format(K, edge_logits.shape)) + raise ValueError( + "Expected edge_logits of shape ({},), but got shape {}".format( + K, edge_logits.shape + ) + ) self.num_vertices = V self.sampler_options = {} if sampler_options is None else sampler_options @@ -88,10 +92,16 @@ def validate_edges(self, edges): raise ValueError("Invalid vertex ids:\n{}".format(edges)) if not (edges[..., 0] < edges[..., 1]).all(): raise ValueError("Vertices are not sorted in each edge:\n{}".format(edges)) - if not ((edges[..., :-1, 1] < edges[..., 1:, 1]) | - ((edges[..., :-1, 1] == edges[..., 1:, 1]) & - (edges[..., :-1, 0] < edges[..., 1:, 0]))).all(): - raise ValueError("Edges are not sorted colexicographically:\n{}".format(edges)) + if not ( + (edges[..., :-1, 1] < edges[..., 1:, 1]) + | ( + (edges[..., :-1, 1] == edges[..., 1:, 1]) + & (edges[..., :-1, 0] < edges[..., 1:, 0]) + ) + ).all(): + raise ValueError( + "Edges are not sorted colexicographically:\n{}".format(edges) + ) # Verify tree property, i.e. connectivity. V = self.num_vertices @@ -123,6 +133,7 @@ def log_partition_function(self): truncated = laplacian[:-1, :-1] try: import gpytorch + log_det = gpytorch.lazy.NonLazyTensor(truncated).logdet() except ImportError: log_det = torch.linalg.cholesky(truncated).diag().log().sum() * 2 @@ -220,11 +231,13 @@ def _get_cpp_module(): import os from torch.utils.cpp_extension import load - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp") - _cpp_module = load(name="cpp_spanning_tree", - sources=[path], - extra_cflags=['-O2'], - verbose=True) + + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "spanning_tree.cpp" + ) + _cpp_module = load( + name="cpp_spanning_tree", sources=[path], extra_cflags=["-O2"], verbose=True + ) return _cpp_module @@ -247,7 +260,7 @@ def make_complete_graph(num_vertices, backend="python"): def _make_complete_graph(num_vertices): if num_vertices < 2: - raise ValueError('PyTorch cannot handle zero-sized multidimensional tensors') + raise ValueError("PyTorch cannot handle zero-sized multidimensional tensors") V = num_vertices K = V * (V - 1) // 2 v1 = torch.arange(V) @@ -255,7 +268,7 @@ def _make_complete_graph(num_vertices): v1, v2 = torch.broadcast_tensors(v1, v2) v1 = v1.contiguous().view(-1) v2 = v2.contiguous().view(-1) - mask = (v1 < v2) + mask = v1 < v2 grid = torch.stack((v1[mask], v2[mask])) assert grid.shape == (2, K) return grid @@ -401,7 +414,7 @@ def sample_tree_mcmc(edge_logits, edges, backend="python"): @torch.no_grad() def _sample_tree_approx(edge_logits): K = len(edge_logits) - V = int(round(0.5 + (0.25 + 2 * K)**0.5)) + V = int(round(0.5 + (0.25 + 2 * K) ** 0.5)) assert K == V * (V - 1) // 2 E = V - 1 grid = make_complete_graph(V) @@ -421,7 +434,7 @@ def _sample_tree_approx(edge_logits): # Sample edges connecting the cumulative tree to a new leaf. for e in range(1, E): c1, c2 = components[grid] - mask = (c1 != c2) + mask = c1 != c2 valid_logits = edge_logits[mask] probs = (valid_logits - valid_logits.max()).exp() k = mask.nonzero(as_tuple=False)[torch.multinomial(probs, 1)[0]] @@ -469,7 +482,7 @@ def sample_tree(edge_logits, init_edges=None, mcmc_steps=1, backend="python"): @torch.no_grad() def _find_best_tree(edge_logits): K = len(edge_logits) - V = int(round(0.5 + (0.25 + 2 * K)**0.5)) + V = int(round(0.5 + (0.25 + 2 * K) ** 0.5)) assert K == V * (V - 1) // 2 E = V - 1 grid = make_complete_graph(V) @@ -488,7 +501,7 @@ def _find_best_tree(edge_logits): # Find edges connecting the cumulative tree to a new leaf. for e in range(1, E): c1, c2 = components[grid] - mask = (c1 != c2) + mask = c1 != c2 valid_logits = edge_logits[mask] k = valid_logits.argmax(0).item() k = mask.nonzero(as_tuple=False)[k] @@ -528,9 +541,25 @@ def find_best_tree(edge_logits, backend="python"): # See https://oeis.org/A000272 NUM_SPANNING_TREES = [ - 1, 1, 1, 3, 16, 125, 1296, 16807, 262144, 4782969, 100000000, 2357947691, - 61917364224, 1792160394037, 56693912375296, 1946195068359375, - 72057594037927936, 2862423051509815793, 121439531096594251776, + 1, + 1, + 1, + 3, + 16, + 125, + 1296, + 16807, + 262144, + 4782969, + 100000000, + 2357947691, + 61917364224, + 1792160394037, + 56693912375296, + 1946195068359375, + 72057594037927936, + 2862423051509815793, + 121439531096594251776, 5480386857784802185939, ] @@ -571,8 +600,9 @@ def _close_under_permutations(V, tree_generators): vertices = list(range(V)) trees = [] for tree in tree_generators: - trees.extend(set(_permute_tree(perm, tree) - for perm in itertools.permutations(vertices))) + trees.extend( + set(_permute_tree(perm, tree) for perm in itertools.permutations(vertices)) + ) trees.sort() return trees @@ -583,8 +613,10 @@ def enumerate_spanning_trees(V): """ if V >= len(_TREE_GENERATORS): raise NotImplementedError( - "enumerate_spanning_trees() is implemented only for trees with up to {} vertices" - .format(len(_TREE_GENERATORS) - 1)) + "enumerate_spanning_trees() is implemented only for trees with up to {} vertices".format( + len(_TREE_GENERATORS) - 1 + ) + ) all_trees = _close_under_permutations(V, _TREE_GENERATORS[V]) assert len(all_trees) == NUM_SPANNING_TREES[V] return all_trees diff --git a/pyro/distributions/stable.py b/pyro/distributions/stable.py index 9047a1472b..d1602eef93 100644 --- a/pyro/distributions/stable.py +++ b/pyro/distributions/stable.py @@ -26,8 +26,11 @@ def _unsafe_standard_stable(alpha, beta, V, W, coords): b = beta * ha.tan() # +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi v = b.atan() - ha + alpha * (V + half_pi) - Z = v.sin() / ((1 + b * b).rsqrt() * V.cos()).pow(inv_alpha) \ + Z = ( + v.sin() + / ((1 + b * b).rsqrt() * V.cos()).pow(inv_alpha) * ((v - V).cos().clamp(min=eps) / W).pow(inv_alpha - 1) + ) Z.data[Z.data != Z.data] = 0 # drop occasional NANs # Optionally convert to Nolan's parametrization S^0 where samples depend @@ -55,10 +58,12 @@ def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords): """ # Determine whether a hole workaround is needed. with torch.no_grad(): - hole = 1. + hole = 1.0 near_hole = (alpha - hole).abs() <= RADIUS if not torch._C._get_tracing_state() and not near_hole.any(): - return _unsafe_standard_stable(alpha, beta, aux_uniform, aux_exponential, coords=coords) + return _unsafe_standard_stable( + alpha, beta, aux_uniform, aux_exponential, coords=coords + ) if coords == "S": # S coords are discontinuous, so interpolate instead in S0 coords. Z = _standard_stable(alpha, beta, aux_uniform, aux_exponential, "S0") @@ -81,7 +86,9 @@ def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords): # 2 * RADIUS weights = (alpha_ - alpha.unsqueeze(-1)).abs_().mul_(-1 / (2 * RADIUS)).add_(1) weights[~near_hole] = 0.5 - pairs = _unsafe_standard_stable(alpha_, beta_, aux_uniform_, aux_exponential_, coords=coords) + pairs = _unsafe_standard_stable( + alpha_, beta_, aux_uniform_, aux_exponential_, coords=coords + ) return (pairs * weights).sum(-1) @@ -131,16 +138,21 @@ class Stable(TorchDistribution): parametrization, or "S" to use the discontinuous parameterization. """ has_rsample = True - arg_constraints = {"stability": constraints.interval(0, 2), # half-open (0, 2] - "skew": constraints.interval(-1, 1), # closed [-1, 1] - "scale": constraints.positive, - "loc": constraints.real} + arg_constraints = { + "stability": constraints.interval(0, 2), # half-open (0, 2] + "skew": constraints.interval(-1, 1), # closed [-1, 1] + "scale": constraints.positive, + "loc": constraints.real, + } support = constraints.real - def __init__(self, stability, skew, scale=1.0, loc=0.0, coords="S0", validate_args=None): + def __init__( + self, stability, skew, scale=1.0, loc=0.0, coords="S0", validate_args=None + ): assert coords in ("S", "S0"), coords self.stability, self.skew, self.scale, self.loc = broadcast_all( - stability, skew, scale, loc) + stability, skew, scale, loc + ) self.coords = coords super().__init__(self.loc.shape, validate_args=validate_args) @@ -166,14 +178,18 @@ def rsample(self, sample_shape=torch.Size()): aux_exponential = new_empty(shape).exponential_() # Differentiably transform. - x = _standard_stable(self.stability, self.skew, aux_uniform, aux_exponential, coords=self.coords) + x = _standard_stable( + self.stability, self.skew, aux_uniform, aux_exponential, coords=self.coords + ) return self.loc + self.scale * x @property def mean(self): result = self.loc if self.coords == "S0": - result = result - self.scale * self.skew * (math.pi / 2 * self.stability).tan() + result = ( + result - self.scale * self.skew * (math.pi / 2 * self.stability).tan() + ) return result.masked_fill(self.stability <= 1, math.nan) @property diff --git a/pyro/distributions/testing/gof.py b/pyro/distributions/testing/gof.py index 4d544b923c..31c9d8c5fb 100644 --- a/pyro/distributions/testing/gof.py +++ b/pyro/distributions/testing/gof.py @@ -71,10 +71,10 @@ class InvalidTest(ValueError): def print_histogram(probs, counts): max_count = max(counts) - print('{: >8} {: >8}'.format('Prob', 'Count')) + print("{: >8} {: >8}".format("Prob", "Count")) for prob, count in sorted(zip(probs, counts), reverse=True): width = int(round(HISTOGRAM_WIDTH * count / max_count)) - print('{: >8.3f} {: >8d} {}'.format(prob, count, '-' * width)) + print("{: >8.3f} {: >8d} {}".format(prob, count, "-" * width)) @torch.no_grad() @@ -113,16 +113,16 @@ def multinomial_goodness_of_fit( for p, c in zip(probs.tolist(), counts.tolist()): if abs(p - 1) < 1e-8: return 1 if c == total_count else 0 - assert p < 1, f'bad probability: {p:g}' + assert p < 1, f"bad probability: {p:g}" if p > 0: mean = total_count * p variance = total_count * p * (1 - p) if not (variance > 1): - raise InvalidTest('Goodness of fit is inaccurate; use more samples') + raise InvalidTest("Goodness of fit is inaccurate; use more samples") chi_squared += (c - mean) ** 2 / variance dof += 1 else: - warnings.warn('Zero probability in goodness-of-fit test') + warnings.warn("Zero probability in goodness-of-fit test") if c > 0: return math.inf @@ -148,7 +148,7 @@ def unif01_goodness_of_fit(samples, *, plot=False): assert samples.max() <= 1 bin_count = int(round(len(samples) ** 0.333)) if bin_count < 7: - raise InvalidTest('imprecise test, use more samples') + raise InvalidTest("imprecise test, use more samples") probs = torch.ones(bin_count) / bin_count binned = samples.mul(bin_count).long().clamp(min=0, max=bin_count - 1) counts = torch.zeros(bin_count) @@ -188,7 +188,7 @@ def density_goodness_of_fit(samples, probs, plot=False): """ assert samples.shape == probs.shape if len(samples) <= 100: - raise InvalidTest('imprecision; use more samples') + raise InvalidTest("imprecision; use more samples") samples, index = samples.sort(0) probs = probs[index] @@ -210,6 +210,7 @@ def get_nearest_neighbor_distances(samples): try: # This version scales as O(N log(N)). from scipy.spatial import cKDTree + samples = samples.cpu().numpy() distances, indices = cKDTree(samples).query(samples, k=2) return torch.from_numpy(distances[:, 1]) @@ -253,7 +254,7 @@ def vector_density_goodness_of_fit(samples, probs, *, dim=None, plot=False): dim = samples.shape[-1] assert dim if len(samples) <= 1000 * dim: - raise InvalidTest('imprecision; use more samples') + raise InvalidTest("imprecision; use more samples") radii = get_nearest_neighbor_distances(samples) density = len(samples) * probs volume = volume_of_sphere(dim, radii) diff --git a/pyro/distributions/testing/naive_dirichlet.py b/pyro/distributions/testing/naive_dirichlet.py index 080df5243a..b9676f1b99 100644 --- a/pyro/distributions/testing/naive_dirichlet.py +++ b/pyro/distributions/testing/naive_dirichlet.py @@ -15,9 +15,12 @@ class NaiveDirichlet(Dirichlet): This naive implementation has stochastic reparameterized gradients, which have higher variance than PyTorch's ``Dirichlet`` implementation. """ + def __init__(self, concentration, validate_args=None): super().__init__(concentration) - self._gamma = Gamma(concentration, torch.ones_like(concentration), validate_args=validate_args) + self._gamma = Gamma( + concentration, torch.ones_like(concentration), validate_args=validate_args + ) def rsample(self, sample_shape=torch.Size()): gammas = self._gamma.rsample(sample_shape) @@ -32,6 +35,7 @@ class NaiveBeta(Beta): This naive implementation has stochastic reparameterized gradients, which have higher variance than PyTorch's ``Beta`` implementation. """ + def __init__(self, concentration1, concentration0, validate_args=None): super().__init__(concentration1, concentration0, validate_args=validate_args) alpha_beta = torch.stack([concentration1, concentration0], -1) diff --git a/pyro/distributions/testing/rejection_exponential.py b/pyro/distributions/testing/rejection_exponential.py index b227f21798..1e26120108 100644 --- a/pyro/distributions/testing/rejection_exponential.py +++ b/pyro/distributions/testing/rejection_exponential.py @@ -12,8 +12,7 @@ @copy_docs_from(Exponential) class RejectionExponential(Rejector): - arg_constraints = {"rate": constraints.positive, - "factor": constraints.positive} + arg_constraints = {"rate": constraints.positive, "factor": constraints.positive} support = constraints.positive def __init__(self, rate, factor): diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index f829854ebb..4f83701ae9 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -15,19 +15,29 @@ class RejectionStandardGamma(Rejector): Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution. This assumes `concentration >= 1` and does not boost `concentration` or augment shape. """ + def __init__(self, concentration): if concentration.data.min() < 1: - raise NotImplementedError('concentration < 1 is not supported') + raise NotImplementedError("concentration < 1 is not supported") self.concentration = concentration - self._standard_gamma = Gamma(concentration, concentration.new([1.]).squeeze().expand_as(concentration)) + self._standard_gamma = Gamma( + concentration, concentration.new([1.0]).squeeze().expand_as(concentration) + ) # The following are Marsaglia & Tsang's variable names. self._d = self.concentration - 1.0 / 3.0 self._c = 1.0 / torch.sqrt(9.0 * self._d) # Compute log scale using Gamma.log_prob(). x = self._d.detach() # just an arbitrary x. - log_scale = self.propose_log_prob(x) + self.log_prob_accept(x) - self.log_prob(x) - super().__init__(self.propose, self.log_prob_accept, log_scale, - batch_shape=concentration.shape, event_shape=()) + log_scale = ( + self.propose_log_prob(x) + self.log_prob_accept(x) - self.log_prob(x) + ) + super().__init__( + self.propose, + self.log_prob_accept, + log_scale, + batch_shape=concentration.shape, + event_shape=(), + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(RejectionStandardGamma, _instance) @@ -39,17 +49,24 @@ def expand(self, batch_shape, _instance=None): # Compute log scale using Gamma.log_prob(). x = new._d.detach() # just an arbitrary x. log_scale = new.propose_log_prob(x) + new.log_prob_accept(x) - new.log_prob(x) - super(RejectionStandardGamma, new).__init__(new.propose, new.log_prob_accept, log_scale, - batch_shape=batch_shape, event_shape=()) + super(RejectionStandardGamma, new).__init__( + new.propose, + new.log_prob_accept, + log_scale, + batch_shape=batch_shape, + event_shape=(), + ) new._validate_args = self._validate_args return new @weakmethod def propose(self, sample_shape=torch.Size()): # Marsaglia & Tsang's x == Naesseth's epsilon` - x = torch.randn(sample_shape + self.concentration.shape, - dtype=self.concentration.dtype, - device=self.concentration.device) + x = torch.randn( + sample_shape + self.concentration.shape, + dtype=self.concentration.dtype, + device=self.concentration.device, + ) y = 1.0 + self._c * x v = y * y * y return (self._d * v).clamp_(1e-30, 1e30) @@ -61,7 +78,9 @@ def propose_log_prob(self, value): result -= torch.log(3 * y ** 2) x = (y - 1) / self._c result -= self._c.log() - result += Normal(torch.zeros_like(self.concentration), torch.ones_like(self.concentration)).log_prob(x) + result += Normal( + torch.zeros_like(self.concentration), torch.ones_like(self.concentration) + ).log_prob(x) return result @weakmethod @@ -70,7 +89,7 @@ def log_prob_accept(self, value): y = torch.pow(v, 1.0 / 3.0) x = (y - 1.0) / self._c log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v)) - log_prob_accept[y <= 0] = -float('inf') + log_prob_accept[y <= 0] = -float("inf") return log_prob_accept def log_prob(self, x): @@ -111,11 +130,12 @@ class ShapeAugmentedGamma(Gamma): This implements the shape augmentation trick of Naesseth, Ruiz, Linderman, Blei (2017) https://arxiv.org/abs/1610.05683 """ + has_rsample = True def __init__(self, concentration, rate, boost=1, validate_args=None): if concentration.min() + boost < 1: - raise ValueError('Need to boost at least once for concentration < 1') + raise ValueError("Need to boost at least once for concentration < 1") super().__init__(concentration, rate, validate_args=validate_args) self.concentration = concentration self._boost = boost @@ -159,15 +179,20 @@ class ShapeAugmentedDirichlet(Dirichlet): This naive implementation has stochastic reparameterized gradients, which have higher variance than PyTorch's ``Dirichlet`` implementation. """ + def __init__(self, concentration, boost=1, validate_args=None): super().__init__(concentration, validate_args=validate_args) - self._gamma = ShapeAugmentedGamma(concentration, torch.ones_like(concentration), boost) + self._gamma = ShapeAugmentedGamma( + concentration, torch.ones_like(concentration), boost + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(ShapeAugmentedDirichlet, _instance) new = super().expand(batch_shape, new) batch_shape = torch.Size(batch_shape) - new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._gamma = self._gamma.expand( + batch_shape + self._gamma.concentration.shape[-1:] + ) new._validate_args = self._validate_args return new @@ -184,16 +209,21 @@ class ShapeAugmentedBeta(Beta): This naive implementation has stochastic reparameterized gradients, which have higher variance than PyTorch's ``rate`` implementation. """ + def __init__(self, concentration1, concentration0, boost=1, validate_args=None): super().__init__(concentration1, concentration0, validate_args=validate_args) alpha_beta = torch.stack([concentration1, concentration0], -1) - self._gamma = ShapeAugmentedGamma(alpha_beta, torch.ones_like(alpha_beta), boost) + self._gamma = ShapeAugmentedGamma( + alpha_beta, torch.ones_like(alpha_beta), boost + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(ShapeAugmentedBeta, _instance) new = super().expand(batch_shape, new) batch_shape = torch.Size(batch_shape) - new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._gamma = self._gamma.expand( + batch_shape + self._gamma.concentration.shape[-1:] + ) new._validate_args = self._validate_args return new diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 9a918c11f0..9459803212 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -32,7 +32,9 @@ def _log_normalizer(d): y = d.concentration0 return (x + y).lgamma() - x.lgamma() - y.lgamma() - log_normalizer = _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + log_normalizer = ( + _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + ) return updated, log_normalizer @@ -46,7 +48,7 @@ class Binomial(torch.distributions.Binomial, TorchDistributionMixin): # a shifted Sterling's approximation to the Beta function, reducing # computational cost from 3 lgamma() evaluations to 4 log() evaluations # plus arithmetic. Recommended values are between 0.1 and 0.01. - approx_log_prob_tol = 0. + approx_log_prob_tol = 0.0 def sample(self, sample_shape=torch.Size()): if self.approx_sample_thresh < math.inf: @@ -65,10 +67,12 @@ def sample(self, sample_shape=torch.Size()): sample = torch.where(p < q, result, self.total_count - result) # Draw exact samples for remaining items. if exact.any(): - total_count = torch.where(exact, self.total_count, - torch.zeros_like(self.total_count)) + total_count = torch.where( + exact, self.total_count, torch.zeros_like(self.total_count) + ) exact_sample = torch.distributions.Binomial( - total_count, self.probs, validate_args=False).sample(sample_shape) + total_count, self.probs, validate_args=False + ).sample(sample_shape) sample = torch.where(exact, exact_sample, sample) return sample return super().sample(sample_shape) @@ -84,9 +88,14 @@ def log_prob(self, value): # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) # = k * logit - n * logit - n * log1p(e^-logit) # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) - normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()) - return (k * self.logits - normalize_term - + log_binomial(n, k, tol=self.approx_log_prob_tol)) + normalize_term = n * ( + _clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p() + ) + return ( + k * self.logits + - normalize_term + + log_binomial(n, k, tol=self.approx_log_prob_tol) + ) # This overloads .log_prob() and .enumerate_support() to speed up evaluating @@ -94,11 +103,10 @@ def log_prob(self, value): # and merely reshape the self.logits tensor. This is especially important for # Pyro models that use enumeration. class Categorical(torch.distributions.Categorical, TorchDistributionMixin): - arg_constraints = {"probs": constraints.simplex, - "logits": constraints.real_vector} + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} def log_prob(self, value): - if getattr(value, '_pyro_categorical_support', None) == id(self): + if getattr(value, "_pyro_categorical_support", None) == id(self): # Assume value is a reshaped torch.arange(event_shape[0]). # In this case we can call .reshape() rather than torch.gather(). if not torch._C._get_tracing_state(): @@ -107,7 +115,9 @@ def log_prob(self, value): assert value.size(0) == self.logits.size(-1) logits = self.logits if logits.dim() <= value.dim(): - logits = logits.reshape((1,) * (1 + value.dim() - logits.dim()) + logits.shape) + logits = logits.reshape( + (1,) * (1 + value.dim() - logits.dim()) + logits.shape + ) if not torch._C._get_tracing_state(): assert logits.size(-1 - value.dim()) == 1 return logits.transpose(-1 - value.dim(), -1).squeeze(-1) @@ -121,7 +131,6 @@ def enumerate_support(self, expand=True): class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): - @staticmethod def infer_shapes(concentration): batch_shape = concentration[:-1] @@ -140,7 +149,9 @@ def _log_normalizer(d): c = d.concentration return c.sum(-1).lgamma() - c.lgamma().sum(-1) - log_normalizer = _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + log_normalizer = ( + _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + ) return updated, log_normalizer @@ -158,7 +169,9 @@ def _log_normalizer(d): c = d.concentration return d.rate.log() * c - c.lgamma() - log_normalizer = _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + log_normalizer = ( + _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated) + ) return updated, log_normalizer @@ -176,14 +189,21 @@ def __init__(self, loc, scale, validate_args=None): # This differs from torch.distributions.LogNormal only in that base_dist is # a pyro.distributions.Normal rather than a torch.distributions.Normal. super(torch.distributions.LogNormal, self).__init__( - base_dist, torch.distributions.transforms.ExpTransform(), validate_args=validate_args) + base_dist, + torch.distributions.transforms.ExpTransform(), + validate_args=validate_args, + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LogNormal, _instance) - return super(torch.distributions.LogNormal, self).expand(batch_shape, _instance=new) + return super(torch.distributions.LogNormal, self).expand( + batch_shape, _instance=new + ) -class LowRankMultivariateNormal(torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin): +class LowRankMultivariateNormal( + torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin +): @staticmethod def infer_shapes(loc, cov_factor, cov_diag): event_shape = loc[-1:] @@ -191,9 +211,13 @@ def infer_shapes(loc, cov_factor, cov_diag): return batch_shape, event_shape -class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): +class MultivariateNormal( + torch.distributions.MultivariateNormal, TorchDistributionMixin +): @staticmethod - def infer_shapes(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None): + def infer_shapes( + loc, covariance_matrix=None, precision_matrix=None, scale_tril=None + ): batch_shape, event_shape = loc[:-1], loc[-1:] for matrix in [covariance_matrix, precision_matrix, scale_tril]: if matrix is not None: @@ -242,13 +266,16 @@ def log_prob(self, value): rate, value, nonzero = torch.broadcast_tensors(self.rate, value, value > 0) sparse_rate = rate[nonzero] sparse_value = value[nonzero] - return torch.zeros_like(rate).masked_scatter( - nonzero, (sparse_rate.log() * sparse_value) - (sparse_value + 1).lgamma() - ) - rate + return ( + torch.zeros_like(rate).masked_scatter( + nonzero, + (sparse_rate.log() * sparse_value) - (sparse_value + 1).lgamma(), + ) + - rate + ) class Independent(torch.distributions.Independent, TorchDistributionMixin): - @staticmethod def infer_shapes(**kwargs): raise NotImplementedError @@ -307,21 +334,26 @@ def support(self): _PyroDist.__module__ = __name__ locals()[_name] = _PyroDist - _PyroDist.__doc__ = ''' + _PyroDist.__doc__ = """ Wraps :class:`{}.{}` with :class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`. - '''.format(_Dist.__module__, _Dist.__name__) + """.format( + _Dist.__module__, _Dist.__name__ + ) __all__.append(_name) # Create sphinx documentation. -__doc__ = '\n\n'.join([ - - ''' +__doc__ = "\n\n".join( + [ + """ {0} ---------------------------------------------------------------- .. autoclass:: pyro.distributions.{0} - '''.format(_name) - for _name in sorted(__all__) -]) + """.format( + _name + ) + for _name in sorted(__all__) + ] +) diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index f98cca66a9..94c16b19eb 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -26,6 +26,7 @@ class TorchDistributionMixin(Distribution): :class:`torch.distributions.distribution.Distribution` and then inherit from :class:`TorchDistributionMixin`. """ + def __call__(self, sample_shape=torch.Size()): """ Samples a random value. @@ -43,7 +44,11 @@ def __call__(self, sample_shape=torch.Size()): batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ - return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape) + return ( + self.rsample(sample_shape) + if self.has_rsample + else self.sample(sample_shape) + ) @property def event_dim(self): @@ -92,7 +97,7 @@ def infer_shapes(cls, **arg_shapes): batch_shapes = [] for name, shape in arg_shapes.items(): event_dim = cls.arg_constraints.get(name, constraints.real).event_dim - batch_shapes.append(shape[:len(shape) - event_dim]) + batch_shapes.append(shape[: len(shape) - event_dim]) batch_shape = torch.Size(broadcast_shape(*batch_shapes)) event_shape = torch.Size() return batch_shape, event_shape @@ -126,13 +131,17 @@ def expand_by(self, sample_shape): try: expanded_dist = self.expand(torch.Size(sample_shape) + self.batch_shape) except NotImplementedError: - expanded_dist = TorchDistributionMixin.expand(self, torch.Size(sample_shape) + self.batch_shape) + expanded_dist = TorchDistributionMixin.expand( + self, torch.Size(sample_shape) + self.batch_shape + ) return expanded_dist def reshape(self, sample_shape=None, extra_event_dims=None): - raise Exception(''' + raise Exception( + """ .reshape(sample_shape=s, extra_event_dims=n) was renamed and split into - .expand_by(sample_shape=s).to_event(reinterpreted_batch_ndims=n).''') + .expand_by(sample_shape=s).to_event(reinterpreted_batch_ndims=n).""" + ) def to_event(self, reinterpreted_batch_ndims=None): """ @@ -179,11 +188,17 @@ def to_event(self, reinterpreted_batch_ndims=None): if reinterpreted_batch_ndims == 0: return base_dist if reinterpreted_batch_ndims < 0: - raise ValueError("Cannot remove event dimensions from {}".format(type(self))) - return pyro.distributions.torch.Independent(base_dist, reinterpreted_batch_ndims) + raise ValueError( + "Cannot remove event dimensions from {}".format(type(self)) + ) + return pyro.distributions.torch.Independent( + base_dist, reinterpreted_batch_ndims + ) def independent(self, reinterpreted_batch_ndims=None): - warnings.warn("independent is deprecated; use to_event instead", DeprecationWarning) + warnings.warn( + "independent is deprecated; use to_event instead", DeprecationWarning + ) return self.to_event(reinterpreted_batch_ndims=reinterpreted_batch_ndims) def mask(self, mask): @@ -261,6 +276,7 @@ class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin method to improve gradient estimates and set ``.has_enumerate_support = True``. """ + # Provides a default `.expand` method for Pyro distributions which overrides # torch.distributions.Distribution.expand (throws a NotImplementedError). expand = TorchDistributionMixin.expand @@ -278,6 +294,7 @@ class MaskedDistribution(TorchDistribution): :param mask: A boolean or boolean-valued tensor. :type mask: torch.Tensor or bool """ + arg_constraints = {} def __init__(self, base_dist, mask): @@ -300,7 +317,9 @@ def expand(self, batch_shape, _instance=None): new._mask = self._mask if isinstance(new._mask, torch.Tensor): new._mask = new._mask.expand(batch_shape) - super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(MaskedDistribution, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -324,8 +343,9 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, value): if self._mask is False: - shape = broadcast_shape(self.base_dist.batch_shape, - value.shape[:value.dim() - self.event_dim]) + shape = broadcast_shape( + self.base_dist.batch_shape, value.shape[: value.dim() - self.event_dim] + ) return torch.zeros((), device=value.device).expand(shape) if self._mask is True: return self.base_dist.log_prob(value) @@ -353,7 +373,9 @@ def conjugate_update(self, other): """ updated, log_normalizer = self.base_dist.conjugate_update(other) updated = updated.mask(self._mask) - log_normalizer = torch.where(self._mask, log_normalizer, torch.zeros_like(log_normalizer)) + log_normalizer = torch.where( + self._mask, log_normalizer, torch.zeros_like(log_normalizer) + ) return updated, log_normalizer @@ -370,8 +392,9 @@ def expand(self, batch_shape, _instance=None): # Do basic validation. e.g. we should not "unexpand" distributions even if that is possible. new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape) # Record interstitial and expanded dims/sizes w.r.t. the base distribution - new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape, - new_shape) + new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape( + self.base_dist.batch_shape, new_shape + ) self._batch_shape = new_shape self._expanded_sizes = expanded_sizes self._interstitial_sizes = interstitial_sizes @@ -380,8 +403,11 @@ def expand(self, batch_shape, _instance=None): @staticmethod def _broadcast_shape(existing_shape, new_shape): if len(new_shape) < len(existing_shape): - raise ValueError("Cannot broadcast distribution of shape {} to shape {}" - .format(existing_shape, new_shape)) + raise ValueError( + "Cannot broadcast distribution of shape {} to shape {}".format( + existing_shape, new_shape + ) + ) reversed_shape = list(reversed(existing_shape)) expanded_sizes, interstitial_sizes = [], [] for i, size in enumerate(reversed(new_shape)): @@ -393,9 +419,16 @@ def _broadcast_shape(existing_shape, new_shape): reversed_shape[i] = size interstitial_sizes.append((-i - 1, size)) elif reversed_shape[i] != size: - raise ValueError("Cannot broadcast distribution of shape {} to shape {}" - .format(existing_shape, new_shape)) - return tuple(reversed(reversed_shape)), OrderedDict(expanded_sizes), OrderedDict(interstitial_sizes) + raise ValueError( + "Cannot broadcast distribution of shape {} to shape {}".format( + existing_shape, new_shape + ) + ) + return ( + tuple(reversed(reversed_shape)), + OrderedDict(expanded_sizes), + OrderedDict(interstitial_sizes), + ) @property def has_rsample(self): @@ -417,7 +450,9 @@ def _sample(self, sample_fn, sample_shape): batch_shape = expanded_sizes + interstitial_sizes samples = sample_fn(sample_shape + batch_shape) interstitial_idx = len(sample_shape) + len(expanded_sizes) - interstitial_sample_dims = tuple(range(interstitial_idx, interstitial_idx + len(interstitial_sizes))) + interstitial_sample_dims = tuple( + range(interstitial_idx, interstitial_idx + len(interstitial_sizes)) + ) for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims): samples = samples.transpose(dim1, dim2) return samples.reshape(sample_shape + self.batch_shape + self.event_shape) @@ -429,12 +464,16 @@ def rsample(self, sample_shape=torch.Size()): return self._sample(self.base_dist.rsample, sample_shape) def log_prob(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) + shape = broadcast_shape( + self.batch_shape, value.shape[: value.dim() - self.event_dim] + ) log_prob = self.base_dist.log_prob(value) return log_prob.expand(shape) def score_parts(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) + shape = broadcast_shape( + self.batch_shape, value.shape[: value.dim() - self.event_dim] + ) log_prob, score_function, entropy_term = self.base_dist.score_parts(value) if self.batch_shape != self.base_dist.batch_shape: log_prob = log_prob.expand(shape) @@ -484,7 +523,7 @@ def _kl_masked_masked(p, q): mask = p._mask & q._mask if mask is False: - return 0. # Return a float, since we cannot determine device. + return 0.0 # Return a float, since we cannot determine device. if mask is True: return kl_divergence(p.base_dist, q.base_dist) kl = kl_divergence(p.base_dist, q.base_dist) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index ddb3601b8c..55d98f6650 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -7,18 +7,18 @@ import torch -assert torch.__version__.startswith('1.') +assert torch.__version__.startswith("1.") def patch_dependency(target, root_module=torch): - parts = target.split('.') + parts = target.split(".") assert parts[0] == root_module.__name__ module = root_module for part in parts[1:-1]: module = getattr(module, part) name = parts[-1] old_fn = getattr(module, name, None) - old_fn = getattr(old_fn, '_pyro_unpatched', old_fn) # ensure patching is idempotent + old_fn = getattr(old_fn, "_pyro_unpatched", old_fn) # ensure patching is idempotent def decorator(new_fn): try: @@ -35,7 +35,7 @@ def decorator(new_fn): # TODO: Move upstream to allow for pickle serialization of transforms -@patch_dependency('torch.distributions.transforms.Transform.__getstate__') +@patch_dependency("torch.distributions.transforms.Transform.__getstate__") def _Transform__getstate__(self): attrs = {} for k, v in self.__dict__.items(): @@ -47,50 +47,52 @@ def _Transform__getstate__(self): # TODO move upstream -@patch_dependency('torch.distributions.transforms.Transform.clear_cache') +@patch_dependency("torch.distributions.transforms.Transform.clear_cache") def _Transform_clear_cache(self): if self._cache_size == 1: self._cached_x_y = None, None # TODO move upstream -@patch_dependency('torch.distributions.TransformedDistribution.clear_cache') +@patch_dependency("torch.distributions.TransformedDistribution.clear_cache") def _TransformedDistribution_clear_cache(self): for t in self.transforms: t.clear_cache() # TODO fix https://github.com/pytorch/pytorch/issues/48054 upstream -@patch_dependency('torch.distributions.HalfCauchy.log_prob') +@patch_dependency("torch.distributions.HalfCauchy.log_prob") def _HalfCauchy_logprob(self, value): if self._validate_args: self._validate_sample(value) - value = torch.as_tensor(value, dtype=self.base_dist.scale.dtype, - device=self.base_dist.scale.device) + value = torch.as_tensor( + value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device + ) log_prob = self.base_dist.log_prob(value) + math.log(2) log_prob.masked_fill_(value.expand(log_prob.shape) < 0, -float("inf")) return log_prob # TODO fix batch_shape have an extra singleton dimension upstream -@patch_dependency('torch.distributions.constraints._PositiveDefinite.check') +@patch_dependency("torch.distributions.constraints._PositiveDefinite.check") def _PositiveDefinite_check(self, value): matrix_shape = value.shape[-2:] batch_shape = value.shape[:-2] flattened_value = value.reshape((-1,) + matrix_shape) - return torch.stack([torch.linalg.eigvalsh(v)[:1] > 0.0 - for v in flattened_value]).view(batch_shape) + return torch.stack( + [torch.linalg.eigvalsh(v)[:1] > 0.0 for v in flattened_value] + ).view(batch_shape) -@patch_dependency('torch.distributions.constraints._CorrCholesky.check') +@patch_dependency("torch.distributions.constraints._CorrCholesky.check") def _CorrCholesky_check(self, value): row_norm = torch.linalg.norm(value.detach(), dim=-1) - unit_row_norm = (row_norm - 1.).abs().le(1e-4).all(dim=-1) + unit_row_norm = (row_norm - 1.0).abs().le(1e-4).all(dim=-1) return torch.distributions.constraints.lower_cholesky.check(value) & unit_row_norm # This adds a __call__ method to satisfy sphinx. -@patch_dependency('torch.distributions.utils.lazy_property.__call__') +@patch_dependency("torch.distributions.utils.lazy_property.__call__") def _lazy_property__call__(self): raise NotImplementedError diff --git a/pyro/distributions/torch_transform.py b/pyro/distributions/torch_transform.py index 7ff1082970..5c33b4eda1 100644 --- a/pyro/distributions/torch_transform.py +++ b/pyro/distributions/torch_transform.py @@ -24,6 +24,7 @@ class ComposeTransformModule(torch.distributions.ComposeTransform, torch.nn.Modu so that transform parameters are automatically registered by Pyro's param store when used in :class:`~pyro.nn.module.PyroModule` instances. """ + def __init__(self, parts): super().__init__(parts) for part in parts: diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 398403532e..2be90c0896 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -97,7 +97,9 @@ def _transform_to_corr_cholesky(constraint): @biject_to.register(constraints.corr_matrix) @transform_to.register(constraints.corr_matrix) def _transform_to_corr_matrix(constraint): - return ComposeTransform([CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv]) + return ComposeTransform( + [CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv] + ) @biject_to.register(constraints.ordered_vector) @@ -145,74 +147,74 @@ def iterated(repeats, base_fn, *args, **kwargs): __all__ = [ - 'iterated', - 'AffineAutoregressive', - 'AffineCoupling', - 'BatchNorm', - 'BlockAutoregressive', - 'CholeskyTransform', - 'ComposeTransformModule', - 'ConditionalAffineAutoregressive', - 'ConditionalAffineCoupling', - 'ConditionalGeneralizedChannelPermute', - 'ConditionalHouseholder', - 'ConditionalMatrixExponential', - 'ConditionalNeuralAutoregressive', - 'ConditionalPlanar', - 'ConditionalRadial', - 'ConditionalSpline', - 'ConditionalSplineAutoregressive', - 'CorrLCholeskyTransform', - 'CorrMatrixCholeskyTransform', - 'DiscreteCosineTransform', - 'ELUTransform', - 'GeneralizedChannelPermute', - 'HaarTransform', - 'Householder', - 'LeakyReLUTransform', - 'LowerCholeskyAffine', - 'MatrixExponential', - 'NeuralAutoregressive', - 'Normalize', - 'OrderedTransform', - 'Permute', - 'Planar', - 'Polynomial', - 'Radial', - 'SoftplusLowerCholeskyTransform', - 'SoftplusTransform', - 'Spline', - 'SplineAutoregressive', - 'SplineCoupling', - 'Sylvester', - 'affine_autoregressive', - 'affine_coupling', - 'batchnorm', - 'block_autoregressive', - 'conditional_affine_autoregressive', - 'conditional_affine_coupling', - 'conditional_generalized_channel_permute', - 'conditional_householder', - 'conditional_matrix_exponential', - 'conditional_neural_autoregressive', - 'conditional_planar', - 'conditional_radial', - 'conditional_spline', - 'conditional_spline_autoregressive', - 'elu', - 'generalized_channel_permute', - 'householder', - 'leaky_relu', - 'matrix_exponential', - 'neural_autoregressive', - 'permute', - 'planar', - 'polynomial', - 'radial', - 'spline', - 'spline_autoregressive', - 'spline_coupling', - 'sylvester', + "iterated", + "AffineAutoregressive", + "AffineCoupling", + "BatchNorm", + "BlockAutoregressive", + "CholeskyTransform", + "ComposeTransformModule", + "ConditionalAffineAutoregressive", + "ConditionalAffineCoupling", + "ConditionalGeneralizedChannelPermute", + "ConditionalHouseholder", + "ConditionalMatrixExponential", + "ConditionalNeuralAutoregressive", + "ConditionalPlanar", + "ConditionalRadial", + "ConditionalSpline", + "ConditionalSplineAutoregressive", + "CorrLCholeskyTransform", + "CorrMatrixCholeskyTransform", + "DiscreteCosineTransform", + "ELUTransform", + "GeneralizedChannelPermute", + "HaarTransform", + "Householder", + "LeakyReLUTransform", + "LowerCholeskyAffine", + "MatrixExponential", + "NeuralAutoregressive", + "Normalize", + "OrderedTransform", + "Permute", + "Planar", + "Polynomial", + "Radial", + "SoftplusLowerCholeskyTransform", + "SoftplusTransform", + "Spline", + "SplineAutoregressive", + "SplineCoupling", + "Sylvester", + "affine_autoregressive", + "affine_coupling", + "batchnorm", + "block_autoregressive", + "conditional_affine_autoregressive", + "conditional_affine_coupling", + "conditional_generalized_channel_permute", + "conditional_householder", + "conditional_matrix_exponential", + "conditional_neural_autoregressive", + "conditional_planar", + "conditional_radial", + "conditional_spline", + "conditional_spline_autoregressive", + "elu", + "generalized_channel_permute", + "householder", + "leaky_relu", + "matrix_exponential", + "neural_autoregressive", + "permute", + "planar", + "polynomial", + "radial", + "spline", + "spline_autoregressive", + "spline_coupling", + "sylvester", ] __all__.extend(torch_transforms) diff --git a/pyro/distributions/transforms/affine_autoregressive.py b/pyro/distributions/transforms/affine_autoregressive.py index 97f2d6f431..483fe89aa0 100644 --- a/pyro/distributions/transforms/affine_autoregressive.py +++ b/pyro/distributions/transforms/affine_autoregressive.py @@ -98,12 +98,12 @@ class AffineAutoregressive(TransformModule): autoregressive = True def __init__( - self, - autoregressive_nn, - log_scale_min_clip=-5., - log_scale_max_clip=3., - sigmoid_bias=2.0, - stable=False + self, + autoregressive_nn, + log_scale_min_clip=-5.0, + log_scale_max_clip=3.0, + sigmoid_bias=2.0, + stable=False, ): super().__init__(cache_size=1) self.arn = autoregressive_nn @@ -129,7 +129,9 @@ def _call(self, x): the base distribution (or the output of a previous transform) """ mean, log_scale = self.arn(x) - log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) self._cached_log_scale = log_scale scale = torch.exp(log_scale) @@ -152,13 +154,20 @@ def _inverse(self, y): # NOTE: Inversion is an expensive operation that scales in the dimension of the input for idx in perm: mean, log_scale = self.arn(torch.stack(x, dim=-1)) - inverse_scale = torch.exp(-clamp_preserve_gradients( - log_scale[..., idx], min=self.log_scale_min_clip, max=self.log_scale_max_clip)) + inverse_scale = torch.exp( + -clamp_preserve_gradients( + log_scale[..., idx], + min=self.log_scale_min_clip, + max=self.log_scale_max_clip, + ) + ) mean = mean[..., idx] x[idx] = (y[..., idx] - mean) * inverse_scale x = torch.stack(x, dim=-1) - log_scale = clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip) + log_scale = clamp_preserve_gradients( + log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip + ) self._cached_log_scale = log_scale return x @@ -176,7 +185,9 @@ def log_abs_det_jacobian(self, x, y): log_scale = self._cached_log_scale elif not self.stable: _, log_scale = self.arn(x) - log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) else: _, logit_scale = self.arn(x) log_scale = self.logsigmoid(logit_scale + self.sigmoid_bias) @@ -362,7 +373,9 @@ def affine_autoregressive(input_dim, hidden_dims=None, **kwargs): return AffineAutoregressive(arn, **kwargs) -def conditional_affine_autoregressive(input_dim, context_dim, hidden_dims=None, **kwargs): +def conditional_affine_autoregressive( + input_dim, context_dim, hidden_dims=None, **kwargs +): """ A helper function to create an :class:`~pyro.distributions.transforms.ConditionalAffineAutoregressive` object diff --git a/pyro/distributions/transforms/affine_coupling.py b/pyro/distributions/transforms/affine_coupling.py index f744503ab8..cf768b9626 100644 --- a/pyro/distributions/transforms/affine_coupling.py +++ b/pyro/distributions/transforms/affine_coupling.py @@ -88,7 +88,15 @@ class AffineCoupling(TransformModule): bijective = True - def __init__(self, split_dim, hypernet, *, dim=-1, log_scale_min_clip=-5., log_scale_max_clip=3.): + def __init__( + self, + split_dim, + hypernet, + *, + dim=-1, + log_scale_min_clip=-5.0, + log_scale_max_clip=3.0 + ): super().__init__(cache_size=1) if dim >= 0: raise ValueError("'dim' keyword argument must be negative") @@ -117,14 +125,18 @@ def _call(self, x): :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ - x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim) + x1, x2 = x.split( + [self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim + ) # Now that we can split on an arbitrary dimension, we have do a bit of reshaping... - mean, log_scale = self.nn(x1.reshape(x1.shape[:self.dim] + (-1,))) - mean = mean.reshape(mean.shape[:-1] + x2.shape[self.dim:]) - log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[self.dim:]) + mean, log_scale = self.nn(x1.reshape(x1.shape[: self.dim] + (-1,))) + mean = mean.reshape(mean.shape[:-1] + x2.shape[self.dim :]) + log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[self.dim :]) - log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) self._cached_log_scale = log_scale y1 = x1 @@ -139,15 +151,19 @@ def _inverse(self, y): Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ - y1, y2 = y.split([self.split_dim, y.size(self.dim) - self.split_dim], dim=self.dim) + y1, y2 = y.split( + [self.split_dim, y.size(self.dim) - self.split_dim], dim=self.dim + ) x1 = y1 # Now that we can split on an arbitrary dimension, we have do a bit of reshaping... - mean, log_scale = self.nn(x1.reshape(x1.shape[:self.dim] + (-1,))) - mean = mean.reshape(mean.shape[:-1] + y2.shape[self.dim:]) - log_scale = log_scale.reshape(log_scale.shape[:-1] + y2.shape[self.dim:]) + mean, log_scale = self.nn(x1.reshape(x1.shape[: self.dim] + (-1,))) + mean = mean.reshape(mean.shape[:-1] + y2.shape[self.dim :]) + log_scale = log_scale.reshape(log_scale.shape[:-1] + y2.shape[self.dim :]) - log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) self._cached_log_scale = log_scale x2 = (y2 - mean) * torch.exp(-log_scale) @@ -161,10 +177,14 @@ def log_abs_det_jacobian(self, x, y): if self._cached_log_scale is not None and x is x_old and y is y_old: log_scale = self._cached_log_scale else: - x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim) - _, log_scale = self.nn(x1.reshape(x1.shape[:self.dim] + (-1,))) - log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[self.dim:]) - log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) + x1, x2 = x.split( + [self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim + ) + _, log_scale = self.nn(x1.reshape(x1.shape[: self.dim] + (-1,))) + log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[self.dim :]) + log_scale = clamp_preserve_gradients( + log_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) return _sum_rightmost(log_scale, self.event_dim) @@ -286,9 +306,13 @@ def affine_coupling(input_dim, hidden_dims=None, split_dim=None, dim=-1, **kwarg """ if not isinstance(input_dim, int): if len(input_dim) != -dim: - raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim)) + raise ValueError( + "event shape {} must have same length as event_dim {}".format( + input_dim, -dim + ) + ) event_shape = input_dim - extra_dims = reduce(operator.mul, event_shape[(dim + 1):], 1) + extra_dims = reduce(operator.mul, event_shape[(dim + 1) :], 1) else: event_shape = [input_dim] extra_dims = 1 @@ -299,14 +323,20 @@ def affine_coupling(input_dim, hidden_dims=None, split_dim=None, dim=-1, **kwarg if hidden_dims is None: hidden_dims = [10 * event_shape[dim] * extra_dims] - hypernet = DenseNN(split_dim * extra_dims, - hidden_dims, - [(event_shape[dim] - split_dim) * extra_dims, - (event_shape[dim] - split_dim) * extra_dims]) + hypernet = DenseNN( + split_dim * extra_dims, + hidden_dims, + [ + (event_shape[dim] - split_dim) * extra_dims, + (event_shape[dim] - split_dim) * extra_dims, + ], + ) return AffineCoupling(split_dim, hypernet, dim=dim, **kwargs) -def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_dim=None, dim=-1, **kwargs): +def conditional_affine_coupling( + input_dim, context_dim, hidden_dims=None, split_dim=None, dim=-1, **kwargs +): """ A helper function to create an :class:`~pyro.distributions.transforms.ConditionalAffineCoupling` object that @@ -336,9 +366,13 @@ def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_ """ if not isinstance(input_dim, int): if len(input_dim) != -dim: - raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim)) + raise ValueError( + "event shape {} must have same length as event_dim {}".format( + input_dim, -dim + ) + ) event_shape = input_dim - extra_dims = reduce(operator.mul, event_shape[(dim + 1):], 1) + extra_dims = reduce(operator.mul, event_shape[(dim + 1) :], 1) else: event_shape = [input_dim] extra_dims = 1 @@ -349,6 +383,13 @@ def conditional_affine_coupling(input_dim, context_dim, hidden_dims=None, split_ if hidden_dims is None: hidden_dims = [10 * event_shape[dim] * extra_dims] - nn = ConditionalDenseNN(split_dim * extra_dims, context_dim, hidden_dims, - [(event_shape[dim] - split_dim) * extra_dims, (event_shape[dim] - split_dim) * extra_dims]) + nn = ConditionalDenseNN( + split_dim * extra_dims, + context_dim, + hidden_dims, + [ + (event_shape[dim] - split_dim) * extra_dims, + (event_shape[dim] - split_dim) * extra_dims, + ], + ) return ConditionalAffineCoupling(split_dim, nn, dim=dim, **kwargs) diff --git a/pyro/distributions/transforms/basic.py b/pyro/distributions/transforms/basic.py index d4fa9eae6f..585bbdb877 100644 --- a/pyro/distributions/transforms/basic.py +++ b/pyro/distributions/transforms/basic.py @@ -28,7 +28,9 @@ def _call(self, x): return F.elu(x) def _inverse(self, y, eps=1e-8): - return torch.max(y, torch.zeros_like(y)) + torch.min(torch.log1p(y + eps), torch.zeros_like(y)) + return torch.max(y, torch.zeros_like(y)) + torch.min( + torch.log1p(y + eps), torch.zeros_like(y) + ) def log_abs_det_jacobian(self, x, y): return -F.relu(-x) @@ -42,6 +44,7 @@ def elu(): """ return ELUTransform() + # TODO: Move upstream @@ -64,7 +67,9 @@ def _inverse(self, y): return F.leaky_relu(y, negative_slope=100.0) def log_abs_det_jacobian(self, x, y): - return torch.where(x >= 0., torch.zeros_like(x), torch.ones_like(x) * math.log(0.01)) + return torch.where( + x >= 0.0, torch.zeros_like(x), torch.ones_like(x) * math.log(0.01) + ) def leaky_relu(): diff --git a/pyro/distributions/transforms/batchnorm.py b/pyro/distributions/transforms/batchnorm.py index 686f02abe9..1a92e09a2a 100644 --- a/pyro/distributions/transforms/batchnorm.py +++ b/pyro/distributions/transforms/batchnorm.py @@ -83,8 +83,8 @@ def __init__(self, input_dim, momentum=0.1, epsilon=1e-5): self.momentum = momentum self.epsilon = epsilon - self.register_buffer('moving_mean', torch.zeros(input_dim)) - self.register_buffer('moving_variance', torch.ones(input_dim)) + self.register_buffer("moving_mean", torch.zeros(input_dim)) + self.register_buffer("moving_variance", torch.ones(input_dim)) @property def constrained_gamma(self): @@ -100,8 +100,9 @@ def _call(self, x): the base distribution (or the output of a previous transform) """ # Enforcing the constraint that gamma is positive - return (x - self.beta) / self.constrained_gamma * \ - torch.sqrt(self.moving_variance + self.epsilon) + self.moving_mean + return (x - self.beta) / self.constrained_gamma * torch.sqrt( + self.moving_variance + self.epsilon + ) + self.moving_mean def _inverse(self, y): """ @@ -123,7 +124,9 @@ def _inverse(self, y): else: mean, var = self.moving_mean, self.moving_variance - return (y - mean) * self.constrained_gamma / torch.sqrt(var + self.epsilon) + self.beta + return (y - mean) * self.constrained_gamma / torch.sqrt( + var + self.epsilon + ) + self.beta def log_abs_det_jacobian(self, x, y): """ @@ -134,7 +137,7 @@ def log_abs_det_jacobian(self, x, y): else: # NOTE: You wouldn't typically run this function in eval mode, but included for gradient tests var = self.moving_variance - return (-self.constrained_gamma.log() + 0.5 * torch.log(var + self.epsilon)) + return -self.constrained_gamma.log() + 0.5 * torch.log(var + self.epsilon) def batchnorm(input_dim, **kwargs): diff --git a/pyro/distributions/transforms/block_autoregressive.py b/pyro/distributions/transforms/block_autoregressive.py index 2591c97084..e019ea6b83 100644 --- a/pyro/distributions/transforms/block_autoregressive.py +++ b/pyro/distributions/transforms/block_autoregressive.py @@ -74,21 +74,28 @@ class BlockAutoregressive(TransformModule): bijective = True autoregressive = True - def __init__(self, input_dim, hidden_factors=[8, 8], activation='tanh', residual=None): + def __init__( + self, input_dim, hidden_factors=[8, 8], activation="tanh", residual=None + ): super().__init__(cache_size=1) if any([h < 1 for h in hidden_factors]): - raise ValueError('Hidden factors, {}, must all be >= 1'.format(hidden_factors)) + raise ValueError( + "Hidden factors, {}, must all be >= 1".format(hidden_factors) + ) - if residual not in [None, 'normal', 'gated']: - raise ValueError('Invalid value {} for keyword argument "residual"'.format(residual)) + if residual not in [None, "normal", "gated"]: + raise ValueError( + 'Invalid value {} for keyword argument "residual"'.format(residual) + ) # Mix in activation function methods name_to_mixin = { - 'ELU': ELUTransform, - 'LeakyReLU': LeakyReLUTransform, - 'sigmoid': torch.distributions.transforms.SigmoidTransform, - 'tanh': TanhTransform} + "ELU": ELUTransform, + "LeakyReLU": LeakyReLUTransform, + "sigmoid": torch.distributions.transforms.SigmoidTransform, + "tanh": TanhTransform, + } if activation not in name_to_mixin: raise ValueError('Invalid activation function "{}"'.format(activation)) self.T = name_to_mixin[activation]() @@ -96,14 +103,23 @@ def __init__(self, input_dim, hidden_factors=[8, 8], activation='tanh', residual # Initialize modules for each layer in transform self.residual = residual self.input_dim = input_dim - self.layers = nn.ModuleList([MaskedBlockLinear(input_dim, input_dim * hidden_factors[0], input_dim)]) + self.layers = nn.ModuleList( + [MaskedBlockLinear(input_dim, input_dim * hidden_factors[0], input_dim)] + ) for idx in range(1, len(hidden_factors)): - self.layers.append(MaskedBlockLinear( - input_dim * hidden_factors[idx - 1], input_dim * hidden_factors[idx], input_dim)) - self.layers.append(MaskedBlockLinear(input_dim * hidden_factors[-1], input_dim, input_dim)) + self.layers.append( + MaskedBlockLinear( + input_dim * hidden_factors[idx - 1], + input_dim * hidden_factors[idx], + input_dim, + ) + ) + self.layers.append( + MaskedBlockLinear(input_dim * hidden_factors[-1], input_dim, input_dim) + ) self._cached_logDetJ = None - if residual == 'gated': + if residual == "gated": self.gate = torch.nn.Parameter(torch.nn.init.normal_(torch.Tensor(1))) def _call(self, x): @@ -121,14 +137,18 @@ def _call(self, x): if idx == 0: y = self.T(pre_activation) - J_act = self.T.log_abs_det_jacobian((pre_activation).view( - *(list(x.size()) + [-1, 1])), y.view(*(list(x.size()) + [-1, 1]))) + J_act = self.T.log_abs_det_jacobian( + (pre_activation).view(*(list(x.size()) + [-1, 1])), + y.view(*(list(x.size()) + [-1, 1])), + ) logDetJ = dy_dx + J_act elif idx < len(self.layers) - 1: y = self.T(pre_activation) - J_act = self.T.log_abs_det_jacobian((pre_activation).view( - *(list(x.size()) + [-1, 1])), y.view(*(list(x.size()) + [-1, 1]))) + J_act = self.T.log_abs_det_jacobian( + (pre_activation).view(*(list(x.size()) + [-1, 1])), + y.view(*(list(x.size()) + [-1, 1])), + ) logDetJ = log_matrix_product(dy_dx, logDetJ) + J_act else: @@ -137,11 +157,11 @@ def _call(self, x): self._cached_logDetJ = logDetJ.squeeze(-1).squeeze(-1) - if self.residual == 'normal': + if self.residual == "normal": y = y + x self._cached_logDetJ = F.softplus(self._cached_logDetJ) - elif self.residual == 'gated': - y = self.gate.sigmoid() * x + (1. - self.gate.sigmoid()) * y + elif self.residual == "gated": + y = self.gate.sigmoid() * x + (1.0 - self.gate.sigmoid()) * y term1 = torch.log(self.gate.sigmoid() + eps) log1p_gate = torch.log1p(eps - self.gate.sigmoid()) log_gate = torch.log(self.gate.sigmoid() + eps) @@ -161,7 +181,9 @@ def _inverse(self, y): cached on the forward call) """ - raise KeyError("BlockAutoregressive object expected to find key in intermediates cache but didn't") + raise KeyError( + "BlockAutoregressive object expected to find key in intermediates cache but didn't" + ) def log_abs_det_jacobian(self, x, y): """ @@ -192,27 +214,45 @@ def __init__(self, in_features, out_features, dim, bias=True): # Fill in non-zero entries of block weight matrix, going from top # to bottom. for i in range(dim): - weight[i * out_features // dim:(i + 1) * out_features // dim, - 0:(i + 1) * in_features // dim] = torch.nn.init.xavier_uniform_( - torch.Tensor(out_features // dim, (i + 1) * in_features // dim)) + weight[ + i * out_features // dim : (i + 1) * out_features // dim, + 0 : (i + 1) * in_features // dim, + ] = torch.nn.init.xavier_uniform_( + torch.Tensor(out_features // dim, (i + 1) * in_features // dim) + ) self._weight = torch.nn.Parameter(weight) - self._diag_weight = torch.nn.Parameter(torch.nn.init.uniform_(torch.Tensor(out_features, 1)).log()) - - self.bias = torch.nn.Parameter( - torch.nn.init.uniform_(torch.Tensor(out_features), - -1 / math.sqrt(out_features), - 1 / math.sqrt(out_features))) if bias else 0 + self._diag_weight = torch.nn.Parameter( + torch.nn.init.uniform_(torch.Tensor(out_features, 1)).log() + ) + + self.bias = ( + torch.nn.Parameter( + torch.nn.init.uniform_( + torch.Tensor(out_features), + -1 / math.sqrt(out_features), + 1 / math.sqrt(out_features), + ) + ) + if bias + else 0 + ) # Diagonal block mask - mask_d = torch.eye(dim).unsqueeze(-1).repeat(1, out_features // dim, - in_features // dim).view(out_features, in_features) - self.register_buffer('mask_d', mask_d) + mask_d = ( + torch.eye(dim) + .unsqueeze(-1) + .repeat(1, out_features // dim, in_features // dim) + .view(out_features, in_features) + ) + self.register_buffer("mask_d", mask_d) # Off-diagonal block mask for lower triangular weight matrix mask_o = torch.tril(torch.ones(dim, dim), diagonal=-1).unsqueeze(-1) - mask_o = mask_o.repeat(1, out_features // dim, in_features // dim).view(out_features, in_features) - self.register_buffer('mask_o', mask_o) + mask_o = mask_o.repeat(1, out_features // dim, in_features // dim).view( + out_features, in_features + ) + self.register_buffer("mask_o", mask_o) def get_weights(self): """ @@ -234,7 +274,9 @@ def get_weights(self): # taking the log gives the right hand side below: wpl = self._diag_weight + self._weight - 0.5 * torch.log(w_squared_norm + eps) - return w, wpl[self.mask_d.bool()].view(self.dim, self.out_features // self.dim, self.in_features // self.dim) + return w, wpl[self.mask_d.bool()].view( + self.dim, self.out_features // self.dim, self.in_features // self.dim + ) def forward(self, x): w, wpl = self.get_weights() diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index 768b6ce3a8..3b890f5a3b 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -17,14 +17,16 @@ def _vector_to_l_cholesky(z): x = torch.zeros(z.shape[:-1] + (D, D), dtype=z.dtype, device=z.device) x[..., 0, 0] = 1 - x[..., 1:, 0] = z[..., :(D - 1)] + x[..., 1:, 0] = z[..., : (D - 1)] i = D - 1 last_squared_x = torch.zeros(z.shape[:-1] + (D,), dtype=z.dtype, device=z.device) for j in range(1, D): distance_to_copy = D - 1 - j - last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone()**2 + last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone() ** 2 x[..., j, j] = (1 - last_squared_x[..., 0]).sqrt() - x[..., (j + 1):, j] = z[..., i:(i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt() + x[..., (j + 1) :, j] = ( + z[..., i : (i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt() + ) i += distance_to_copy return x @@ -42,6 +44,7 @@ class CorrLCholeskyTransform(Transform): Section 10.12. """ + domain = constraints.real_vector codomain = constraints.corr_cholesky bijective = True @@ -54,17 +57,21 @@ def _call(self, x): return _vector_to_l_cholesky(z) def _inverse(self, y): - if (y.shape[-2] != y.shape[-1]): - raise ValueError("A matrix that isn't square can't be a Cholesky factor of a correlation matrix") + if y.shape[-2] != y.shape[-1]: + raise ValueError( + "A matrix that isn't square can't be a Cholesky factor of a correlation matrix" + ) D = y.shape[-1] - z_tri = torch.zeros(y.shape[:-2] + (D - 2, D - 2), dtype=y.dtype, device=y.device) - z_stack = [ - y[..., 1:, 0] - ] + z_tri = torch.zeros( + y.shape[:-2] + (D - 2, D - 2), dtype=y.dtype, device=y.device + ) + z_stack = [y[..., 1:, 0]] for i in range(2, D): - z_tri[..., i - 2, 0:(i - 1)] = y[..., i, 1:i] / (1 - y[..., i, 0:(i - 1)].pow(2).cumsum(-1)).sqrt() + z_tri[..., i - 2, 0 : (i - 1)] = ( + y[..., i, 1:i] / (1 - y[..., i, 0 : (i - 1)].pow(2).cumsum(-1)).sqrt() + ) for j in range(D - 2): z_stack.append(z_tri[..., j:, j]) @@ -74,7 +81,9 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): # Note dependence on pytorch 1.0.1 for batched tril tanpart = x.cosh().log().sum(-1).mul(-2) - matpart = (1 - y.pow(2).cumsum(-1).tril(diagonal=-2)).log().div(2).sum(-1).sum(-1) + matpart = ( + (1 - y.pow(2).cumsum(-1).tril(diagonal=-2)).log().div(2).sum(-1).sum(-1) + ) return tanpart + matpart @@ -100,7 +109,9 @@ def log_abs_det_jacobian(self, x, y): # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = x.shape[-1] order = torch.arange(n, 0, -1, dtype=x.dtype, device=x.device) - return -n * math.log(2) - (order * torch.diagonal(y, dim1=-2, dim2=-1).log()).sum(-1) + return -n * math.log(2) - ( + order * torch.diagonal(y, dim1=-2, dim2=-1).log() + ).sum(-1) class CorrMatrixCholeskyTransform(CholeskyTransform): diff --git a/pyro/distributions/transforms/discrete_cosine.py b/pyro/distributions/transforms/discrete_cosine.py index bbd039f1dc..22e7945a86 100644 --- a/pyro/distributions/transforms/discrete_cosine.py +++ b/pyro/distributions/transforms/discrete_cosine.py @@ -24,9 +24,10 @@ class DiscreteCosineTransform(Transform): noise; when -1 this transforms violet noise to white noise; etc. Any real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise. """ + bijective = True - def __init__(self, dim=-1, smooth=0., cache_size=0): + def __init__(self, dim=-1, smooth=0.0, cache_size=0): assert isinstance(dim, int) and dim < 0 self.dim = dim self.smooth = float(smooth) @@ -37,8 +38,11 @@ def __hash__(self): return hash((type(self), self.dim, self.smooth)) def __eq__(self, other): - return (type(self) == type(other) and self.dim == other.dim - and self.smooth == other.smooth) + return ( + type(self) == type(other) + and self.dim == other.dim + and self.smooth == other.smooth + ) @constraints.dependent_property(is_discrete=False) def domain(self): @@ -82,7 +86,7 @@ def _inverse(self, y): return x def log_abs_det_jacobian(self, x, y): - return x.new_zeros(x.shape[:self.dim]) + return x.new_zeros(x.shape[: self.dim]) def with_cache(self, cache_size=1): if self._cache_size == cache_size: diff --git a/pyro/distributions/transforms/generalized_channel_permute.py b/pyro/distributions/transforms/generalized_channel_permute.py index 83ff4a97c0..e56fe72f30 100644 --- a/pyro/distributions/transforms/generalized_channel_permute.py +++ b/pyro/distributions/transforms/generalized_channel_permute.py @@ -30,7 +30,9 @@ def U_diag(self): @property def L(self): - return self.LU.tril(diagonal=-1) + torch.eye(self.LU.size(-1), dtype=self.LU.dtype, device=self.LU.device) + return self.LU.tril(diagonal=-1) + torch.eye( + self.LU.size(-1), dtype=self.LU.dtype, device=self.LU.device + ) @property def U(self): @@ -100,7 +102,9 @@ def log_abs_det_jacobian(self, x, y): h, w = x.shape[-2:] log_det = h * w * self.U_diag.abs().log().sum() - return log_det * torch.ones(x.size()[:-3], dtype=x.dtype, layout=x.layout, device=x.device) + return log_det * torch.ones( + x.size()[:-3], dtype=x.dtype, layout=x.layout, device=x.device + ) @copy_docs_from(ConditionedGeneralizedChannelPermute) @@ -164,7 +168,7 @@ class GeneralizedChannelPermute(ConditionedGeneralizedChannelPermute, TransformM def __init__(self, channels=3, permutation=None): super(GeneralizedChannelPermute, self).__init__() - self.__delattr__('permutation') + self.__delattr__("permutation") # Sample a random orthogonal matrix W, _ = torch.linalg.qr(torch.randn(channels, channels)) @@ -179,11 +183,13 @@ def __init__(self, channels=3, permutation=None): if len(permutation) != channels: raise ValueError( 'Keyword argument "permutation" expected to have {} elements but {} found.'.format( - channels, len(permutation))) + channels, len(permutation) + ) + ) P = torch.eye(channels, channels)[permutation.type(dtype=torch.int64)] # We register the permutation matrix so that the model can be serialized - self.register_buffer('permutation', P) + self.register_buffer("permutation", P) # NOTE: For this implementation I have chosen to store the parameters densely, rather than # storing L, U, and s separately @@ -263,9 +269,13 @@ def __init__(self, nn, channels=3, permutation=None): self.nn = nn self.channels = channels if permutation is None: - permutation = torch.randperm(channels, device='cpu').to(torch.Tensor().device) - P = torch.eye(len(permutation), len(permutation))[permutation.type(dtype=torch.int64)] - self.register_buffer('permutation', P) + permutation = torch.randperm(channels, device="cpu").to( + torch.Tensor().device + ) + P = torch.eye(len(permutation), len(permutation))[ + permutation.type(dtype=torch.int64) + ] + self.register_buffer("permutation", P) def condition(self, context): LU = self.nn(context) diff --git a/pyro/distributions/transforms/haar.py b/pyro/distributions/transforms/haar.py index 74c577b23e..3702ef53a4 100644 --- a/pyro/distributions/transforms/haar.py +++ b/pyro/distributions/transforms/haar.py @@ -24,6 +24,7 @@ class HaarTransform(Transform): :param bool flip: Whether to flip the time axis before applying the Haar transform. Defaults to false. """ + bijective = True def __init__(self, dim=-1, flip=False, cache_size=0): @@ -36,8 +37,11 @@ def __hash__(self): return hash((type(self), self.event_dim, self.flip)) def __eq__(self, other): - return (type(self) == type(other) and self.dim == other.dim and - self.flip == other.flip) + return ( + type(self) == type(other) + and self.dim == other.dim + and self.flip == other.flip + ) @constraints.dependent_property(is_discrete=False) def domain(self): @@ -70,7 +74,7 @@ def _inverse(self, y): return x def log_abs_det_jacobian(self, x, y): - return x.new_zeros(x.shape[:self.dim]) + return x.new_zeros(x.shape[: self.dim]) def with_cache(self, cache_size=1): if self._cache_size == cache_size: diff --git a/pyro/distributions/transforms/householder.py b/pyro/distributions/transforms/householder.py index d920642ed8..33f9f0c4fc 100644 --- a/pyro/distributions/transforms/householder.py +++ b/pyro/distributions/transforms/householder.py @@ -46,7 +46,7 @@ def _call(self, x): u = self.u() for idx in range(u.size(-2)): projection = (u[..., idx, :] * y).sum(dim=-1, keepdim=True) * u[..., idx, :] - y = y - 2. * projection + y = y - 2.0 * projection return y def _inverse(self, y): @@ -64,7 +64,7 @@ def _inverse(self, y): for jdx in reversed(range(u.size(-2))): # NOTE: Need to apply transforms in reverse order from forward operation! projection = (u[..., jdx, :] * x).sum(dim=-1, keepdim=True) * u[..., jdx, :] - x = x - 2. * projection + x = x - 2.0 * projection return x def log_abs_det_jacobian(self, x, y): @@ -73,7 +73,9 @@ def log_abs_det_jacobian(self, x, y): is measure preserving, so :math:`\log(|detJ|) = 0` """ - return torch.zeros(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device) + return torch.zeros( + x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device + ) @copy_docs_from(TransformModule) @@ -131,16 +133,23 @@ def __init__(self, input_dim, count_transforms=1): self.input_dim = input_dim if count_transforms < 1: - raise ValueError('Number of Householder transforms, {}, is less than 1!'.format(count_transforms)) + raise ValueError( + "Number of Householder transforms, {}, is less than 1!".format( + count_transforms + ) + ) elif count_transforms > input_dim: warnings.warn( "Number of Householder transforms, {}, is greater than input dimension {}, which is an \ -over-parametrization!".format(count_transforms, input_dim)) +over-parametrization!".format( + count_transforms, input_dim + ) + ) self.u_unnormed = nn.Parameter(torch.Tensor(count_transforms, input_dim)) self.reset_parameters() def reset_parameters(self): - stdv = 1. / math.sqrt(self.u_unnormed.size(-1)) + stdv = 1.0 / math.sqrt(self.u_unnormed.size(-1)) self.u_unnormed.data.uniform_(-stdv, stdv) @@ -210,11 +219,18 @@ def __init__(self, input_dim, nn, count_transforms=1): self.nn = nn self.input_dim = input_dim if count_transforms < 1: - raise ValueError('Number of Householder transforms, {}, is less than 1!'.format(count_transforms)) + raise ValueError( + "Number of Householder transforms, {}, is less than 1!".format( + count_transforms + ) + ) elif count_transforms > input_dim: warnings.warn( "Number of Householder transforms, {}, is greater than input dimension {}, which is an \ -over-parametrization!".format(count_transforms, input_dim)) +over-parametrization!".format( + count_transforms, input_dim + ) + ) self.count_transforms = count_transforms def _u_unnormed(self, context): @@ -251,7 +267,9 @@ def householder(input_dim, count_transforms=None): return Householder(input_dim, count_transforms=count_transforms) -def conditional_householder(input_dim, context_dim, hidden_dims=None, count_transforms=1): +def conditional_householder( + input_dim, context_dim, hidden_dims=None, count_transforms=1 +): """ A helper function to create a :class:`~pyro.distributions.transforms.ConditionalHouseholder` object that takes diff --git a/pyro/distributions/transforms/lower_cholesky_affine.py b/pyro/distributions/transforms/lower_cholesky_affine.py index 241fbf1bbd..5a17eeb179 100644 --- a/pyro/distributions/transforms/lower_cholesky_affine.py +++ b/pyro/distributions/transforms/lower_cholesky_affine.py @@ -23,6 +23,7 @@ class LowerCholeskyAffine(Transform): :type scale_tril: torch.tensor """ + domain = constraints.real_vector codomain = constraints.real_vector bijective = True @@ -32,9 +33,11 @@ def __init__(self, loc, scale_tril, cache_size=0): super().__init__(cache_size=cache_size) self.loc = loc self.scale_tril = scale_tril - assert loc.size(-1) == scale_tril.size(-1) == scale_tril.size(-2), \ - "loc and scale_tril must be of size D and D x D, respectively (instead: {}, {})".format(loc.shape, - scale_tril.shape) + assert ( + loc.size(-1) == scale_tril.size(-1) == scale_tril.size(-2) + ), "loc and scale_tril must be of size D and D x D, respectively (instead: {}, {})".format( + loc.shape, scale_tril.shape + ) def _call(self, x): """ @@ -54,16 +57,19 @@ def _inverse(self, y): Inverts y => x. """ - return torch.triangular_solve((y - self.loc).unsqueeze(-1), self.scale_tril, - upper=False, transpose=False)[0].squeeze(-1) + return torch.triangular_solve( + (y - self.loc).unsqueeze(-1), self.scale_tril, upper=False, transpose=False + )[0].squeeze(-1) def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log Jacobian, i.e. log(abs(dy/dx)). """ - return torch.ones(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device) * \ - self.scale_tril.diag().log().sum() + return ( + torch.ones(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device) + * self.scale_tril.diag().log().sum() + ) def with_cache(self, cache_size=1): if self._cache_size == cache_size: diff --git a/pyro/distributions/transforms/matrix_exponential.py b/pyro/distributions/transforms/matrix_exponential.py index 773f87ffe5..8b721f060d 100644 --- a/pyro/distributions/transforms/matrix_exponential.py +++ b/pyro/distributions/transforms/matrix_exponential.py @@ -21,7 +21,7 @@ class ConditionedMatrixExponential(Transform): codomain = constraints.real_vector bijective = True - def __init__(self, weights=None, iterations=8, normalization='none', bound=None): + def __init__(self, weights=None, iterations=8, normalization="none", bound=None): super().__init__(cache_size=1) assert iterations > 0 self.weights = weights @@ -31,10 +31,10 @@ def __init__(self, weights=None, iterations=8, normalization='none', bound=None) # Currently, weight and spectral normalization are unimplemented. This doesn't effect the validity of the # bijection, although applying these norms should improve the numerical conditioning of the approximation. - if normalization == 'weight' or normalization == 'spectral': - raise NotImplementedError('Normalization is currently not implemented.') - elif normalization != 'none': - raise ValueError('Unknown normalization method: {}'.format(normalization)) + if normalization == "weight" or normalization == "spectral": + raise NotImplementedError("Normalization is currently not implemented.") + elif normalization != "none": + raise ValueError("Unknown normalization method: {}".format(normalization)) def _exp(self, x, M): """ @@ -151,14 +151,16 @@ class MatrixExponential(ConditionedMatrixExponential, TransformModule): codomain = constraints.real_vector bijective = True - def __init__(self, input_dim, iterations=8, normalization='none', bound=None): - super().__init__(iterations=iterations, normalization=normalization, bound=bound) + def __init__(self, input_dim, iterations=8, normalization="none", bound=None): + super().__init__( + iterations=iterations, normalization=normalization, bound=bound + ) self.weights = nn.Parameter(torch.Tensor(input_dim, input_dim)) self.reset_parameters() def reset_parameters(self): - stdv = 1. / math.sqrt(self.weights.size(0)) + stdv = 1.0 / math.sqrt(self.weights.size(0)) self.weights.data.uniform_(-stdv, stdv) @@ -230,7 +232,7 @@ class ConditionalMatrixExponential(ConditionalTransformModule): codomain = constraints.real_vector bijective = True - def __init__(self, input_dim, nn, iterations=8, normalization='none', bound=None): + def __init__(self, input_dim, nn, iterations=8, normalization="none", bound=None): super().__init__() self.input_dim = input_dim self.nn = nn @@ -248,11 +250,16 @@ def condition(self, context): def weights(): w = cond_nn() return w.view(w.shape[:-1] + (self.input_dim, self.input_dim)) - return ConditionedMatrixExponential(weights, iterations=self.iterations, normalization=self.normalization, - bound=self.bound) + return ConditionedMatrixExponential( + weights, + iterations=self.iterations, + normalization=self.normalization, + bound=self.bound, + ) -def matrix_exponential(input_dim, iterations=8, normalization='none', bound=None): + +def matrix_exponential(input_dim, iterations=8, normalization="none", bound=None): """ A helper function to create a :class:`~pyro.distributions.transforms.MatrixExponential` object for consistency @@ -277,11 +284,19 @@ def matrix_exponential(input_dim, iterations=8, normalization='none', bound=None """ - return MatrixExponential(input_dim, iterations=iterations, normalization=normalization, bound=bound) + return MatrixExponential( + input_dim, iterations=iterations, normalization=normalization, bound=bound + ) -def conditional_matrix_exponential(input_dim, context_dim, hidden_dims=None, iterations=8, normalization='none', - bound=None): +def conditional_matrix_exponential( + input_dim, + context_dim, + hidden_dims=None, + iterations=8, + normalization="none", + bound=None, +): """ A helper function to create a :class:`~pyro.distributions.transforms.ConditionalMatrixExponential` object for @@ -314,4 +329,6 @@ def conditional_matrix_exponential(input_dim, context_dim, hidden_dims=None, ite if hidden_dims is None: hidden_dims = [input_dim * 10, input_dim * 10] nn = DenseNN(context_dim, hidden_dims, param_dims=[input_dim * input_dim]) - return ConditionalMatrixExponential(input_dim, nn, iterations=iterations, normalization=normalization, bound=bound) + return ConditionalMatrixExponential( + input_dim, nn, iterations=iterations, normalization=normalization, bound=bound + ) diff --git a/pyro/distributions/transforms/neural_autoregressive.py b/pyro/distributions/transforms/neural_autoregressive.py index 967e835fb6..4fb550ea72 100644 --- a/pyro/distributions/transforms/neural_autoregressive.py +++ b/pyro/distributions/transforms/neural_autoregressive.py @@ -65,15 +65,16 @@ class NeuralAutoregressive(TransformModule): eps = 1e-8 autoregressive = True - def __init__(self, autoregressive_nn, hidden_units=16, activation='sigmoid'): + def __init__(self, autoregressive_nn, hidden_units=16, activation="sigmoid"): super().__init__(cache_size=1) # Create the intermediate transform used name_to_mixin = { - 'ELU': ELUTransform, - 'LeakyReLU': LeakyReLUTransform, - 'sigmoid': SigmoidTransform, - 'tanh': TanhTransform} + "ELU": ELUTransform, + "LeakyReLU": LeakyReLUTransform, + "sigmoid": SigmoidTransform, + "tanh": TanhTransform, + } if activation not in name_to_mixin: raise ValueError('Invalid activation function "{}"'.format(activation)) self.T = name_to_mixin[activation]() @@ -129,8 +130,12 @@ def log_abs_det_jacobian(self, x, y): T = self.T log_dydD = self._cached_log_df_inv_dx - log_dDdx = torch.logsumexp(torch.log(A + self.eps) + self.logsoftmax(W_pre) + - T.log_abs_det_jacobian(C, T_C), dim=-2) + log_dDdx = torch.logsumexp( + torch.log(A + self.eps) + + self.logsoftmax(W_pre) + + T.log_abs_det_jacobian(C, T_C), + dim=-2, + ) log_det = log_dydD + log_dDdx return log_det.sum(-1) @@ -204,7 +209,7 @@ def condition(self, context): return NeuralAutoregressive(cond_nn, **self.kwargs) -def neural_autoregressive(input_dim, hidden_dims=None, activation='sigmoid', width=16): +def neural_autoregressive(input_dim, hidden_dims=None, activation="sigmoid", width=16): """ A helper function to create a :class:`~pyro.distributions.transforms.NeuralAutoregressive` object that takes @@ -231,7 +236,9 @@ def neural_autoregressive(input_dim, hidden_dims=None, activation='sigmoid', wid return NeuralAutoregressive(arn, hidden_units=width, activation=activation) -def conditional_neural_autoregressive(input_dim, context_dim, hidden_dims=None, activation='sigmoid', width=16): +def conditional_neural_autoregressive( + input_dim, context_dim, hidden_dims=None, activation="sigmoid", width=16 +): """ A helper function to create a :class:`~pyro.distributions.transforms.ConditionalNeuralAutoregressive` object @@ -256,5 +263,9 @@ def conditional_neural_autoregressive(input_dim, context_dim, hidden_dims=None, if hidden_dims is None: hidden_dims = [3 * input_dim + 1] - arn = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims, param_dims=[width] * 3) - return ConditionalNeuralAutoregressive(arn, hidden_units=width, activation=activation) + arn = ConditionalAutoRegressiveNN( + input_dim, context_dim, hidden_dims, param_dims=[width] * 3 + ) + return ConditionalNeuralAutoregressive( + arn, hidden_units=width, activation=activation + ) diff --git a/pyro/distributions/transforms/normalize.py b/pyro/distributions/transforms/normalize.py index 8769955b4e..612ea2ea8e 100644 --- a/pyro/distributions/transforms/normalize.py +++ b/pyro/distributions/transforms/normalize.py @@ -15,6 +15,7 @@ class Normalize(Transform): Safely project a vector onto the sphere wrt the ``p`` norm. This avoids the singularity at zero by mapping to the vector ``[1, 0, 0, ..., 0]``. """ + domain = constraints.real_vector codomain = constraints.sphere bijective = False diff --git a/pyro/distributions/transforms/ordered.py b/pyro/distributions/transforms/ordered.py index ff1de6f5b3..d6470e13d1 100644 --- a/pyro/distributions/transforms/ordered.py +++ b/pyro/distributions/transforms/ordered.py @@ -15,6 +15,7 @@ class OrderedTransform(Transform): of a given tensor via the transformation :math:`y_0 = x_0`, :math:`y_i = \\sum_{1 \\le j \\le i} \\exp(x_i)` """ + domain = constraints.real_vector codomain = constraints.ordered_vector bijective = True diff --git a/pyro/distributions/transforms/permute.py b/pyro/distributions/transforms/permute.py index 97330e5213..748b567443 100644 --- a/pyro/distributions/transforms/permute.py +++ b/pyro/distributions/transforms/permute.py @@ -66,9 +66,9 @@ def codomain(self): @lazy_property def inv_permutation(self): result = torch.empty_like(self.permutation, dtype=torch.long) - result[self.permutation] = torch.arange(self.permutation.size(0), - dtype=torch.long, - device=self.permutation.device) + result[self.permutation] = torch.arange( + self.permutation.size(0), dtype=torch.long, device=self.permutation.device + ) return result def _call(self, x): @@ -101,7 +101,9 @@ def log_abs_det_jacobian(self, x, y): determinant is -1 or +1), and so returning a vector of zeros works. """ - return torch.zeros(x.size()[:-self.event_dim], dtype=x.dtype, layout=x.layout, device=x.device) + return torch.zeros( + x.size()[: -self.event_dim], dtype=x.dtype, layout=x.layout, device=x.device + ) def with_cache(self, cache_size=1): if self._cache_size == cache_size: @@ -127,7 +129,11 @@ def permute(input_dim, permutation=None, dim=-1): """ if dim < -1 or not isinstance(input_dim, int): if len(input_dim) != -dim: - raise ValueError('event shape {} must have same length as event_dim {}'.format(input_dim, -dim)) + raise ValueError( + "event shape {} must have same length as event_dim {}".format( + input_dim, -dim + ) + ) input_dim = input_dim[dim] if permutation is None: diff --git a/pyro/distributions/transforms/planar.py b/pyro/distributions/transforms/planar.py index c706e6680a..0aab945e2c 100644 --- a/pyro/distributions/transforms/planar.py +++ b/pyro/distributions/transforms/planar.py @@ -46,13 +46,21 @@ def _call(self, x): # x ~ (batch_size, dim_size, 1) # w ~ (batch_size, 1, dim_size) # bias ~ (batch_size, 1) - act = torch.tanh(torch.matmul(w.unsqueeze(-2), x.unsqueeze(-1)).squeeze(-1) + bias) + act = torch.tanh( + torch.matmul(w.unsqueeze(-2), x.unsqueeze(-1)).squeeze(-1) + bias + ) u_hat = self.u_hat(u, w) y = x + u_hat * act - psi_z = (1. - act.pow(2)) * w + psi_z = (1.0 - act.pow(2)) * w self._cached_logDetJ = torch.log( - torch.abs(1 + torch.matmul(psi_z.unsqueeze(-2), u_hat.unsqueeze(-1)).squeeze(-1).squeeze(-1))) + torch.abs( + 1 + + torch.matmul(psi_z.unsqueeze(-2), u_hat.unsqueeze(-1)) + .squeeze(-1) + .squeeze(-1) + ) + ) return y @@ -66,7 +74,9 @@ def _inverse(self, y): cached on the forward call) """ - raise KeyError("ConditionedPlanar object expected to find key in intermediates cache but didn't") + raise KeyError( + "ConditionedPlanar object expected to find key in intermediates cache but didn't" + ) def log_abs_det_jacobian(self, x, y): """ @@ -127,9 +137,21 @@ class Planar(ConditionedPlanar, TransformModule): def __init__(self, input_dim): super().__init__(self._params) - self.bias = nn.Parameter(torch.Tensor(1,)) - self.u = nn.Parameter(torch.Tensor(input_dim,)) - self.w = nn.Parameter(torch.Tensor(input_dim,)) + self.bias = nn.Parameter( + torch.Tensor( + 1, + ) + ) + self.u = nn.Parameter( + torch.Tensor( + input_dim, + ) + ) + self.w = nn.Parameter( + torch.Tensor( + input_dim, + ) + ) self.input_dim = input_dim self.reset_parameters() @@ -138,7 +160,7 @@ def _params(self): return self.bias, self.u, self.w def reset_parameters(self): - stdv = 1. / math.sqrt(self.u.size(0)) + stdv = 1.0 / math.sqrt(self.u.size(0)) self.w.data.uniform_(-stdv, stdv) self.u.data.uniform_(-stdv, stdv) self.bias.data.zero_() diff --git a/pyro/distributions/transforms/polynomial.py b/pyro/distributions/transforms/polynomial.py index e9dc4e951f..ba765cc2cf 100644 --- a/pyro/distributions/transforms/polynomial.py +++ b/pyro/distributions/transforms/polynomial.py @@ -87,18 +87,20 @@ def __init__(self, autoregressive_nn, input_dim, count_degree, count_sum): # Vector of powers of input dimension powers = torch.arange(1, count_degree + 2, dtype=torch.get_default_dtype()) - self.register_buffer('powers', powers) + self.register_buffer("powers", powers) # Build mask of constants - mask = self.powers + torch.arange(count_degree + 1).unsqueeze(-1).type_as(powers) + mask = self.powers + torch.arange(count_degree + 1).unsqueeze(-1).type_as( + powers + ) power_mask = mask mask = mask.reciprocal() - self.register_buffer('power_mask', power_mask) - self.register_buffer('mask', mask) + self.register_buffer("power_mask", power_mask) + self.register_buffer("mask", mask) def reset_parameters(self): - stdv = 1. / math.sqrt(self.c.size(0)) + stdv = 1.0 / math.sqrt(self.c.size(0)) self.c.data.uniform_(-stdv, stdv) def _call(self, x): @@ -124,12 +126,16 @@ def _call(self, x): # Eq (8) from the paper, expanding the squared term and integrating # NOTE: The view_as is necessary because the batch dimensions were collapsed previously - y = self.c + (coefs * x_pow_matrix * self.mask.unsqueeze(-1)).sum((1, 2, 3)).view_as(x) + y = self.c + (coefs * x_pow_matrix * self.mask.unsqueeze(-1)).sum( + (1, 2, 3) + ).view_as(x) # log(|det(J)|) is calculated by the fundamental theorem of calculus, i.e. remove the constant # term and the integral from eq (8) (the equation for this isn't given in the paper) x_pow_matrix = x_view.pow(self.power_mask.unsqueeze(-1) - 1).unsqueeze(-4) - self._cached_logDetJ = torch.log((coefs * x_pow_matrix).sum((1, 2, 3)).view_as(x) + 1e-8).sum(-1) + self._cached_logDetJ = torch.log( + (coefs * x_pow_matrix).sum((1, 2, 3)).view_as(x) + 1e-8 + ).sum(-1) return y @@ -144,7 +150,9 @@ def _inverse(self, y): cached on the forward call) """ - raise KeyError("Polynomial object expected to find key in intermediates cache but didn't") + raise KeyError( + "Polynomial object expected to find key in intermediates cache but didn't" + ) def log_abs_det_jacobian(self, x, y): """ @@ -176,5 +184,9 @@ def polynomial(input_dim, hidden_dims=None): count_sum = 3 if hidden_dims is None: hidden_dims = [input_dim * 10] - arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=[(count_degree + 1) * count_sum]) - return Polynomial(arn, input_dim=input_dim, count_degree=count_degree, count_sum=count_sum) + arn = AutoRegressiveNN( + input_dim, hidden_dims, param_dims=[(count_degree + 1) * count_sum] + ) + return Polynomial( + arn, input_dim=input_dim, count_degree=count_degree, count_sum=count_sum + ) diff --git a/pyro/distributions/transforms/radial.py b/pyro/distributions/transforms/radial.py index 05ba830fbd..e44c39ea9c 100644 --- a/pyro/distributions/transforms/radial.py +++ b/pyro/distributions/transforms/radial.py @@ -42,7 +42,9 @@ def _call(self, x): :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ - x0, alpha_prime, beta_prime = self._params() if callable(self._params) else self._params + x0, alpha_prime, beta_prime = ( + self._params() if callable(self._params) else self._params + ) # Ensure invertibility using approach in appendix A.2 alpha = F.softplus(alpha_prime) @@ -52,11 +54,13 @@ def _call(self, x): diff = x - x0 r = diff.norm(dim=-1, keepdim=True) h = (alpha + r).reciprocal() - h_prime = - (h ** 2) + h_prime = -(h ** 2) beta_h = beta * h - self._cached_logDetJ = ((x0.size(-1) - 1) * torch.log1p(beta_h) + - torch.log1p(beta_h + beta * h_prime * r)).sum(-1) + self._cached_logDetJ = ( + (x0.size(-1) - 1) * torch.log1p(beta_h) + + torch.log1p(beta_h + beta * h_prime * r) + ).sum(-1) return x + beta_h * diff def _inverse(self, y): @@ -69,7 +73,9 @@ def _inverse(self, y): cached on the forward call) """ - raise KeyError("ConditionedRadial object expected to find key in intermediates cache but didn't") + raise KeyError( + "ConditionedRadial object expected to find key in intermediates cache but didn't" + ) def log_abs_det_jacobian(self, x, y): """ @@ -128,9 +134,21 @@ class Radial(ConditionedRadial, TransformModule): def __init__(self, input_dim): super().__init__(self._params) - self.x0 = nn.Parameter(torch.Tensor(input_dim,)) - self.alpha_prime = nn.Parameter(torch.Tensor(1,)) - self.beta_prime = nn.Parameter(torch.Tensor(1,)) + self.x0 = nn.Parameter( + torch.Tensor( + input_dim, + ) + ) + self.alpha_prime = nn.Parameter( + torch.Tensor( + 1, + ) + ) + self.beta_prime = nn.Parameter( + torch.Tensor( + 1, + ) + ) self.input_dim = input_dim self.reset_parameters() @@ -138,7 +156,7 @@ def _params(self): return self.x0, self.alpha_prime, self.beta_prime def reset_parameters(self): - stdv = 1. / math.sqrt(self.x0.size(0)) + stdv = 1.0 / math.sqrt(self.x0.size(0)) self.alpha_prime.data.uniform_(-stdv, stdv) self.beta_prime.data.uniform_(-stdv, stdv) self.x0.data.uniform_(-stdv, stdv) diff --git a/pyro/distributions/transforms/softplus.py b/pyro/distributions/transforms/softplus.py index b58a889612..857398a42c 100644 --- a/pyro/distributions/transforms/softplus.py +++ b/pyro/distributions/transforms/softplus.py @@ -39,6 +39,7 @@ class SoftplusLowerCholeskyTransform(Transform): nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization. """ + domain = constraints.independent(constraints.real, 2) codomain = constraints.lower_cholesky @@ -55,6 +56,6 @@ def _inverse(self, y): __all__ = [ - 'SoftplusTransform', - 'SoftplusLowerCholeskyTransform', + "SoftplusTransform", + "SoftplusLowerCholeskyTransform", ] diff --git a/pyro/distributions/transforms/spline.py b/pyro/distributions/transforms/spline.py index c9cf2b09d1..ba7d240a1b 100644 --- a/pyro/distributions/transforms/spline.py +++ b/pyro/distributions/transforms/spline.py @@ -31,10 +31,7 @@ def _searchsorted(sorted_sequence, values): TODO: Replace with torch.searchsorted once it is released """ - return torch.sum( - values[..., None] >= sorted_sequence, - dim=-1 - ) - 1 + return torch.sum(values[..., None] >= sorted_sequence, dim=-1) - 1 def _select_bins(x, idx): @@ -69,7 +66,7 @@ def _calculate_knots(lengths, lower, upper): knots = torch.cumsum(lengths, dim=-1) # Pad left of last dimension with 1 zero to compensate for dim lost to cumsum - knots = F.pad(knots, pad=(1, 0), mode='constant', value=0.0) + knots = F.pad(knots, pad=(1, 0), mode="constant", value=0.0) # Translate [0,1] knot points to [-B, B] knots = (upper - lower) * knots + lower @@ -83,18 +80,20 @@ def _calculate_knots(lengths, lower, upper): return lengths, knots -def _monotonic_rational_spline(inputs, - widths, - heights, - derivatives, - lambdas=None, - inverse=False, - bound=3., - min_bin_width=1e-3, - min_bin_height=1e-3, - min_derivative=1e-3, - min_lambda=0.025, - eps=1e-6): +def _monotonic_rational_spline( + inputs, + widths, + heights, + derivatives, + lambdas=None, + inverse=False, + bound=3.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, + min_lambda=0.025, + eps=1e-6, +): """ Calculating a monotonic rational spline (linear or quadratic) or its inverse, plus the log(abs(detJ)) required for normalizing flows. @@ -110,9 +109,9 @@ def _monotonic_rational_spline(inputs, num_bins = widths.shape[-1] if min_bin_width * num_bins > 1.0: - raise ValueError('Minimal bin width too large for the number of bins') + raise ValueError("Minimal bin width too large for the number of bins") if min_bin_height * num_bins > 1.0: - raise ValueError('Minimal bin height too large for the number of bins') + raise ValueError("Minimal bin height too large for the number of bins") # inputs, inside_interval_mask, outside_interval_mask ~ (batch_dim, input_dim) left, right = -bound, bound @@ -140,11 +139,15 @@ def _monotonic_rational_spline(inputs, # Pad left and right derivatives with fixed values at first and last knots # These are 1 since the function is the identity outside the bounding box and the derivative is continuous # NOTE: Not sure why this is 1.0 - min_derivative rather than 1.0. I've copied this from original implementation - derivatives = F.pad(derivatives, pad=(1, 1), mode='constant', value=1.0 - min_derivative) + derivatives = F.pad( + derivatives, pad=(1, 1), mode="constant", value=1.0 - min_derivative + ) # Get the index of the bin that each input is in # bin_idx ~ (batch_dim, input_dim, 1) - bin_idx = _searchsorted(cumheights + eps if inverse else cumwidths + eps, inputs).unsqueeze(-1) + bin_idx = _searchsorted( + cumheights + eps if inverse else cumwidths + eps, inputs + ).unsqueeze(-1) # Select the value for the relevant bin for the variables used in the main calculation input_widths = _select_bins(widths, bin_idx) @@ -171,60 +174,84 @@ def _monotonic_rational_spline(inputs, # The weight, w_c, at the division point of each bin # Recall that each bin is divided into two parts so we have enough d.o.f. to fit spline - wc = (input_lambdas * wa * input_derivatives + (1 - input_lambdas) - * wb * input_derivatives_plus_one) / input_delta + wc = ( + input_lambdas * wa * input_derivatives + + (1 - input_lambdas) * wb * input_derivatives_plus_one + ) / input_delta # Calculate y coords of bins ya = input_cumheights yb = input_heights + input_cumheights - yc = ((1.0 - input_lambdas) * wa * ya + input_lambdas * wb * yb) / \ - ((1.0 - input_lambdas) * wa + input_lambdas * wb) + yc = ((1.0 - input_lambdas) * wa * ya + input_lambdas * wb * yb) / ( + (1.0 - input_lambdas) * wa + input_lambdas * wb + ) if inverse: - numerator = (input_lambdas * wa * (ya - inputs)) * (inputs <= yc).float() \ - + ((wc - input_lambdas * wb) * inputs + input_lambdas * wb * yb - wc * yc) * (inputs > yc).float() - - denominator = ((wc - wa) * inputs + wa * ya - wc * yc) * (inputs <= yc).float()\ - + ((wc - wb) * inputs + wb * yb - wc * yc) * (inputs > yc).float() + numerator = (input_lambdas * wa * (ya - inputs)) * ( + inputs <= yc + ).float() + ( + (wc - input_lambdas * wb) * inputs + input_lambdas * wb * yb - wc * yc + ) * ( + inputs > yc + ).float() + + denominator = ((wc - wa) * inputs + wa * ya - wc * yc) * ( + inputs <= yc + ).float() + ((wc - wb) * inputs + wb * yb - wc * yc) * (inputs > yc).float() theta = numerator / denominator outputs = theta * input_widths + input_cumwidths - derivative_numerator = (wa * wc * input_lambdas * (yc - ya) * (inputs <= yc).float() - + wb * wc * (1 - input_lambdas) * (yb - yc) * (inputs > yc).float()) * input_widths + derivative_numerator = ( + wa * wc * input_lambdas * (yc - ya) * (inputs <= yc).float() + + wb * wc * (1 - input_lambdas) * (yb - yc) * (inputs > yc).float() + ) * input_widths - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log( + torch.abs(denominator) + ) else: theta = (inputs - input_cumwidths) / input_widths - numerator = (wa * ya * (input_lambdas - theta) + wc * yc * theta) * (theta <= input_lambdas).float()\ - + (wc * yc * (1 - theta) + wb * yb * (theta - input_lambdas)) * (theta > input_lambdas).float() + numerator = (wa * ya * (input_lambdas - theta) + wc * yc * theta) * ( + theta <= input_lambdas + ).float() + (wc * yc * (1 - theta) + wb * yb * (theta - input_lambdas)) * ( + theta > input_lambdas + ).float() - denominator = (wa * (input_lambdas - theta) + wc * theta) * (theta <= input_lambdas).float()\ - + (wc * (1 - theta) + wb * (theta - input_lambdas)) * (theta > input_lambdas).float() + denominator = (wa * (input_lambdas - theta) + wc * theta) * ( + theta <= input_lambdas + ).float() + (wc * (1 - theta) + wb * (theta - input_lambdas)) * ( + theta > input_lambdas + ).float() outputs = numerator / denominator - derivative_numerator = (wa * wc * input_lambdas * (yc - ya) * (theta <= input_lambdas).float() + - wb * wc * (1 - input_lambdas) * (yb - yc) * (theta > input_lambdas).float()) \ - / input_widths + derivative_numerator = ( + wa * wc * input_lambdas * (yc - ya) * (theta <= input_lambdas).float() + + wb + * wc + * (1 - input_lambdas) + * (yb - yc) + * (theta > input_lambdas).float() + ) / input_widths - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log( + torch.abs(denominator) + ) # Calculate monotonic *quadratic* rational spline else: if inverse: - a = (((inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta) - + input_heights * (input_delta - input_derivatives))) - b = (input_heights * input_derivatives - - (inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta)) - c = - input_delta * (inputs - input_cumheights) + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c assert (discriminant >= 0).all() @@ -233,26 +260,35 @@ def _monotonic_rational_spline(inputs, outputs = root * input_widths + input_cumwidths theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2)) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) logabsdet = -(torch.log(derivative_numerator) - 2 * torch.log(denominator)) else: theta = (inputs - input_cumwidths) / input_widths theta_one_minus_theta = theta * (1 - theta) - numerator = input_heights * (input_delta * theta.pow(2) - + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) outputs = input_cumheights + numerator / denominator - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2)) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) # Apply the identity function outside the bounding box @@ -267,11 +303,12 @@ class ConditionedSpline(Transform): Helper class to manage learnable splines. One could imagine this as a standard layer in PyTorch... """ + domain = constraints.real codomain = constraints.real bijective = True - def __init__(self, params, bound=3.0, order='linear'): + def __init__(self, params, bound=3.0, order="linear"): super().__init__(cache_size=1) self._params = params @@ -310,7 +347,9 @@ def log_abs_det_jacobian(self, x, y): def spline_op(self, x, **kwargs): w, h, d, l = self._params() if callable(self._params) else self._params - y, log_detJ = _monotonic_rational_spline(x, w, h, d, l, bound=self.bound, **kwargs) + y, log_detJ = _monotonic_rational_spline( + x, w, h, d, l, bound=self.bound, **kwargs + ) return y, log_detJ @@ -368,7 +407,7 @@ class Spline(ConditionedSpline, TransformModule): codomain = constraints.real bijective = True - def __init__(self, input_dim, count_bins=8, bound=3., order='linear'): + def __init__(self, input_dim, count_bins=8, bound=3.0, order="linear"): super(Spline, self).__init__(self._params) self.input_dim = input_dim @@ -376,24 +415,34 @@ def __init__(self, input_dim, count_bins=8, bound=3., order='linear'): self.bound = bound self.order = order - self.unnormalized_widths = nn.Parameter(torch.randn(self.input_dim, self.count_bins)) - self.unnormalized_heights = nn.Parameter(torch.randn(self.input_dim, self.count_bins)) - self.unnormalized_derivatives = nn.Parameter(torch.randn(self.input_dim, self.count_bins - 1)) + self.unnormalized_widths = nn.Parameter( + torch.randn(self.input_dim, self.count_bins) + ) + self.unnormalized_heights = nn.Parameter( + torch.randn(self.input_dim, self.count_bins) + ) + self.unnormalized_derivatives = nn.Parameter( + torch.randn(self.input_dim, self.count_bins - 1) + ) # Rational linear splines have additional lambda parameters if self.order == "linear": - self.unnormalized_lambdas = nn.Parameter(torch.rand(self.input_dim, self.count_bins)) + self.unnormalized_lambdas = nn.Parameter( + torch.rand(self.input_dim, self.count_bins) + ) elif self.order != "quadratic": raise ValueError( "Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format( - self.order)) + self.order + ) + ) def _params(self): # widths, unnormalized_widths ~ (input_dim, num_bins) w = F.softmax(self.unnormalized_widths, dim=-1) h = F.softmax(self.unnormalized_heights, dim=-1) d = F.softplus(self.unnormalized_derivatives) - if self.order == 'linear': + if self.order == "linear": l = torch.sigmoid(self.unnormalized_lambdas) else: l = None @@ -467,7 +516,7 @@ class ConditionalSpline(ConditionalTransformModule): codomain = constraints.real bijective = True - def __init__(self, nn, input_dim, count_bins, bound=3.0, order='linear'): + def __init__(self, nn, input_dim, count_bins, bound=3.0, order="linear"): super().__init__() self.nn = nn @@ -493,7 +542,9 @@ def _params(self, context): else: raise ValueError( "Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format( - self.order)) + self.order + ) + ) # AutoRegressiveNN and DenseNN return different shapes... if w.shape[-1] == self.input_dim: @@ -532,7 +583,9 @@ def spline(input_dim, **kwargs): return Spline(input_dim, **kwargs) -def conditional_spline(input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, order='linear'): +def conditional_spline( + input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear" +): """ A helper function to create a :class:`~pyro.distributions.transforms.ConditionalSpline` object that takes care @@ -558,22 +611,33 @@ def conditional_spline(input_dim, context_dim, hidden_dims=None, count_bins=8, b if hidden_dims is None: hidden_dims = [input_dim * 10, input_dim * 10] - if order == 'linear': - nn = DenseNN(context_dim, - hidden_dims, - param_dims=[input_dim * count_bins, - input_dim * count_bins, - input_dim * (count_bins - 1), - input_dim * count_bins]) - elif order == 'quadratic': - nn = DenseNN(context_dim, - hidden_dims, - param_dims=[input_dim * count_bins, - input_dim * count_bins, - input_dim * (count_bins - 1)]) + if order == "linear": + nn = DenseNN( + context_dim, + hidden_dims, + param_dims=[ + input_dim * count_bins, + input_dim * count_bins, + input_dim * (count_bins - 1), + input_dim * count_bins, + ], + ) + elif order == "quadratic": + nn = DenseNN( + context_dim, + hidden_dims, + param_dims=[ + input_dim * count_bins, + input_dim * count_bins, + input_dim * (count_bins - 1), + ], + ) else: - raise ValueError("Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format( - order)) + raise ValueError( + "Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format( + order + ) + ) return ConditionalSpline(nn, input_dim, count_bins, bound=bound, order=order) diff --git a/pyro/distributions/transforms/spline_autoregressive.py b/pyro/distributions/transforms/spline_autoregressive.py index db72d803f9..48b450bff1 100644 --- a/pyro/distributions/transforms/spline_autoregressive.py +++ b/pyro/distributions/transforms/spline_autoregressive.py @@ -76,16 +76,13 @@ class SplineAutoregressive(TransformModule): autoregressive = True def __init__( - self, - input_dim, - autoregressive_nn, - count_bins=8, - bound=3., - order='linear' + self, input_dim, autoregressive_nn, count_bins=8, bound=3.0, order="linear" ): super(SplineAutoregressive, self).__init__(cache_size=1) self.arn = autoregressive_nn - self.spline = ConditionalSpline(autoregressive_nn, input_dim, count_bins, bound, order) + self.spline = ConditionalSpline( + autoregressive_nn, input_dim, count_bins, bound, order + ) def _call(self, x): """ @@ -220,7 +217,9 @@ def condition(self, context): return SplineAutoregressive(self.input_dim, cond_nn, **self.kwargs) -def spline_autoregressive(input_dim, hidden_dims=None, count_bins=8, bound=3.0, order='linear'): +def spline_autoregressive( + input_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear" +): r""" A helper function to create an :class:`~pyro.distributions.transforms.SplineAutoregressive` object that takes @@ -247,11 +246,14 @@ def spline_autoregressive(input_dim, hidden_dims=None, count_bins=8, bound=3.0, param_dims = [count_bins, count_bins, count_bins - 1, count_bins] arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims) - return SplineAutoregressive(input_dim, arn, count_bins=count_bins, bound=bound, order=order) + return SplineAutoregressive( + input_dim, arn, count_bins=count_bins, bound=bound, order=order + ) -def conditional_spline_autoregressive(input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, - order='linear'): +def conditional_spline_autoregressive( + input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear" +): r""" A helper function to create a :class:`~pyro.distributions.transforms.ConditionalSplineAutoregressive` object @@ -279,5 +281,9 @@ def conditional_spline_autoregressive(input_dim, context_dim, hidden_dims=None, hidden_dims = [input_dim * 10, input_dim * 10] param_dims = [count_bins, count_bins, count_bins - 1, count_bins] - arn = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims, param_dims=param_dims) - return ConditionalSplineAutoregressive(input_dim, arn, count_bins=count_bins, bound=bound, order=order) + arn = ConditionalAutoRegressiveNN( + input_dim, context_dim, hidden_dims, param_dims=param_dims + ) + return ConditionalSplineAutoregressive( + input_dim, arn, count_bins=count_bins, bound=bound, order=order + ) diff --git a/pyro/distributions/transforms/spline_coupling.py b/pyro/distributions/transforms/spline_coupling.py index a83d726c46..c7ee114c88 100644 --- a/pyro/distributions/transforms/spline_coupling.py +++ b/pyro/distributions/transforms/spline_coupling.py @@ -78,13 +78,24 @@ class SplineCoupling(TransformModule): codomain = constraints.real_vector bijective = True - def __init__(self, input_dim, split_dim, hypernet, count_bins=8, bound=3., order='linear', identity=False): + def __init__( + self, + input_dim, + split_dim, + hypernet, + count_bins=8, + bound=3.0, + order="linear", + identity=False, + ): super(SplineCoupling, self).__init__(cache_size=1) # One part of the input is (optionally) put through an element-wise spline and the other part through a # conditional one that inputs the first part. self.lower_spline = Spline(split_dim, count_bins, bound, order) - self.upper_spline = ConditionalSpline(hypernet, input_dim - split_dim, count_bins, bound, order) + self.upper_spline = ConditionalSpline( + hypernet, input_dim - split_dim, count_bins, bound, order + ) self.split_dim = split_dim self.identity = identity @@ -97,7 +108,7 @@ def _call(self, x): :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ - x1, x2 = x[..., :self.split_dim], x[..., self.split_dim:] + x1, x2 = x[..., : self.split_dim], x[..., self.split_dim :] if not self.identity: y1 = self.lower_spline(x1) @@ -123,7 +134,7 @@ def _inverse(self, y): Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ - y1, y2 = y[..., :self.split_dim], y[..., self.split_dim:] + y1, y2 = y[..., : self.split_dim], y[..., self.split_dim :] if not self.identity: x1 = self.lower_spline._inv_call(y1) @@ -154,7 +165,9 @@ def log_abs_det_jacobian(self, x, y): return self._cache_log_detJ.sum(-1) -def spline_coupling(input_dim, split_dim=None, hidden_dims=None, count_bins=8, bound=3.0): +def spline_coupling( + input_dim, split_dim=None, hidden_dims=None, count_bins=8, bound=3.0 +): """ A helper function to create a :class:`~pyro.distributions.transforms.SplineCoupling` object for consistency @@ -171,11 +184,15 @@ def spline_coupling(input_dim, split_dim=None, hidden_dims=None, count_bins=8, b if hidden_dims is None: hidden_dims = [input_dim * 10, input_dim * 10] - nn = DenseNN(split_dim, - hidden_dims, - param_dims=[(input_dim - split_dim) * count_bins, - (input_dim - split_dim) * count_bins, - (input_dim - split_dim) * (count_bins - 1), - (input_dim - split_dim) * count_bins]) + nn = DenseNN( + split_dim, + hidden_dims, + param_dims=[ + (input_dim - split_dim) * count_bins, + (input_dim - split_dim) * count_bins, + (input_dim - split_dim) * (count_bins - 1), + (input_dim - split_dim) * count_bins, + ], + ) return SplineCoupling(input_dim, split_dim, nn, count_bins, bound) diff --git a/pyro/distributions/transforms/sylvester.py b/pyro/distributions/transforms/sylvester.py index c9bde3e427..0e63a3470a 100644 --- a/pyro/distributions/transforms/sylvester.py +++ b/pyro/distributions/transforms/sylvester.py @@ -69,7 +69,7 @@ def __init__(self, input_dim, count_transforms=1): # Register masks and indices triangular_mask = torch.triu(torch.ones(input_dim, input_dim), diagonal=1) - self.register_buffer('triangular_mask', triangular_mask) + self.register_buffer("triangular_mask", triangular_mask) self._cached_logDetJ = None self.tanh = nn.Tanh() @@ -77,7 +77,7 @@ def __init__(self, input_dim, count_transforms=1): # Derivative of hyperbolic tan def dtanh_dx(self, x): - return 1. - self.tanh(x).pow(2) + return 1.0 - self.tanh(x).pow(2) # Construct upper diagonal R matrix def R(self): @@ -90,11 +90,14 @@ def S(self): # Construct orthonomal matrix using Householder flow def Q(self, x): u = self.u() - partial_Q = torch.eye(self.input_dim, dtype=x.dtype, layout=x.layout, - device=x.device) - 2. * torch.ger(u[0], u[0]) + partial_Q = torch.eye( + self.input_dim, dtype=x.dtype, layout=x.layout, device=x.device + ) - 2.0 * torch.ger(u[0], u[0]) for idx in range(1, self.u_unnormed.size(-2)): - partial_Q = torch.matmul(partial_Q, torch.eye(self.input_dim) - 2. * torch.ger(u[idx], u[idx])) + partial_Q = torch.matmul( + partial_Q, torch.eye(self.input_dim) - 2.0 * torch.ger(u[idx], u[idx]) + ) return partial_Q @@ -122,7 +125,9 @@ def _call(self, x): preactivation = torch.matmul(x, B) + self.b y = x + torch.matmul(self.tanh(preactivation), A) - self._cached_logDetJ = torch.log1p(self.dtanh_dx(preactivation) * R.diagonal() * S.diagonal() + 1e-8).sum(-1) + self._cached_logDetJ = torch.log1p( + self.dtanh_dx(preactivation) * R.diagonal() * S.diagonal() + 1e-8 + ).sum(-1) return y def _inverse(self, y): @@ -135,7 +140,9 @@ def _inverse(self, y): cached on the forward call) """ - raise KeyError("Sylvester object expected to find key in intermediates cache but didn't") + raise KeyError( + "Sylvester object expected to find key in intermediates cache but didn't" + ) def log_abs_det_jacobian(self, x, y): """ diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index 4e3608a6e2..455d5d24d8 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -16,7 +16,8 @@ class Unit(TorchDistribution): This is used for :func:`pyro.factor` statements. """ - arg_constraints = {'log_factor': constraints.real} + + arg_constraints = {"log_factor": constraints.real} support = constraints.real def __init__(self, log_factor, validate_args=None): diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index bb11328a8a..3bcc9c2d1a 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -31,26 +31,26 @@ def decorator(destin_class): # if not destin_class.__doc__: # destin_class.__doc__ = source_class.__doc__ for name in dir(destin_class): - if name.startswith('_'): + if name.startswith("_"): continue destin_attr = getattr(destin_class, name) - destin_attr = getattr(destin_attr, '__func__', destin_attr) + destin_attr = getattr(destin_attr, "__func__", destin_attr) source_attr = getattr(source_class, name, None) - source_doc = getattr(source_attr, '__doc__', None) - if source_doc and not getattr(destin_attr, '__doc__', None): - if full_text or source_doc.startswith('See '): + source_doc = getattr(source_attr, "__doc__", None) + if source_doc and not getattr(destin_attr, "__doc__", None): + if full_text or source_doc.startswith("See "): destin_doc = source_doc else: - destin_doc = 'See :meth:`{}.{}.{}`'.format( - source_class.__module__, source_class.__name__, name) + destin_doc = "See :meth:`{}.{}.{}`".format( + source_class.__module__, source_class.__name__, name + ) if isinstance(destin_attr, property): # Set docs for object properties. # Since __doc__ is read-only, we need to reset the property # with the updated doc. - updated_property = property(destin_attr.fget, - destin_attr.fset, - destin_attr.fdel, - destin_doc) + updated_property = property( + destin_attr.fget, destin_attr.fset, destin_attr.fdel, destin_doc + ) setattr(destin_class, name, updated_property) else: destin_attr.__doc__ = destin_doc @@ -75,11 +75,13 @@ def __init__(self): def _callback(self, result): print(result) """ + def weak_fn(weakself, *args, **kwargs): self = weakself() if self is None: - raise AttributeError("self was garbage collected when calling self.{}" - .format(fn.__name__)) + raise AttributeError( + "self was garbage collected when calling self.{}".format(fn.__name__) + ) return fn(self, *args, **kwargs) @property @@ -89,8 +91,12 @@ def weak_binder(self): @weak_binder.setter def weak_binder(self, new): - if not (isinstance(new, functools.partial) and new.func is weak_fn and - len(new.args) == 1 and new.args[0] is weakref.ref(self)): + if not ( + isinstance(new, functools.partial) + and new.func is weak_fn + and len(new.args) == 1 + and new.args[0] is weakref.ref(self) + ): raise AttributeError("cannot overwrite weakmethod {}".format(fn.__name__)) return weak_binder @@ -127,11 +133,7 @@ def detach(obj): torch_jit_script_if_tracing = getattr( torch.jit, "script_if_tracing", - getattr( - torch.jit, - "_script_if_tracing", - torch.jit.script - ), + getattr(torch.jit, "_script_if_tracing", torch.jit.script), ) @@ -172,7 +174,7 @@ def broadcast_shape(*shapes, **kwargs): :rtype: tuple :raises: ValueError """ - strict = kwargs.pop('strict', False) + strict = kwargs.pop("strict", False) reversed_shape = [] for shape in shapes: for i, size in enumerate(reversed(shape)): @@ -181,8 +183,11 @@ def broadcast_shape(*shapes, **kwargs): elif reversed_shape[i] == 1 and not strict: reversed_shape[i] = size elif reversed_shape[i] != size and (size != 1 or strict): - raise ValueError('shape mismatch: objects cannot be broadcast to a single shape: {}'.format( - ' vs '.join(map(str, shapes)))) + raise ValueError( + "shape mismatch: objects cannot be broadcast to a single shape: {}".format( + " vs ".join(map(str, shapes)) + ) + ) return tuple(reversed(reversed_shape)) @@ -284,7 +289,7 @@ def eye_like(value, m, n=None): if n is None: n = m eye = torch.zeros(m, n, dtype=value.dtype, device=value.device) - eye.view(-1)[:min(m, n) * n:n + 1] = 1 + eye.view(-1)[: min(m, n) * n : n + 1] = 1 return eye diff --git a/pyro/distributions/von_mises_3d.py b/pyro/distributions/von_mises_3d.py index acfa1f6519..d695080741 100644 --- a/pyro/distributions/von_mises_3d.py +++ b/pyro/distributions/von_mises_3d.py @@ -28,13 +28,17 @@ class VonMises3D(TorchDistribution): vector. The direction of this vector is the location, and its magnitude is the concentration. """ - arg_constraints = {'concentration': constraints.real} + + arg_constraints = {"concentration": constraints.real} support = constraints.sphere def __init__(self, concentration, validate_args=None): if concentration.dim() < 1 or concentration.shape[-1] != 3: - raise ValueError('Expected concentration to have rightmost dim 3, actual shape = {}'.format( - concentration.shape)) + raise ValueError( + "Expected concentration to have rightmost dim 3, actual shape = {}".format( + concentration.shape + ) + ) self.concentration = concentration batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args) @@ -42,10 +46,13 @@ def __init__(self, concentration, validate_args=None): def log_prob(self, value): if self._validate_args: if value.dim() < 1 or value.shape[-1] != 3: - raise ValueError('Expected value to have rightmost dim 3, actual shape = {}'.format( - value.shape)) + raise ValueError( + "Expected value to have rightmost dim 3, actual shape = {}".format( + value.shape + ) + ) if not (torch.abs(value.norm(2, -1) - 1) < 1e-6).all(): - raise ValueError('direction vectors are not normalized') + raise ValueError("direction vectors are not normalized") scale = self.concentration.norm(2, -1) log_normalizer = scale.log() - scale.sinh().log() - math.log(4 * math.pi) return (self.concentration * value).sum(-1) + log_normalizer @@ -54,6 +61,6 @@ def expand(self, batch_shape): try: return super().expand(batch_shape) except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') + validate_args = self.__dict__.get("_validate_args") concentration = self.concentration.expand(torch.Size(batch_shape) + (3,)) return type(self)(concentration, validate_args=validate_args) diff --git a/pyro/distributions/zero_inflated.py b/pyro/distributions/zero_inflated.py index fa06964439..5a0cae0e4b 100644 --- a/pyro/distributions/zero_inflated.py +++ b/pyro/distributions/zero_inflated.py @@ -26,12 +26,17 @@ class ZeroInflatedDistribution(TorchDistribution): :param torch.Tensor gate: probability of extra zeros given via a Bernoulli distribution. :param torch.Tensor gate_logits: logits of extra zeros given via a Bernoulli distribution. """ - arg_constraints = {"gate": constraints.unit_interval, - "gate_logits": constraints.real} + + arg_constraints = { + "gate": constraints.unit_interval, + "gate_logits": constraints.real, + } def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None): if (gate is None) == (gate_logits is None): - raise ValueError("Either `gate` or `gate_logits` must be specified, but not both.") + raise ValueError( + "Either `gate` or `gate_logits` must be specified, but not both." + ) if gate is not None: batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) self.gate = gate.expand(batch_shape) @@ -39,9 +44,10 @@ def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape) self.gate_logits = gate_logits.expand(batch_shape) if base_dist.event_shape: - raise ValueError("ZeroInflatedDistribution expected empty " - "base_dist.event_shape but got {}" - .format(base_dist.event_shape)) + raise ValueError( + "ZeroInflatedDistribution expected empty " + "base_dist.event_shape but got {}".format(base_dist.event_shape) + ) self.base_dist = base_dist.expand(batch_shape) event_shape = torch.Size() @@ -64,7 +70,7 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) - if 'gate' in self.__dict__: + if "gate" in self.__dict__: gate, value = broadcast_all(self.gate, value) log_prob = (-gate).log1p() + self.base_dist.log_prob(value) log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob) @@ -98,10 +104,16 @@ def variance(self): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) - gate = self.gate.expand(batch_shape) if 'gate' in self.__dict__ else None - gate_logits = self.gate_logits.expand(batch_shape) if 'gate_logits' in self.__dict__ else None + gate = self.gate.expand(batch_shape) if "gate" in self.__dict__ else None + gate_logits = ( + self.gate_logits.expand(batch_shape) + if "gate_logits" in self.__dict__ + else None + ) base_dist = self.base_dist.expand(batch_shape) - ZeroInflatedDistribution.__init__(new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False) + ZeroInflatedDistribution.__init__( + new, base_dist, gate=gate, gate_logits=gate_logits, validate_args=False + ) new._validate_args = self._validate_args return new @@ -114,9 +126,12 @@ class ZeroInflatedPoisson(ZeroInflatedDistribution): :param torch.Tensor gate: probability of extra zeros. :param torch.Tensor gate_logits: logits of extra zeros. """ - arg_constraints = {"rate": constraints.positive, - "gate": constraints.unit_interval, - "gate_logits": constraints.real} + + arg_constraints = { + "rate": constraints.positive, + "gate": constraints.unit_interval, + "gate_logits": constraints.real, + } support = constraints.nonnegative_integer def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None): @@ -143,14 +158,26 @@ class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution): :param torch.Tensor gate: probability of extra zeros. :param torch.Tensor gate_logits: logits of extra zeros. """ - arg_constraints = {"total_count": constraints.greater_than_eq(0), - "probs": constraints.half_open_interval(0., 1.), - "logits": constraints.real, - "gate": constraints.unit_interval, - "gate_logits": constraints.real} + + arg_constraints = { + "total_count": constraints.greater_than_eq(0), + "probs": constraints.half_open_interval(0.0, 1.0), + "logits": constraints.real, + "gate": constraints.unit_interval, + "gate_logits": constraints.real, + } support = constraints.nonnegative_integer - def __init__(self, total_count, *, probs=None, logits=None, gate=None, gate_logits=None, validate_args=None): + def __init__( + self, + total_count, + *, + probs=None, + logits=None, + gate=None, + gate_logits=None, + validate_args=None + ): base_dist = NegativeBinomial( total_count=total_count, probs=probs, diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 8891c9cc64..2f48f16836 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -31,8 +31,9 @@ class EmpiricalMarginal(Empirical): """ def __init__(self, trace_posterior, sites=None, validate_args=None): - assert isinstance(trace_posterior, TracePosterior), \ - "trace_dist must be trace posterior distribution object" + assert isinstance( + trace_posterior, TracePosterior + ), "trace_dist must be trace posterior distribution object" if sites is None: sites = "_RETURN" self._num_chains = 1 @@ -53,13 +54,19 @@ def _get_samples_and_weights(self): for i in range(num_chains): samples = torch.stack(self._samples_buffer[i], dim=0) samples_by_chain.append(samples) - weights_dtype = samples.dtype if samples.dtype.is_floating_point else torch.float32 - weights = torch.as_tensor(self._weights_buffer[i], device=samples.device, dtype=weights_dtype) + weights_dtype = ( + samples.dtype if samples.dtype.is_floating_point else torch.float32 + ) + weights = torch.as_tensor( + self._weights_buffer[i], device=samples.device, dtype=weights_dtype + ) weights_by_chain.append(weights) if len(samples_by_chain) == 1: return samples_by_chain[0], weights_by_chain[0] else: - return torch.stack(samples_by_chain, dim=0), torch.stack(weights_by_chain, dim=0) + return torch.stack(samples_by_chain, dim=0), torch.stack( + weights_by_chain, dim=0 + ) def _add_sample(self, value, log_weight=None, chain_id=0): """ @@ -79,7 +86,11 @@ def _add_sample(self, value, log_weight=None, chain_id=0): # Apply default weight of 1.0. if log_weight is None: log_weight = 0.0 - if self._validate_args and not isinstance(log_weight, numbers.Number) and log_weight.dim() > 0: + if ( + self._validate_args + and not isinstance(log_weight, numbers.Number) + and log_weight.dim() > 0 + ): raise ValueError("``weight.dim() > 0``, but weight should be a scalar.") # Append to the buffer list @@ -89,11 +100,16 @@ def _add_sample(self, value, log_weight=None, chain_id=0): def _populate_traces(self, trace_posterior, sites): assert isinstance(sites, (list, str)) - for tr, log_weight, chain_id in zip(trace_posterior.exec_traces, - trace_posterior.log_weights, - trace_posterior.chain_ids): - value = tr.nodes[sites]["value"] if isinstance(sites, str) else \ - torch.stack([tr.nodes[site]["value"] for site in sites], 0) + for tr, log_weight, chain_id in zip( + trace_posterior.exec_traces, + trace_posterior.log_weights, + trace_posterior.chain_ids, + ): + value = ( + tr.nodes[sites]["value"] + if isinstance(sites, str) + else torch.stack([tr.nodes[site]["value"] for site in sites], 0) + ) self._add_sample(value, log_weight=log_weight, chain_id=chain_id) @@ -108,9 +124,11 @@ class Marginals: :param list sites: optional list of sites for which we need to generate the marginal distribution. """ + def __init__(self, trace_posterior, sites=None, validate_args=None): - assert isinstance(trace_posterior, TracePosterior), \ - "trace_dist must be trace posterior distribution object" + assert isinstance( + trace_posterior, TracePosterior + ), "trace_dist must be trace posterior distribution object" if sites is None: sites = ["_RETURN"] elif isinstance(sites, str): @@ -124,8 +142,10 @@ def __init__(self, trace_posterior, sites=None, validate_args=None): self._populate_traces(trace_posterior, validate_args) def _populate_traces(self, trace_posterior, validate): - self._marginals = {site: EmpiricalMarginal(trace_posterior, site, validate) - for site in self.sites} + self._marginals = { + site: EmpiricalMarginal(trace_posterior, site, validate) + for site in self.sites + } def support(self, flatten=False): """ @@ -137,8 +157,12 @@ def support(self, flatten=False): :returns: a dict with keys are sites' names and values are sites' supports. :rtype: :class:`OrderedDict` """ - support = OrderedDict([(site, value.enumerate_support()) - for site, value in self._marginals.items()]) + support = OrderedDict( + [ + (site, value.enumerate_support()) + for site, value in self._marginals.items() + ] + ) if self._trace_posterior.num_chains > 1 and flatten: for site, samples in support.items(): shape = samples.size() @@ -164,6 +188,7 @@ class TracePosterior(object, metaclass=ABCMeta): This is designed to be used by other utility classes like `EmpiricalMarginal`, that need access to the collected execution traces. """ + def __init__(self, num_chains=1): self.num_chains = num_chains self._reset() @@ -172,7 +197,9 @@ def _reset(self): self.log_weights = [] self.exec_traces = [] self.chain_ids = [] # chain id corresponding to the sample - self._idx_by_chain = [[] for _ in range(self.num_chains)] # indexes of samples by chain id + self._idx_by_chain = [ + [] for _ in range(self.num_chains) + ] # indexes of samples by chain id self._categorical = None def marginal(self, sites=None): @@ -201,7 +228,10 @@ def __call__(self, *args, **kwargs): # we get the index from ``idxs_by_chain`` instead of sampling from # the marginal directly. random_idx = self._categorical.sample().item() - chain_idx, sample_idx = random_idx % self.num_chains, random_idx // self.num_chains + chain_idx, sample_idx = ( + random_idx % self.num_chains, + random_idx // self.num_chains, + ) sample_idx = self._idx_by_chain[chain_idx][sample_idx] trace = self.exec_traces[sample_idx].copy() for name in trace.observation_nodes: @@ -256,19 +286,27 @@ def information_criterion(self, pointwise=False): for trace in self.exec_traces: obs_nodes = trace.observation_nodes if len(obs_nodes) > 1: - raise ValueError("Infomation criterion calculation only works for models " - "with one observation node.") + raise ValueError( + "Infomation criterion calculation only works for models " + "with one observation node." + ) if obs_node is None: obs_node = obs_nodes[0] elif obs_node != obs_nodes[0]: - raise ValueError("Observation node has been changed, expected {} but got {}" - .format(obs_node, obs_nodes[0])) + raise ValueError( + "Observation node has been changed, expected {} but got {}".format( + obs_node, obs_nodes[0] + ) + ) - log_likelihoods.append(trace.nodes[obs_node]["fn"] - .log_prob(trace.nodes[obs_node]["value"])) + log_likelihoods.append( + trace.nodes[obs_node]["fn"].log_prob(trace.nodes[obs_node]["value"]) + ) ll = torch.stack(log_likelihoods, dim=0) - waic_value, p_waic = waic(ll, torch.tensor(self.log_weights, device=ll.device), pointwise) + waic_value, p_waic = waic( + ll, torch.tensor(self.log_weights, device=ll.device), pointwise + ) return OrderedDict([("waic", waic_value), ("p_waic", p_waic)]) @@ -288,15 +326,18 @@ class TracePredictive(TracePosterior): :param int num_samples: number of samples to generate. :param keep_sites: The sites which should be sampled from posterior distribution (default: all) """ + def __init__(self, model, posterior, num_samples, keep_sites=None): self.model = model self.posterior = posterior self.num_samples = num_samples self.keep_sites = keep_sites super().__init__() - warnings.warn('The `TracePredictive` class is deprecated and will be removed ' - 'in a future release. Use the `pyro.infer.Predictive` class instead.', - FutureWarning) + warnings.warn( + "The `TracePredictive` class is deprecated and will be removed " + "in a future release. Use the `pyro.infer.Predictive` class instead.", + FutureWarning, + ) def _traces(self, *args, **kwargs): if not self.posterior.exec_traces: @@ -306,8 +347,10 @@ def _traces(self, *args, **kwargs): model_trace = self.posterior().copy() self._remove_dropped_nodes(model_trace) self._adjust_to_data(model_trace, data_trace) - resampled_trace = poutine.trace(poutine.replay(self.model, model_trace)).get_trace(*args, **kwargs) - yield (resampled_trace, 0., 0) + resampled_trace = poutine.trace( + poutine.replay(self.model, model_trace) + ).get_trace(*args, **kwargs) + yield (resampled_trace, 0.0, 0) def _remove_dropped_nodes(self, trace): if self.keep_sites is None: @@ -336,10 +379,15 @@ def _adjust_to_data(self, trace, data_trace): assert ocis.name == cis.name assert not site_is_subsample(site) batch_dim = cis.dim - site["fn"].event_dim - subsampled_idxs[cis.name] = subsampled_idxs.get(cis.name, - torch.randint(0, ocis.size, (cis.size,), - device=site["value"].device)) - site["value"] = site["value"].index_select(batch_dim, subsampled_idxs[cis.name]) + subsampled_idxs[cis.name] = subsampled_idxs.get( + cis.name, + torch.randint( + 0, ocis.size, (cis.size,), device=site["value"].device + ), + ) + site["value"] = site["value"].index_select( + batch_dim, subsampled_idxs[cis.name] + ) def marginal(self, sites=None): """ diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index a868f1413a..42da0c7030 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -29,26 +29,26 @@ from pyro.infer.autoguide.utils import mean_field_entropy __all__ = [ - 'AutoCallable', - 'AutoContinuous', - 'AutoDelta', - 'AutoDiagonalNormal', - 'AutoDiscreteParallel', - 'AutoGuide', - 'AutoGuideList', - 'AutoIAFNormal', - 'AutoLaplaceApproximation', - 'AutoLowRankMultivariateNormal', - 'AutoMultivariateNormal', - 'AutoNormal', - 'AutoNormalizingFlow', - 'AutoStructured', - 'init_to_feasible', - 'init_to_generated', - 'init_to_mean', - 'init_to_median', - 'init_to_sample', - 'init_to_uniform', - 'init_to_value', - 'mean_field_entropy', + "AutoCallable", + "AutoContinuous", + "AutoDelta", + "AutoDiagonalNormal", + "AutoDiscreteParallel", + "AutoGuide", + "AutoGuideList", + "AutoIAFNormal", + "AutoLaplaceApproximation", + "AutoLowRankMultivariateNormal", + "AutoMultivariateNormal", + "AutoNormal", + "AutoNormalizingFlow", + "AutoStructured", + "init_to_feasible", + "init_to_generated", + "init_to_mean", + "init_to_median", + "init_to_sample", + "init_to_uniform", + "init_to_value", + "mean_field_entropy", ] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 5765a2bb36..cdde09f65d 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -64,7 +64,7 @@ def _getattr(obj, attr): lpart, _, rpart = key.rpartition(".") # Recursive getattr while setting any prefix attributes to PyroModule if lpart: - obj = functools.reduce(_getattr, [obj] + lpart.split('.')) + obj = functools.reduce(_getattr, [obj] + lpart.split(".")) setattr(obj, rpart, val) @@ -148,23 +148,29 @@ def _create_plates(self, *args, **kwargs): plates = self.create_plates(*args, **kwargs) if isinstance(plates, pyro.plate): plates = [plates] - assert all(isinstance(p, pyro.plate) for p in plates), \ - "create_plates() returned a non-plate" + assert all( + isinstance(p, pyro.plate) for p in plates + ), "create_plates() returned a non-plate" self.plates = {p.name: p for p in plates} for name, frame in sorted(self._prototype_frames.items()): if name not in self.plates: full_size = getattr(frame, "full_size", frame.size) - self.plates[name] = pyro.plate(name, full_size, dim=frame.dim, - subsample_size=frame.size) + self.plates[name] = pyro.plate( + name, full_size, dim=frame.dim, subsample_size=frame.size + ) else: - assert self.create_plates is None, "Cannot pass create_plates() to non-master guide" + assert ( + self.create_plates is None + ), "Cannot pass create_plates() to non-master guide" self.plates = self.master().plates return self.plates def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = poutine.block(self.model, prototype_hide_fn) - self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) + self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( + *args, **kwargs + ) if self.master is not None: self.master()._check_prototype(self.prototype_trace) @@ -174,7 +180,9 @@ def _setup_prototype(self, *args, **kwargs): if frame.vectorized: self._prototype_frames[frame.name] = frame else: - raise NotImplementedError("AutoGuide does not support sequential pyro.plate") + raise NotImplementedError( + "AutoGuide does not support sequential pyro.plate" + ) def median(self, *args, **kwargs): """ @@ -226,12 +234,17 @@ def append(self, part): if not isinstance(part, AutoGuide): part = AutoCallable(self.model, part) if part.master is not None: - raise RuntimeError("The module `{}` is already added.".format(self._pyro_name)) + raise RuntimeError( + "The module `{}` is already added.".format(self._pyro_name) + ) setattr(self, str(len(self)), part) def add(self, part): """Deprecated alias for :meth:`append`.""" - warnings.warn("The method `.add` has been deprecated in favor of `.append`.", DeprecationWarning) + warnings.warn( + "The method `.add` has been deprecated in favor of `.append`.", + DeprecationWarning, + ) self.append(part) def forward(self, *args, **kwargs): @@ -342,8 +355,8 @@ def my_init_fn(site): or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ - def __init__(self, model, init_loc_fn=init_to_median, *, - create_plates=None): + + def __init__(self, model, init_loc_fn=init_to_median, *, create_plates=None): self.init_loc_fn = init_loc_fn model = InitMessenger(self.init_loc_fn)(model) super().__init__(model, create_plates=create_plates) @@ -389,8 +402,9 @@ def forward(self, *args, **kwargs): if frame.vectorized: stack.enter_context(plates[frame.name]) attr_get = operator.attrgetter(name) - result[name] = pyro.sample(name, dist.Delta(attr_get(self), - event_dim=site["fn"].event_dim)) + result[name] = pyro.sample( + name, dist.Delta(attr_get(self), event_dim=site["fn"].event_dim) + ) return result @torch.no_grad() @@ -438,10 +452,9 @@ class AutoNormal(AutoGuide): scale_constraint = constraints.softplus_positive - def __init__(self, model, *, - init_loc_fn=init_to_feasible, - init_scale=0.1, - create_plates=None): + def __init__( + self, model, *, init_loc_fn=init_to_feasible, init_scale=0.1, create_plates=None + ): self.init_loc_fn = init_loc_fn if not isinstance(init_scale, float) or not (init_scale > 0): @@ -462,7 +475,9 @@ def _setup_prototype(self, *args, **kwargs): for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): - init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() + init_loc = ( + biject_to(site["fn"].support).inv(site["value"].detach()).detach() + ) event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim @@ -474,9 +489,14 @@ def _setup_prototype(self, *args, **kwargs): init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) - _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) - _deep_setattr(self.scales, name, - PyroParam(init_scale, self.scale_constraint, event_dim)) + _deep_setattr( + self.locs, name, PyroParam(init_loc, constraints.real, event_dim) + ) + _deep_setattr( + self.scales, + name, + PyroParam(init_scale, self.scale_constraint, event_dim), + ) def _get_loc_and_scale(self, name): site_loc = _deep_getattr(self.locs, name) @@ -511,9 +531,10 @@ def forward(self, *args, **kwargs): unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( - site_loc, site_scale, + site_loc, + site_scale, ).to_event(self._event_dims[name]), - infer={"is_auxiliary": True} + infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) @@ -573,10 +594,16 @@ def quantiles(self, quantiles, *args, **kwargs): for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, site_scale = self._get_loc_and_scale(name) - site_quantiles = torch.tensor(quantiles, dtype=site_loc.dtype, device=site_loc.device) + site_quantiles = torch.tensor( + quantiles, dtype=site_loc.dtype, device=site_loc.device + ) site_quantiles = site_quantiles.reshape((-1,) + (1,) * site_loc.dim()) - site_quantiles_values = dist.Normal(site_loc, site_scale).icdf(site_quantiles) - constrained_site_quantiles = biject_to(site["fn"].support)(site_quantiles_values) + site_quantiles_values = dist.Normal(site_loc, site_scale).icdf( + site_quantiles + ) + constrained_site_quantiles = biject_to(site["fn"].support)( + site_quantiles_values + ) results[name] = constrained_site_quantiles return results @@ -608,6 +635,7 @@ class AutoContinuous(AutoGuide): :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. """ + def __init__(self, model, init_loc_fn=init_to_median): model = InitMessenger(init_loc_fn)(model) super().__init__(model) @@ -620,14 +648,22 @@ def _setup_prototype(self, *args, **kwargs): # Collect the shapes of unconstrained values. # These may differ from the shapes of constrained values. with helpful_support_errors(site): - self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape + self._unconstrained_shapes[name] = ( + biject_to(site["fn"].support).inv(site["value"]).shape + ) # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] - self.latent_dim = sum(_product(shape) for shape in self._unconstrained_shapes.values()) + self.latent_dim = sum( + _product(shape) for shape in self._unconstrained_shapes.values() + ) if self.latent_dim == 0: - raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__)) + raise RuntimeError( + "{} found no latent variables; Use an empty guide instead".format( + type(self).__name__ + ) + ) def _init_loc(self): """ @@ -684,7 +720,9 @@ def sample_latent(self, *args, **kwargs): base ``model``. """ pos_dist = self.get_posterior(*args, **kwargs) - return pyro.sample("_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True}) + return pyro.sample( + "_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True} + ) def _unpack_latent(self, latent): """ @@ -692,16 +730,23 @@ def _unpack_latent(self, latent): (site, unconstrained_value) """ - batch_shape = latent.shape[:-1] # for plates outside of _setup_prototype, e.g. parallel particles + batch_shape = latent.shape[ + :-1 + ] # for plates outside of _setup_prototype, e.g. parallel particles pos = 0 for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_shape = site["value"].shape unconstrained_shape = self._unconstrained_shapes[name] size = _product(unconstrained_shape) - event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape) - unconstrained_shape = torch.broadcast_shapes(unconstrained_shape, - batch_shape + (1,) * event_dim) - unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape) + event_dim = ( + site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape) + ) + unconstrained_shape = torch.broadcast_shapes( + unconstrained_shape, batch_shape + (1,) * event_dim + ) + unconstrained_value = latent[..., pos : pos + size].view( + unconstrained_shape + ) yield site, unconstrained_value pos += size if not torch._C._get_tracing_state(): @@ -771,8 +816,10 @@ def median(self, *args, **kwargs): """ loc, _ = self._loc_scale(*args, **kwargs) loc = loc.detach() - return {site["name"]: biject_to(site["fn"].support)(unconstrained_value) - for site, unconstrained_value in self._unpack_latent(loc)} + return { + site["name"]: biject_to(site["fn"].support)(unconstrained_value) + for site, unconstrained_value in self._unpack_latent(loc) + } @torch.no_grad() def quantiles(self, quantiles, *args, **kwargs): @@ -787,12 +834,16 @@ def quantiles(self, quantiles, *args, **kwargs): :rtype: dict """ loc, scale = self._loc_scale(*args, **kwargs) - quantiles = torch.tensor(quantiles, dtype=loc.dtype, device=loc.device).unsqueeze(-1) + quantiles = torch.tensor( + quantiles, dtype=loc.dtype, device=loc.device + ).unsqueeze(-1) latents = dist.Normal(loc, scale).icdf(quantiles) result = {} for latent in latents: for site, unconstrained_value in self._unpack_latent(latent): - result.setdefault(site["name"], []).append(biject_to(site["fn"].support)(unconstrained_value)) + result.setdefault(site["name"], []).append( + biject_to(site["fn"].support)(unconstrained_value) + ) result = {k: torch.stack(v) for k, v in result.items()} return result @@ -831,11 +882,15 @@ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) - self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim) * self._init_scale, - self.scale_tril_constraint) + self.scale_tril = PyroParam( + eye_like(self.loc, self.latent_dim) * self._init_scale, + self.scale_tril_constraint, + ) def get_base_dist(self): - return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1) + return dist.Normal( + torch.zeros_like(self.loc), torch.zeros_like(self.loc) + ).to_event(1) def get_transform(self, *args, **kwargs): return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=self.scale_tril) @@ -883,11 +938,15 @@ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) - self.scale = PyroParam(self.loc.new_full((self.latent_dim,), self._init_scale), - self.scale_constraint) + self.scale = PyroParam( + self.loc.new_full((self.latent_dim,), self._init_scale), + self.scale_constraint, + ) def get_base_dist(self): - return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1) + return dist.Normal( + torch.zeros_like(self.loc), torch.zeros_like(self.loc) + ).to_event(1) def get_transform(self, *args, **kwargs): return dist.transforms.AffineTransform(self.loc, self.scale) @@ -947,9 +1006,13 @@ def _setup_prototype(self, *args, **kwargs): self.rank = int(round(self.latent_dim ** 0.5)) self.scale = PyroParam( self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale), - constraint=self.scale_constraint) + constraint=self.scale_constraint, + ) self.cov_factor = nn.Parameter( - self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5)) + self.loc.new_empty(self.latent_dim, self.rank).normal_( + 0, 1 / self.rank ** 0.5 + ) + ) def get_posterior(self, *args, **kwargs): """ @@ -991,7 +1054,7 @@ def __init__(self, model, init_transform_fn): super().__init__(model, init_loc_fn=init_to_feasible) self._init_transform_fn = init_transform_fn self.transform = None - self._prototype_tensor = torch.tensor(0.) + self._prototype_tensor = torch.tensor(0.0) def get_base_dist(self): loc = self._prototype_tensor.new_zeros(1) @@ -1040,16 +1103,30 @@ class AutoIAFNormal(AutoNormalizingFlow): :func:`~pyro.distributions.transforms.affine_autoregressive`. """ - def __init__(self, model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs): + def __init__( + self, + model, + hidden_dim=None, + init_loc_fn=None, + num_transforms=1, + **init_transform_kwargs, + ): if init_loc_fn: - warnings.warn("The `init_loc_fn` argument to AutoIAFNormal is not used in practice. " - "Please consider removing, as this may be removed in a future release.", - category=FutureWarning) - super().__init__(model, - init_transform_fn=functools.partial(iterated, num_transforms, - affine_autoregressive, - hidden_dims=hidden_dim, - **init_transform_kwargs)) + warnings.warn( + "The `init_loc_fn` argument to AutoIAFNormal is not used in practice. " + "Please consider removing, as this may be removed in a future release.", + category=FutureWarning, + ) + super().__init__( + model, + init_transform_fn=functools.partial( + iterated, + num_transforms, + affine_autoregressive, + hidden_dims=hidden_dim, + **init_transform_kwargs, + ), + ) class AutoLaplaceApproximation(AutoContinuous): @@ -1074,6 +1151,7 @@ class AutoLaplaceApproximation(AutoContinuous): :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. """ + def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params @@ -1092,7 +1170,8 @@ def laplace_approximation(self, *args, **kwargs): """ guide_trace = poutine.trace(self).get_trace(*args, **kwargs) model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) + poutine.replay(self.model, trace=guide_trace) + ).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() H = hessian(loss, self.loc) @@ -1113,10 +1192,13 @@ class AutoDiscreteParallel(AutoGuide): A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model. """ + def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = poutine.block(config_enumerate(self.model), prototype_hide_fn) - self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) + self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( + *args, **kwargs + ) if self.master is not None: self.master()._check_prototype(self.prototype_trace) @@ -1125,14 +1207,18 @@ def _setup_prototype(self, *args, **kwargs): self._prototype_frames = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): if site["infer"].get("enumerate") != "parallel": - raise NotImplementedError('Expected sample site "{}" to be discrete and ' - 'configured for parallel enumeration'.format(name)) + raise NotImplementedError( + 'Expected sample site "{}" to be discrete and ' + "configured for parallel enumeration".format(name) + ) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): - params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] + params = [ + ("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"]) + ] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) @@ -1143,13 +1229,18 @@ def _setup_prototype(self, *args, **kwargs): if frame.vectorized: self._prototype_frames[frame.name] = frame else: - raise NotImplementedError("AutoDiscreteParallel does not support sequential pyro.plate") + raise NotImplementedError( + "AutoDiscreteParallel does not support sequential pyro.plate" + ) # Initialize guide params for site, Dist, param_spec in self._discrete_sites: name = site["name"] for param_name, param_init, param_constraint in param_spec: - _deep_setattr(self, "{}_{}".format(name, param_name), - PyroParam(param_init, constraint=param_constraint)) + _deep_setattr( + self, + "{}_{}".format(name, param_name), + PyroParam(param_init, constraint=param_constraint), + ) def forward(self, *args, **kwargs): """ @@ -1180,7 +1271,9 @@ def forward(self, *args, **kwargs): with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) - result[name] = pyro.sample(name, discrete_dist, infer={"enumerate": "parallel"}) + result[name] = pyro.sample( + name, discrete_dist, infer={"enumerate": "parallel"} + ) return result @@ -1300,9 +1393,13 @@ def _setup_prototype(self, *args, **kwargs): numel = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): with helpful_support_errors(site): - init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() + init_loc = ( + biject_to(site["fn"].support).inv(site["value"].detach()).detach() + ) self._batch_shapes[name] = site["fn"].batch_shape - self._unconstrained_event_shapes[name] = init_loc.shape[len(site["fn"].batch_shape):] + self._unconstrained_event_shapes[name] = init_loc.shape[ + len(site["fn"].batch_shape) : + ] numel[name] = init_loc.numel() init_locs[name] = init_loc.reshape(-1) @@ -1323,12 +1420,16 @@ def _setup_prototype(self, *args, **kwargs): raise ValueError(f"Unsupported conditional type: {conditional}") if conditional in ("normal", "mvn"): init_scale = torch.full_like(init_loc, self._init_scale) - _deep_setattr(self.scales, name, - PyroParam(init_scale, self.scale_constraint)) + _deep_setattr( + self.scales, name, PyroParam(init_scale, self.scale_constraint) + ) if conditional == "mvn": init_scale_tril = eye_like(init_loc, init_loc.numel()) - _deep_setattr(self.scale_trils, name, - PyroParam(init_scale_tril, self.scale_tril_constraint)) + _deep_setattr( + self.scale_trils, + name, + PyroParam(init_scale_tril, self.scale_tril_constraint), + ) # Initialize dependencies on upstream variables. num_pending[name] = 0 @@ -1415,8 +1516,9 @@ def get_deltas(self, save_params=None): scale_tril = _deep_getattr(self.scale_trils, name) aux_value = aux_value @ scale_tril.T * scale if compute_density: - log_density = (-scale_tril.diagonal(dim1=-2, dim2=-1).log() - - scale.log()).expand_as(aux_value) + log_density = ( + -scale_tril.diagonal(dim1=-2, dim2=-1).log() - scale.log() + ).expand_as(aux_value) else: raise ValueError(f"Unsupported conditional type: {conditional}") diff --git a/pyro/infer/autoguide/initialization.py b/pyro/infer/autoguide/initialization.py index 9f5c34f14b..5654f34346 100644 --- a/pyro/infer/autoguide/initialization.py +++ b/pyro/infer/autoguide/initialization.py @@ -225,6 +225,7 @@ class InitMessenger(Messenger): :param callable init_fn: An initialization function. """ + def __init__(self, init_fn): self.init_fn = init_fn super().__init__() @@ -237,12 +238,16 @@ def _pyro_sample(self, msg): if is_validation_enabled() and msg["value"] is not None: if not isinstance(value, type(msg["value"])): raise ValueError( - "{} provided invalid type for site {}:\nexpected {}\nactual {}" - .format(self.init_fn, msg["name"], type(msg["value"]), type(value))) + "{} provided invalid type for site {}:\nexpected {}\nactual {}".format( + self.init_fn, msg["name"], type(msg["value"]), type(value) + ) + ) if value.shape != msg["value"].shape: raise ValueError( - "{} provided invalid shape for site {}:\nexpected {}\nactual {}" - .format(self.init_fn, msg["name"], msg["value"].shape, value.shape)) + "{} provided invalid shape for site {}:\nexpected {}\nactual {}".format( + self.init_fn, msg["name"], msg["value"].shape, value.shape + ) + ) msg["value"] = value def _pyro_get_init_messengers(self, msg): diff --git a/pyro/infer/autoguide/utils.py b/pyro/infer/autoguide/utils.py index 13abda8746..55b3f904e1 100644 --- a/pyro/infer/autoguide/utils.py +++ b/pyro/infer/autoguide/utils.py @@ -28,7 +28,7 @@ def mean_field_entropy(model, args, whitelist=None): sites are included. """ trace = poutine.trace(model).get_trace(*args) - entropy = 0. + entropy = 0.0 for name, site in trace.nodes.items(): if site["type"] == "sample": if not poutine.util.site_is_subsample(site): @@ -51,12 +51,14 @@ def helpful_support_errors(site): "https://pyro.ai/examples/enumeration.html . " "If you are already enumerating, take care to hide this site when " "constructing an autoguide, e.g. " - f"guide = AutoNormal(poutine.block(model, hide=['{name}'])).") + f"guide = AutoNormal(poutine.block(model, hide=['{name}']))." + ) if "sphere" in support_name: name = site["name"] raise ValueError( f"Continuous inference cannot handle spherical sample site '{name}'. " "Consider using ProjectedNormal distribution together with " "a reparameterizer, e.g. " - f"poutine.reparam(config={{'{name}': ProjectedNormalReparam()}}).") + f"poutine.reparam(config={{'{name}': ProjectedNormalReparam()}})." + ) raise e from None diff --git a/pyro/infer/csis.py b/pyro/infer/csis.py index 06db33a844..4717217970 100644 --- a/pyro/infer/csis.py +++ b/pyro/infer/csis.py @@ -36,13 +36,16 @@ class CSIS(Importance): :param validation_batch_size: Number of samples to use for calculating validation loss (will only be used if `.validation_loss` is called). """ - def __init__(self, - model, - guide, - optim, - num_inference_samples=10, - training_batch_size=10, - validation_batch_size=20): + + def __init__( + self, + model, + guide, + optim, + num_inference_samples=10, + training_batch_size=10, + validation_batch_size=20, + ): super().__init__(model, guide, num_inference_samples) self.model = model self.guide = guide @@ -57,8 +60,10 @@ def set_validation_batch(self, *args, **kwargs): Arguments are passed directly to model. """ - self.validation_batch = [self._sample_from_joint(*args, **kwargs) - for _ in range(self.validation_batch_size)] + self.validation_batch = [ + self._sample_from_joint(*args, **kwargs) + for _ in range(self.validation_batch_size) + ] def step(self, *args, **kwargs): """ @@ -71,9 +76,11 @@ def step(self, *args, **kwargs): with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(True, None, *args, **kwargs) - params = set(site["value"].unconstrained() - for site in param_capture.trace.nodes.values() - if site["value"].grad is not None) + params = set( + site["value"].unconstrained() + for site in param_capture.trace.nodes.values() + if site["value"].grad is not None + ) self.optim(params) @@ -95,8 +102,10 @@ def loss_and_grads(self, grads, batch, *args, **kwargs): `args` and `kwargs` are passed to the model and guide. """ if batch is None: - batch = (self._sample_from_joint(*args, **kwargs) - for _ in range(self.training_batch_size)) + batch = ( + self._sample_from_joint(*args, **kwargs) + for _ in range(self.training_batch_size) + ) batch_size = self.training_batch_size else: batch_size = len(batch) @@ -109,13 +118,21 @@ def loss_and_grads(self, grads, batch, *args, **kwargs): particle_loss /= batch_size if grads: - guide_params = set(site["value"].unconstrained() - for site in particle_param_capture.trace.nodes.values()) - guide_grads = torch.autograd.grad(particle_loss, guide_params, allow_unused=True) + guide_params = set( + site["value"].unconstrained() + for site in particle_param_capture.trace.nodes.values() + ) + guide_grads = torch.autograd.grad( + particle_loss, guide_params, allow_unused=True + ) for guide_grad, guide_param in zip(guide_grads, guide_params): if guide_grad is None: continue - guide_param.grad = guide_grad if guide_param.grad is None else guide_param.grad + guide_grad + guide_param.grad = ( + guide_grad + if guide_param.grad is None + else guide_param.grad + guide_grad + ) loss += torch_item(particle_loss) @@ -154,14 +171,16 @@ def _get_matched_trace(self, model_trace, *args, **kwargs): `args` and `kwargs` are passed to the guide. """ kwargs["observations"] = {} - for node in itertools.chain(model_trace.stochastic_nodes, model_trace.observation_nodes): + for node in itertools.chain( + model_trace.stochastic_nodes, model_trace.observation_nodes + ): if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] - guide_trace = poutine.trace(poutine.replay(self.guide, - model_trace) - ).get_trace(*args, **kwargs) + guide_trace = poutine.trace(poutine.replay(self.guide, model_trace)).get_trace( + *args, **kwargs + ) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) diff --git a/pyro/infer/discrete.py b/pyro/infer/discrete.py index 770d31043a..a89294abd2 100644 --- a/pyro/infer/discrete.py +++ b/pyro/infer/discrete.py @@ -38,8 +38,9 @@ def _pyro_sample(self, msg): msg["cond_indep_stack"] = self.trace.nodes[msg["name"]]["cond_indep_stack"] -def _sample_posterior(model, first_available_dim, temperature, strict_enumeration_warning, - *args, **kwargs): +def _sample_posterior( + model, first_available_dim, temperature, strict_enumeration_warning, *args, **kwargs +): # For internal use by infer_discrete. # Create an enumerated trace. @@ -49,12 +50,14 @@ def _sample_posterior(model, first_available_dim, temperature, strict_enumeratio enum_trace.compute_log_prob() enum_trace.pack_tensors() - return _sample_posterior_from_trace(model, enum_trace, temperature, - strict_enumeration_warning, *args, **kwargs) + return _sample_posterior_from_trace( + model, enum_trace, temperature, strict_enumeration_warning, *args, **kwargs + ) -def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumeration_warning, - *args, **kwargs): +def _sample_posterior_from_trace( + model, enum_trace, temperature, strict_enumeration_warning, *args, **kwargs +): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. @@ -65,14 +68,17 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": - ordinal = frozenset(plate_to_symbol[f.name] - for f in node["cond_indep_stack"] - if f.vectorized and f.size > 1) + ordinal = frozenset( + plate_to_symbol[f.name] + for f in node["cond_indep_stack"] + if f.vectorized and f.size > 1 + ) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( - node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) + node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"] + ) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): @@ -89,9 +95,11 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat require_backward(log_prob) if strict_enumeration_warning and not enum_terms: - warnings.warn('infer_discrete found no sample sites configured for enumeration. ' - 'If you want to enumerate sites, you need to @config_enumerate or set ' - 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}.') + warnings.warn( + "infer_discrete found no sample sites configured for enumeration. " + "If you want to enumerate sites, you need to @config_enumerate or set " + 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}.' + ) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow @@ -106,7 +114,9 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): - log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm + log_probs = contract_tensor_tree( + log_probs, sum_dims, ring=ring + ) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: @@ -117,7 +127,10 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: - if query not in query_to_ordinal and query._pyro_backward_result is not pending: + if ( + query not in query_to_ordinal + and query._pyro_backward_result is not pending + ): query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. @@ -138,18 +151,25 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( - f for f in node["cond_indep_stack"] - if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) + f + for f in node["cond_indep_stack"] + if not (f.vectorized and f.size > 1) + or plate_to_symbol[f.name] in ordinal + ) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: - new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) + new_value = packed.pack( + node["value"], node["infer"]["_dim_to_symbol"] + ) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) - new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) + new_node["value"] = packed.unpack( + new_value, enum_trace.symbol_to_dim + ) collapsed_trace.add_node(node["name"], **new_node) @@ -158,8 +178,9 @@ def _sample_posterior_from_trace(model, enum_trace, temperature, strict_enumerat return model(*args, **kwargs) -def infer_discrete(fn=None, first_available_dim=None, temperature=1, *, - strict_enumeration_warning=True): +def infer_discrete( + fn=None, first_available_dim=None, temperature=1, *, strict_enumeration_warning=True +): """ A poutine that samples discrete sites marked with ``site["infer"]["enumerate"] = "parallel"`` from the posterior, @@ -196,11 +217,18 @@ def viterbi_decoder(data, hidden_dim=10): """ assert first_available_dim < 0, first_available_dim if fn is None: # support use as a decorator - return functools.partial(infer_discrete, - first_available_dim=first_available_dim, - temperature=temperature) - return functools.partial(_sample_posterior, fn, first_available_dim, temperature, - strict_enumeration_warning) + return functools.partial( + infer_discrete, + first_available_dim=first_available_dim, + temperature=temperature, + ) + return functools.partial( + _sample_posterior, + fn, + first_available_dim, + temperature, + strict_enumeration_warning, + ) class TraceEnumSample_ELBO(TraceEnum_ELBO): @@ -224,9 +252,9 @@ class TraceEnumSample_ELBO(TraceEnum_ELBO): first_available_dim=-2)(*args, **kwargs) """ + def _get_trace(self, model, guide, args, kwargs): - model_trace, guide_trace = super()._get_trace( - model, guide, args, kwargs) + model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) # Mark all sample sites with require_backward to gather enumerated # sites and adjust cond_indep_stack of all sample sites. @@ -245,6 +273,11 @@ def sample_saved(self): model, model_trace, guide_trace, args, kwargs = self._saved_state model = poutine.replay(model, guide_trace) temperature = 1 - return _sample_posterior_from_trace(model, model_trace, temperature, - self.strict_enumeration_warning, - *args, **kwargs) + return _sample_posterior_from_trace( + model, + model_trace, + temperature, + self.strict_enumeration_warning, + *args, + **kwargs + ) diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 0a2ba0e17a..6e7b45f17e 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -57,19 +57,23 @@ class ELBO(object, metaclass=ABCMeta): Rajesh Ranganath, Sean Gerrish, David M. Blei """ - def __init__(self, - num_particles=1, - max_plate_nesting=float('inf'), - max_iarange_nesting=None, # DEPRECATED - vectorize_particles=False, - strict_enumeration_warning=True, - ignore_jit_warnings=False, - jit_options=None, - retain_graph=None, - tail_adaptive_beta=-1.0): + def __init__( + self, + num_particles=1, + max_plate_nesting=float("inf"), + max_iarange_nesting=None, # DEPRECATED + vectorize_particles=False, + strict_enumeration_warning=True, + ignore_jit_warnings=False, + jit_options=None, + retain_graph=None, + tail_adaptive_beta=-1.0, + ): if max_iarange_nesting is not None: - warnings.warn("max_iarange_nesting is deprecated; use max_plate_nesting instead", - DeprecationWarning) + warnings.warn( + "max_iarange_nesting is deprecated; use max_plate_nesting instead", + DeprecationWarning, + ) max_plate_nesting = max_iarange_nesting self.max_plate_nesting = max_plate_nesting self.num_particles = num_particles @@ -92,13 +96,16 @@ def _guess_max_plate_nesting(self, model, guide, args, kwargs): with poutine.block(): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace( - poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) + poutine.replay(model, trace=guide_trace) + ).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) - sites = [site - for trace in (model_trace, guide_trace) - for site in trace.nodes.values() - if site["type"] == "sample"] + sites = [ + site + for trace in (model_trace, guide_trace) + for site in trace.nodes.values() + if site["type"] == "sample" + ] # Validate shapes now, since shape constraints will be weaker once # max_plate_nesting is changed from float('inf') to some finite value. @@ -108,16 +115,18 @@ def _guess_max_plate_nesting(self, model, guide, args, kwargs): guide_trace.compute_log_prob() model_trace.compute_log_prob() for site in sites: - check_site_shape(site, max_plate_nesting=float('inf')) - - dims = [frame.dim - for site in sites - for frame in site["cond_indep_stack"] - if frame.vectorized] + check_site_shape(site, max_plate_nesting=float("inf")) + + dims = [ + frame.dim + for site in sites + for frame in site["cond_indep_stack"] + if frame.vectorized + ] self.max_plate_nesting = -min(dims) if dims else 0 if self.vectorize_particles and self.num_particles > 1: self.max_plate_nesting += 1 - logging.info('Guessed max_plate_nesting = {}'.format(self.max_plate_nesting)) + logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting)) def _vectorized_num_particles(self, fn): """ @@ -133,7 +142,11 @@ def _vectorized_num_particles(self, fn): def wrapped_fn(*args, **kwargs): if self.num_particles == 1: return fn(*args, **kwargs) - with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting): + with pyro.plate( + "num_particles_vectorized", + self.num_particles, + dim=-self.max_plate_nesting, + ): return fn(*args, **kwargs) return wrapped_fn @@ -144,9 +157,12 @@ def _get_vectorized_trace(self, model, guide, args, kwargs): ``num_particles``, and returns a single trace from the wrapped model and guide. """ - return self._get_trace(self._vectorized_num_particles(model), - self._vectorized_num_particles(guide), - args, kwargs) + return self._get_trace( + self._vectorized_num_particles(model), + self._vectorized_num_particles(guide), + args, + kwargs, + ) @abstractmethod def _get_trace(self, model, guide, args, kwargs): @@ -162,7 +178,7 @@ def _get_traces(self, model, guide, args, kwargs): the result packaged as a trace generator. """ if self.vectorize_particles: - if self.max_plate_nesting == float('inf'): + if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(model, guide, args, kwargs) yield self._get_vectorized_trace(model, guide, args, kwargs) else: diff --git a/pyro/infer/energy_distance.py b/pyro/infer/energy_distance.py index 2239676b87..535586a395 100644 --- a/pyro/infer/energy_distance.py +++ b/pyro/infer/energy_distance.py @@ -18,7 +18,7 @@ def _squared_error(x, y, scale, mask): diff = x - y - if getattr(scale, 'shape', ()) or getattr(mask, 'shape', ()): + if getattr(scale, "shape", ()) or getattr(mask, "shape", ()): error = torch.einsum("nbe,nbe->nb", diff, diff) return scale_and_mask(error, scale, mask).sum(-1) else: @@ -75,17 +75,18 @@ class EnergyDistance: :func:`pyro.plate` contexts. If omitted, this will guess a valid value by running the (model,guide) pair once. """ - def __init__(self, - beta=1., - prior_scale=0., - num_particles=2, - max_plate_nesting=float('inf')): + + def __init__( + self, beta=1.0, prior_scale=0.0, num_particles=2, max_plate_nesting=float("inf") + ): if not (isinstance(beta, (float, int)) and 0 < beta and beta < 2): raise ValueError("Expected beta in (0,2), actual {}".format(beta)) if not (isinstance(prior_scale, (float, int)) and prior_scale >= 0): raise ValueError("Expected prior_scale >= 0, actual {}".format(prior_scale)) if not (isinstance(num_particles, int) and num_particles >= 2): - raise ValueError("Expected num_particles >= 2, actual {}".format(num_particles)) + raise ValueError( + "Expected num_particles >= 2, actual {}".format(num_particles) + ) self.beta = beta self.prior_scale = prior_scale self.num_particles = num_particles @@ -102,8 +103,9 @@ def _get_traces(self, model, guide, args, kwargs): with validation_enabled(False): # Avoid calling .log_prob() when undefined. # TODO factor this out as a stand-alone helper. ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs) - vectorize = pyro.plate("num_particles_vectorized", self.num_particles, - dim=-self.max_plate_nesting) + vectorize = pyro.plate( + "num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting + ) # Trace the guide as in ELBO. with poutine.trace() as tr, vectorize: @@ -128,16 +130,22 @@ def _get_traces(self, model, guide, args, kwargs): if site["type"] == "sample": warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): - raise ValueError("EnergyDistance requires fully reparametrized guides") + raise ValueError( + "EnergyDistance requires fully reparametrized guides" + ) for trace in model_trace.nodes.values(): if site["type"] == "sample": if site["is_observed"]: warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): - raise ValueError("EnergyDistance requires reparametrized likelihoods") + raise ValueError( + "EnergyDistance requires reparametrized likelihoods" + ) if self.prior_scale > 0: - model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"]) + model_trace.compute_log_prob( + site_filter=lambda name, site: not site["is_observed"] + ) if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": @@ -168,7 +176,11 @@ def __call__(self, model, guide, *args, **kwargs): squared_error = [] # E[ (X - x)^2 ] squared_entropy = [] # E[ (X - X')^2 ] prototype = next(iter(data.values())) - pairs = prototype.new_ones(self.num_particles, self.num_particles).tril(-1).nonzero(as_tuple=False) + pairs = ( + prototype.new_ones(self.num_particles, self.num_particles) + .tril(-1) + .nonzero(as_tuple=False) + ) for name, obs in data.items(): sample = samples[name] scale = model_trace.nodes[name]["scale"] @@ -176,17 +188,21 @@ def __call__(self, model, guide, *args, **kwargs): # Flatten to subshapes of (num_particles, batch_size, event_size). event_dim = model_trace.nodes[name]["fn"].event_dim - batch_shape = obs.shape[:obs.dim() - event_dim] - event_shape = obs.shape[obs.dim() - event_dim:] - if getattr(scale, 'shape', ()): + batch_shape = obs.shape[: obs.dim() - event_dim] + event_shape = obs.shape[obs.dim() - event_dim :] + if getattr(scale, "shape", ()): scale = scale.expand(batch_shape).reshape(-1) - if getattr(mask, 'shape', ()): + if getattr(mask, "shape", ()): mask = mask.expand(batch_shape).reshape(-1) obs = obs.reshape(batch_shape.numel(), event_shape.numel()) - sample = sample.reshape(self.num_particles, batch_shape.numel(), event_shape.numel()) + sample = sample.reshape( + self.num_particles, batch_shape.numel(), event_shape.numel() + ) squared_error.append(_squared_error(sample, obs, scale, mask)) - squared_entropy.append(_squared_error(*sample[pairs].unbind(1), scale, mask)) + squared_entropy.append( + _squared_error(*sample[pairs].unbind(1), scale, mask) + ) squared_error = reduce(operator.add, squared_error) squared_entropy = reduce(operator.add, squared_entropy) diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 645c18b304..190762fb38 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -14,17 +14,23 @@ def iter_discrete_escape(trace, msg): - return ((msg["type"] == "sample") and - (not msg["is_observed"]) and - (msg["infer"].get("enumerate") == "sequential") and # only sequential - (msg["name"] not in trace)) + return ( + (msg["type"] == "sample") + and (not msg["is_observed"]) + and (msg["infer"].get("enumerate") == "sequential") + and (msg["name"] not in trace) # only sequential + ) def iter_discrete_extend(trace, site, **ignored): values = enumerate_site(site) enum_total = values.shape[0] - with ignore_jit_warnings(["Converting a tensor to a Python index", - ("Iterating over a tensor", RuntimeWarning)]): + with ignore_jit_warnings( + [ + "Converting a tensor to a Python index", + ("Iterating over a tensor", RuntimeWarning), + ] + ): values = iter(values) for i, value in enumerate(values): extended_site = site.copy() @@ -36,7 +42,9 @@ def iter_discrete_extend(trace, site, **ignored): yield extended_trace -def get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach=False): +def get_importance_trace( + graph_type, max_plate_nesting, model, guide, args, kwargs, detach=False +): """ Returns a single trace from the guide, which can optionally be detached, and the model that is run against it. @@ -44,8 +52,9 @@ def get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwar guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs) if detach: guide_trace.detach_() - model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), - graph_type=graph_type).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(model, trace=guide_trace), graph_type=graph_type + ).get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) @@ -82,8 +91,11 @@ def iter_discrete_traces(graph_type, fn, *args, **kwargs): queue = LifoQueue() queue.put(Trace()) traced_fn = poutine.trace( - poutine.queue(fn, queue, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend), - graph_type=graph_type) + poutine.queue( + fn, queue, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend + ), + graph_type=graph_type, + ) while not queue.empty(): yield traced_fn.get_trace(*args, **kwargs) @@ -94,13 +106,17 @@ def _config_fn(default, expand, num_samples, tmc, site): if type(site["fn"]).__name__ == "_Subsample": return {} if num_samples is not None: - return {"enumerate": site["infer"].get("enumerate", default), - "num_samples": site["infer"].get("num_samples", num_samples), - "expand": site["infer"].get("expand", expand), - "tmc": site["infer"].get("tmc", tmc)} + return { + "enumerate": site["infer"].get("enumerate", default), + "num_samples": site["infer"].get("num_samples", num_samples), + "expand": site["infer"].get("expand", expand), + "tmc": site["infer"].get("tmc", tmc), + } if getattr(site["fn"], "has_enumerate_support", False): - return {"enumerate": site["infer"].get("enumerate", default), - "expand": site["infer"].get("expand", expand)} + return { + "enumerate": site["infer"].get("enumerate", default), + "expand": site["infer"].get("expand", expand), + } return {} @@ -108,7 +124,9 @@ def _config_enumerate(default, expand, num_samples, tmc): return partial(_config_fn, default, expand, num_samples, tmc) -def config_enumerate(guide=None, default="parallel", expand=False, num_samples=None, tmc="diagonal"): +def config_enumerate( + guide=None, default="parallel", expand=False, num_samples=None, tmc="diagonal" +): """ Configures enumeration for all relevant sites in a guide. This is mainly used in conjunction with :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`. @@ -153,23 +171,39 @@ def guide2(*args, **kwargs): :rtype: callable """ if default not in ["sequential", "parallel", "flat", None]: - raise ValueError("Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format( - repr(default))) + raise ValueError( + "Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format( + repr(default) + ) + ) if expand not in [True, False]: - raise ValueError("Invalid expand value. Expected True or False, but got {}".format(repr(expand))) + raise ValueError( + "Invalid expand value. Expected True or False, but got {}".format( + repr(expand) + ) + ) if num_samples is not None: if not (isinstance(num_samples, numbers.Number) and num_samples > 0): - raise ValueError("Invalid num_samples, expected None or positive integer, but got {}".format( - repr(num_samples))) + raise ValueError( + "Invalid num_samples, expected None or positive integer, but got {}".format( + repr(num_samples) + ) + ) if default == "sequential": - raise ValueError('Local sampling does not support "sequential" sampling; ' - 'use "parallel" sampling instead.') + raise ValueError( + 'Local sampling does not support "sequential" sampling; ' + 'use "parallel" sampling instead.' + ) if tmc == "full" and num_samples is not None and num_samples > 1: # tmc strategies validated elsewhere (within enum handler) expand = True # Support usage as a decorator: if guide is None: - return lambda guide: config_enumerate(guide, default=default, expand=expand, num_samples=num_samples, tmc=tmc) + return lambda guide: config_enumerate( + guide, default=default, expand=expand, num_samples=num_samples, tmc=tmc + ) - return poutine.infer_config(guide, config_fn=_config_enumerate(default, expand, num_samples, tmc)) + return poutine.infer_config( + guide, config_fn=_config_enumerate(default, expand, num_samples, tmc) + ) diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index f7f60b62d5..d7c25a843d 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -32,7 +32,9 @@ def __init__(self, model, guide=None, num_samples=None): super().__init__() if num_samples is None: num_samples = 10 - warnings.warn("num_samples not provided, defaulting to {}".format(num_samples)) + warnings.warn( + "num_samples not provided, defaulting to {}".format(num_samples) + ) if guide is None: # propose from the prior by making a guide from the model by hiding observes guide = poutine.block(model, hide_types=["observe"]) @@ -47,7 +49,8 @@ def _traces(self, *args, **kwargs): for i in range(self.num_samples): guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) + poutine.replay(self.model, trace=guide_trace) + ).get_trace(*args, **kwargs) log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum() yield (model_trace, log_weight) @@ -59,10 +62,12 @@ def get_log_normalizer(self): # ensure list is not empty if self.log_weights: log_w = torch.tensor(self.log_weights) - log_num_samples = torch.log(torch.tensor(self.num_samples * 1.)) + log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0)) return torch.logsumexp(log_w - log_num_samples, 0) else: - warnings.warn("The log_weights list is empty, can not compute normalizing constant estimate.") + warnings.warn( + "The log_weights list is empty, can not compute normalizing constant estimate." + ) def get_normalized_weights(self, log_scale=False): """ @@ -73,7 +78,9 @@ def get_normalized_weights(self, log_scale=False): log_w_norm = log_w - torch.logsumexp(log_w, 0) return log_w_norm if log_scale else torch.exp(log_w_norm) else: - warnings.warn("The log_weights list is empty. There is nothing to normalize.") + warnings.warn( + "The log_weights list is empty. There is nothing to normalize." + ) def get_ESS(self): """ @@ -81,9 +88,11 @@ def get_ESS(self): """ if self.log_weights: log_w_norm = self.get_normalized_weights(log_scale=True) - ess = torch.exp(-torch.logsumexp(2*log_w_norm, 0)) + ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) else: - warnings.warn("The log_weights list is empty, effective sample size is zero.") + warnings.warn( + "The log_weights list is empty, effective sample size is zero." + ) ess = 0 return ess @@ -116,12 +125,16 @@ def vectorized_importance_weights(model, guide, *args, **kwargs): def vectorize(fn): def _fn(*args, **kwargs): - with pyro.plate("num_particles_vectorized", num_samples, dim=-max_plate_nesting): + with pyro.plate( + "num_particles_vectorized", num_samples, dim=-max_plate_nesting + ): return fn(*args, **kwargs) + return _fn model_trace, guide_trace = get_importance_trace( - "flat", max_plate_nesting, vectorize(model), vectorize(guide), args, kwargs) + "flat", max_plate_nesting, vectorize(model), vectorize(guide), args, kwargs + ) guide_trace.pack_tensors() model_trace.pack_tensors(guide_trace.plate_to_symbol) @@ -130,18 +143,22 @@ def _fn(*args, **kwargs): log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum() else: wd = guide_trace.plate_to_symbol["num_particles_vectorized"] - log_weights = 0. + log_weights = 0.0 for site in model_trace.nodes.values(): if site["type"] != "sample": continue - log_weights += torch.einsum(site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]]) + log_weights += torch.einsum( + site["packed"]["log_prob"]._pyro_dims + "->" + wd, + [site["packed"]["log_prob"]], + ) for site in guide_trace.nodes.values(): if site["type"] != "sample": continue - log_weights -= torch.einsum(site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]]) + log_weights -= torch.einsum( + site["packed"]["log_prob"]._pyro_dims + "->" + wd, + [site["packed"]["log_prob"]], + ) if normalized: log_weights = log_weights - torch.logsumexp(log_weights) @@ -190,28 +207,42 @@ def psis_diagnostic(model, guide, *args, **kwargs): :returns float: the PSIS diagnostic k """ - num_particles = kwargs.pop('num_particles', 1000) - max_simultaneous_particles = kwargs.pop('max_simultaneous_particles', num_particles) - max_plate_nesting = kwargs.pop('max_plate_nesting', 7) + num_particles = kwargs.pop("num_particles", 1000) + max_simultaneous_particles = kwargs.pop("max_simultaneous_particles", num_particles) + max_plate_nesting = kwargs.pop("max_plate_nesting", 7) if num_particles % max_simultaneous_particles != 0: - raise ValueError("num_particles must be divisible by max_simultaneous_particles.") + raise ValueError( + "num_particles must be divisible by max_simultaneous_particles." + ) N = num_particles // max_simultaneous_particles - log_weights = [vectorized_importance_weights(model, guide, num_samples=max_simultaneous_particles, - max_plate_nesting=max_plate_nesting, - *args, **kwargs)[0] for _ in range(N)] + log_weights = [ + vectorized_importance_weights( + model, + guide, + num_samples=max_simultaneous_particles, + max_plate_nesting=max_plate_nesting, + *args, + **kwargs, + )[0] + for _ in range(N) + ] log_weights = torch.cat(log_weights) log_weights -= log_weights.max() log_weights = torch.sort(log_weights, descending=False)[0] - cutoff_index = - int(math.ceil(min(0.2 * num_particles, 3.0 * math.sqrt(num_particles)))) - 1 + cutoff_index = ( + -int(math.ceil(min(0.2 * num_particles, 3.0 * math.sqrt(num_particles)))) - 1 + ) lw_cutoff = max(math.log(1.0e-15), log_weights[cutoff_index]) lw_tail = log_weights[log_weights > lw_cutoff] if len(lw_tail) < 10: - warnings.warn("Not enough tail samples to compute PSIS diagnostic; increase num_particles.") - k = float('inf') + warnings.warn( + "Not enough tail samples to compute PSIS diagnostic; increase num_particles." + ) + k = float("inf") else: k, _ = fit_generalized_pareto(lw_tail.exp() - math.exp(lw_cutoff)) diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index e7a2fe0676..6fe082cf3e 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -28,12 +28,14 @@ class WarmupAdapter: periodically updated when adaptation is engaged. """ - def __init__(self, - step_size=1, - adapt_step_size=False, - target_accept_prob=0.8, - adapt_mass_matrix=False, - dense_mass=False): + def __init__( + self, + step_size=1, + adapt_step_size=False, + target_accept_prob=0.8, + adapt_mass_matrix=False, + dense_mass=False, + ): self.adapt_step_size = adapt_step_size self.adapt_mass_matrix = adapt_mass_matrix self.target_accept_prob = target_accept_prob @@ -70,8 +72,12 @@ def _build_adaptation_schedule(self): start_buffer_size = self._adapt_start_buffer end_buffer_size = self._adapt_end_buffer init_window_size = self._adapt_initial_window - if (self._adapt_start_buffer + self._adapt_end_buffer - + self._adapt_initial_window > self._warmup_steps): + if ( + self._adapt_start_buffer + + self._adapt_end_buffer + + self._adapt_initial_window + > self._warmup_steps + ): start_buffer_size = int(0.15 * self._warmup_steps) end_buffer_size = int(0.1 * self._warmup_steps) init_window_size = self._warmup_steps - start_buffer_size - end_buffer_size @@ -88,9 +94,12 @@ def _build_adaptation_schedule(self): else: cur_window_size = end_window_start - cur_window_start next_window_start = cur_window_start + cur_window_size - adaptation_schedule.append(adapt_window(cur_window_start, next_window_start - 1)) - adaptation_schedule.append(adapt_window(end_window_start, - self._warmup_steps - 1)) + adaptation_schedule.append( + adapt_window(cur_window_start, next_window_start - 1) + ) + adaptation_schedule.append( + adapt_window(end_window_start, self._warmup_steps - 1) + ) return adaptation_schedule def reset_step_size_adaptation(self, z): @@ -115,8 +124,14 @@ def _end_adaptation(self): _, log_step_size_avg = self._step_size_adapt_scheme.get_state() self.step_size = math.exp(log_step_size_avg) - def configure(self, warmup_steps, initial_step_size=None, mass_matrix_shape=None, - find_reasonable_step_size_fn=None, options={}): + def configure( + self, + warmup_steps, + initial_step_size=None, + mass_matrix_shape=None, + find_reasonable_step_size_fn=None, + options={}, + ): r""" Model specific properties that are specified when the HMC kernel is setup. @@ -129,13 +144,19 @@ def configure(self, warmup_steps, initial_step_size=None, mass_matrix_shape=None tensor options. This is used to construct initial mass matrix in `mass_matrix_adapter`. """ self._warmup_steps = warmup_steps - self.step_size = initial_step_size if initial_step_size is not None else self._init_step_size + self.step_size = ( + initial_step_size if initial_step_size is not None else self._init_step_size + ) if find_reasonable_step_size_fn is not None: self._find_reasonable_step_size = find_reasonable_step_size_fn if mass_matrix_shape is None or self.step_size is None: - raise ValueError("Incomplete configuration - step size and inverse mass matrix " - "need to be initialized.") - self.mass_matrix_adapter.configure(mass_matrix_shape, self.adapt_mass_matrix, options=options) + raise ValueError( + "Incomplete configuration - step size and inverse mass matrix " + "need to be initialized." + ) + self.mass_matrix_adapter.configure( + mass_matrix_shape, self.adapt_mass_matrix, options=options + ) if not self._adaptation_disabled: self._adaptation_schedule = self._build_adaptation_schedule() self._current_window = 0 # starting window index @@ -155,8 +176,9 @@ def step(self, t, z, accept_prob, z_grad=None): return window = self._adaptation_schedule[self._current_window] num_windows = len(self._adaptation_schedule) - mass_matrix_adaptation_phase = self.adapt_mass_matrix and \ - (0 < self._current_window < num_windows - 1) + mass_matrix_adaptation_phase = self.adapt_mass_matrix and ( + 0 < self._current_window < num_windows - 1 + ) if self.adapt_step_size: self._update_step_size(accept_prob.item()) if mass_matrix_adaptation_phase: @@ -224,7 +246,8 @@ class BlockMassMatrix: :param float init_scale: initial scale to construct the initial mass matrix. """ - def __init__(self, init_scale=1.): + + def __init__(self, init_scale=1.0): # TODO: we might allow users specify the initial mass matrix in the constructor. self._init_scale = init_scale self._adapt_scheme = {} @@ -269,8 +292,11 @@ def configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}): for site_names, shape in mass_matrix_shape.items(): self._mass_matrix_size[site_names] = shape[0] diagonal = len(shape) == 1 - inverse_mass_matrix[site_names] = torch.full(shape, self._init_scale, **options) \ - if diagonal else torch.eye(*shape, **options) * self._init_scale + inverse_mass_matrix[site_names] = ( + torch.full(shape, self._init_scale, **options) + if diagonal + else torch.eye(*shape, **options) * self._init_scale + ) if adapt_mass_matrix: adapt_scheme = WelfordCovariance(diagonal=diagonal) self._adapt_scheme[site_names] = adapt_scheme @@ -294,7 +320,9 @@ def end_adaptation(self): """ inverse_mass_matrix = {} for site_names, adapt_scheme in self._adapt_scheme.items(): - inverse_mass_matrix[site_names] = adapt_scheme.get_covariance(regularize=True) + inverse_mass_matrix[site_names] = adapt_scheme.get_covariance( + regularize=True + ) self.inverse_mass_matrix = inverse_mass_matrix def kinetic_grad(self, r): @@ -338,7 +366,9 @@ def scale(self, r_unscaled, r_prototype): pos = 0 for site_name in site_names: next_pos = pos + r_prototype[site_name].numel() - s[site_name] = r_flat[pos:next_pos].reshape(r_prototype[site_name].shape) + s[site_name] = r_flat[pos:next_pos].reshape( + r_prototype[site_name].shape + ) pos = next_pos return s @@ -353,7 +383,10 @@ def unscale(self, r): :returns: a dictionary maps site names to the corresponding tensor """ u = {} - for site_names, mass_matrix_sqrt_inverse in self._mass_matrix_sqrt_inverse.items(): + for ( + site_names, + mass_matrix_sqrt_inverse, + ) in self._mass_matrix_sqrt_inverse.items(): r_flat = torch.cat([r[site_name].reshape(-1) for site_name in site_names]) u[site_names] = _matvecmul(mass_matrix_sqrt_inverse, r_flat) return u @@ -369,7 +402,8 @@ class ArrowheadMassMatrix: :param float init_scale: initial scale to construct the initial mass matrix. """ - def __init__(self, init_scale=1.): + + def __init__(self, init_scale=1.0): self._init_scale = init_scale self._adapt_scheme = {} self._mass_matrix = {} @@ -477,14 +511,19 @@ def kinetic_grad(self, r): :returns: a dictionary maps site names to the corresponding gradient """ v = {} - for site_names, mass_matrix_sqrt_inverse in self._mass_matrix_sqrt_inverse.items(): + for ( + site_names, + mass_matrix_sqrt_inverse, + ) in self._mass_matrix_sqrt_inverse.items(): r_flat = torch.cat([r[site_name].reshape(-1) for site_name in site_names]) # NB: using inverse_mass_matrix as in BlockMassMatrix will cost # O(N^2 x head_size) operators and O(N^2) memory requirement; # here, we will leverage mass_matrix_sqrt_inverse to reduce the cost to # O(N x head_size^2) operators and O(N x head_size) memory requirement. r_unscaled = triu_matvecmul(mass_matrix_sqrt_inverse, r_flat) - v_flat = triu_matvecmul(mass_matrix_sqrt_inverse, r_unscaled, transpose=True) + v_flat = triu_matvecmul( + mass_matrix_sqrt_inverse, r_unscaled, transpose=True + ) # unpacking pos = 0 @@ -514,7 +553,9 @@ def scale(self, r_unscaled, r_prototype): pos = 0 for site_name in site_names: next_pos = pos + r_prototype[site_name].numel() - s[site_name] = r_flat[pos:next_pos].reshape(r_prototype[site_name].shape) + s[site_name] = r_flat[pos:next_pos].reshape( + r_prototype[site_name].shape + ) pos = next_pos return s @@ -529,7 +570,10 @@ def unscale(self, r): :returns: a dictionary maps site names to the corresponding tensor """ u = {} - for site_names, mass_matrix_sqrt_inverse in self._mass_matrix_sqrt_inverse.items(): + for ( + site_names, + mass_matrix_sqrt_inverse, + ) in self._mass_matrix_sqrt_inverse.items(): r_flat = torch.cat([r[site_name].reshape(-1) for site_name in site_names]) u[site_names] = triu_matvecmul(mass_matrix_sqrt_inverse, r_flat) return u diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index 6175ee2beb..2e6a00288b 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -42,15 +42,19 @@ from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, StreamingStats from pyro.util import optional -MAX_SEED = 2**32 - 1 +MAX_SEED = 2 ** 32 - 1 -def logger_thread(log_queue, warmup_steps, num_samples, num_chains, disable_progbar=False): +def logger_thread( + log_queue, warmup_steps, num_samples, num_chains, disable_progbar=False +): """ Logging thread that asynchronously consumes logging events from `log_queue`, and handles them appropriately. """ - progress_bars = ProgressBar(warmup_steps, num_samples, disable=disable_progbar, num_bars=num_chains) + progress_bars = ProgressBar( + warmup_steps, num_samples, disable=disable_progbar, num_bars=num_chains + ) logger = logging.getLogger(__name__) logger.propagate = False logger.addHandler(TqdmHandler()) @@ -69,7 +73,9 @@ def logger_thread(log_queue, warmup_steps, num_samples, num_chains, disable_prog pbar_pos = int(logger_id.split(":")[-1]) num_samples[pbar_pos] += 1 if num_samples[pbar_pos] == warmup_steps: - progress_bars.set_description("Sample [{}]".format(pbar_pos + 1), pos=pbar_pos) + progress_bars.set_description( + "Sample [{}]".format(pbar_pos + 1), pos=pbar_pos + ) diagnostics = json.loads(msg, object_pairs_hook=OrderedDict) progress_bars.set_postfix(diagnostics, pos=pbar_pos, refresh=False) progress_bars.update(pos=pbar_pos) @@ -80,8 +86,18 @@ def logger_thread(log_queue, warmup_steps, num_samples, num_chains, disable_prog class _Worker: - def __init__(self, chain_id, result_queue, log_queue, event, kernel, num_samples, - warmup_steps, initial_params=None, hook=None): + def __init__( + self, + chain_id, + result_queue, + log_queue, + event, + kernel, + num_samples, + warmup_steps, + initial_params=None, + hook=None, + ): self.chain_id = chain_id self.kernel = kernel if initial_params is not None: @@ -106,8 +122,15 @@ def run(self, *args, **kwargs): logging_hook = _add_logging_hook(logger, None, self.hook) try: - for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, logging_hook, - None, *args, **kwargs): + for sample in _gen_samples( + self.kernel, + self.warmup_steps, + self.num_samples, + logging_hook, + None, + *args, + **kwargs + ): self.result_queue.put_nowait((self.chain_id, sample)) self.event.wait() self.event.clear() @@ -125,10 +148,20 @@ def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwa yield {name: params[name].shape for name in save_params} for i in range(warmup_steps): params = kernel.sample(params) - hook(kernel, params, 'Warmup [{}]'.format(chain_id) if chain_id is not None else 'Warmup', i) + hook( + kernel, + params, + "Warmup [{}]".format(chain_id) if chain_id is not None else "Warmup", + i, + ) for i in range(num_samples): params = kernel.sample(params) - hook(kernel, params, 'Sample [{}]'.format(chain_id) if chain_id is not None else 'Sample', i) + hook( + kernel, + params, + "Sample [{}]".format(chain_id) if chain_id is not None else "Sample", + i, + ) flat = [params[name].reshape(-1) for name in save_params] yield (torch.cat if flat else torch.tensor)(flat) yield kernel.diagnostics() @@ -152,7 +185,16 @@ class _UnarySampler: Single process runner class optimized for the case chains are drawn sequentially. """ - def __init__(self, kernel, num_samples, warmup_steps, num_chains, disable_progbar, initial_params=None, hook=None): + def __init__( + self, + kernel, + num_samples, + warmup_steps, + num_chains, + disable_progbar, + initial_params=None, + hook=None, + ): self.kernel = kernel self.initial_params = initial_params self.warmup_steps = warmup_steps @@ -173,12 +215,20 @@ def run(self, *args, **kwargs): initial_params = {k: v[i] for k, v in self.initial_params.items()} self.kernel.initial_params = initial_params - progress_bar = ProgressBar(self.warmup_steps, self.num_samples, disable=self.disable_progbar) + progress_bar = ProgressBar( + self.warmup_steps, self.num_samples, disable=self.disable_progbar + ) logger = initialize_logger(logger, "", progress_bar) hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook) - for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging, - i if self.num_chains > 1 else None, - *args, **kwargs): + for sample in _gen_samples( + self.kernel, + self.warmup_steps, + self.num_samples, + hook_w_logging, + i if self.num_chains > 1 else None, + *args, + **kwargs + ): yield sample, i # sample, chain_id self.kernel.cleanup() progress_bar.close() @@ -190,8 +240,18 @@ class _MultiSampler: `torch.multiprocessing` module (itself a light wrapper over the python `multiprocessing` module) to spin up parallel workers. """ - def __init__(self, kernel, num_samples, warmup_steps, num_chains, mp_context, - disable_progbar, initial_params=None, hook=None): + + def __init__( + self, + kernel, + num_samples, + warmup_steps, + num_chains, + mp_context, + disable_progbar, + initial_params=None, + hook=None, + ): self.kernel = kernel self.warmup_steps = warmup_steps self.num_chains = num_chains @@ -202,13 +262,21 @@ def __init__(self, kernel, num_samples, warmup_steps, num_chains, mp_context, self.ctx = mp.get_context(mp_context) self.result_queue = self.ctx.Queue() self.log_queue = self.ctx.Queue() - self.logger = initialize_logger(logging.getLogger("pyro.infer.mcmc"), - "MAIN", log_queue=self.log_queue) + self.logger = initialize_logger( + logging.getLogger("pyro.infer.mcmc"), "MAIN", log_queue=self.log_queue + ) self.num_samples = num_samples self.initial_params = initial_params - self.log_thread = threading.Thread(target=logger_thread, - args=(self.log_queue, self.warmup_steps, self.num_samples, - self.num_chains, disable_progbar)) + self.log_thread = threading.Thread( + target=logger_thread, + args=( + self.log_queue, + self.warmup_steps, + self.num_samples, + self.num_chains, + disable_progbar, + ), + ) self.log_thread.daemon = True self.log_thread.start() self.events = [self.ctx.Event() for _ in range(num_chains)] @@ -216,12 +284,28 @@ def __init__(self, kernel, num_samples, warmup_steps, num_chains, mp_context, def init_workers(self, *args, **kwargs): self.workers = [] for i in range(self.num_chains): - init_params = {k: v[i] for k, v in self.initial_params.items()} if self.initial_params is not None else None - worker = _Worker(i, self.result_queue, self.log_queue, self.events[i], self.kernel, - self.num_samples, self.warmup_steps, initial_params=init_params, hook=self.hook) + init_params = ( + {k: v[i] for k, v in self.initial_params.items()} + if self.initial_params is not None + else None + ) + worker = _Worker( + i, + self.result_queue, + self.log_queue, + self.events[i], + self.kernel, + self.num_samples, + self.warmup_steps, + initial_params=init_params, + hook=self.hook, + ) worker.daemon = True - self.workers.append(self.ctx.Process(name=str(i), target=worker.run, - args=args, kwargs=kwargs)) + self.workers.append( + self.ctx.Process( + name=str(i), target=worker.run, args=args, kwargs=kwargs + ) + ) def terminate(self, terminate_workers=False): if self.log_thread.is_alive(): @@ -269,6 +353,7 @@ class AbstractMCMC(ABC): """ Base class for MCMC methods. """ + def __init__(self, kernel, num_chains, transforms): self.kernel = kernel self.num_chains = num_chains @@ -296,16 +381,23 @@ def _set_transforms(self, *args, **kwargs): self.transforms = {} def _validate_kernel(self, initial_params): - if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None: + if ( + isinstance(self.kernel, (HMC, NUTS)) + and self.kernel.potential_fn is not None + ): if initial_params is None: - raise ValueError("Must provide valid initial parameters to begin sampling" - " when using `potential_fn` in HMC/NUTS kernel.") + raise ValueError( + "Must provide valid initial parameters to begin sampling" + " when using `potential_fn` in HMC/NUTS kernel." + ) def _validate_initial_params(self, initial_params): for v in initial_params.values(): if v.shape[0] != self.num_chains: - raise ValueError("The leading dimension of tensors in `initial_params` " - "must match the number of chains.") + raise ValueError( + "The leading dimension of tensors in `initial_params` " + "must match the number of chains." + ) class MCMC(AbstractMCMC): @@ -355,11 +447,25 @@ class MCMC(AbstractMCMC): save during sampling and diagnostics. This is useful in models with large nuisance variables. Defaults to None, saving all params. """ - def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, - num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False, - disable_validation=True, transforms=None, save_params=None): + + def __init__( + self, + kernel, + num_samples, + warmup_steps=None, + initial_params=None, + num_chains=1, + hook_fn=None, + mp_context=None, + disable_progbar=False, + disable_validation=True, + transforms=None, + save_params=None, + ): super().__init__(kernel, num_chains, transforms) - self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan + self.warmup_steps = ( + num_samples if warmup_steps is None else warmup_steps + ) # Stan self.num_samples = num_samples self.disable_validation = disable_validation self._samples = None @@ -382,23 +488,43 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, mp_context = "spawn" # verify num_chains is compatible with available CPU. - available_cpu = max(mp.cpu_count() - 1, 1) # reserving 1 for the main process. + available_cpu = max( + mp.cpu_count() - 1, 1 + ) # reserving 1 for the main process. if num_chains <= available_cpu: parallel = True else: - warnings.warn("num_chains={} is more than available_cpu={}. " - "Chains will be drawn sequentially." - .format(num_chains, available_cpu)) + warnings.warn( + "num_chains={} is more than available_cpu={}. " + "Chains will be drawn sequentially.".format( + num_chains, available_cpu + ) + ) else: if initial_params: initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()} self._diagnostics = [None] * num_chains if parallel: - self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context, - disable_progbar, initial_params=initial_params, hook=hook_fn) + self.sampler = _MultiSampler( + kernel, + num_samples, + self.warmup_steps, + num_chains, + mp_context, + disable_progbar, + initial_params=initial_params, + hook=hook_fn, + ) else: - self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar, - initial_params=initial_params, hook=hook_fn) + self.sampler = _UnarySampler( + kernel, + num_samples, + self.warmup_steps, + num_chains, + disable_progbar, + initial_params=initial_params, + hook=hook_fn, + ) @poutine.block def run(self, *args, **kwargs): @@ -425,8 +551,10 @@ def model(data): self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains z_flat_acc = [[] for _ in range(self.num_chains)] - with optional(pyro.validation_enabled(not self.disable_validation), - self.disable_validation is not None): + with optional( + pyro.validation_enabled(not self.disable_validation), + self.disable_validation is not None, + ): # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" # at https://github.com/pytorch/pytorch/issues/10375 # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which @@ -456,7 +584,8 @@ def model(data): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = z_flat_acc[:, :, pos:next_pos].reshape( - (self.num_chains, self.num_samples) + shape) + (self.num_chains, self.num_samples) + shape + ) pos = next_pos assert pos == z_flat_acc.shape[-1] @@ -490,8 +619,10 @@ def diagnostics(self): """ diag = diagnostics(self._samples) for diag_name in self._diagnostics[0]: - diag[diag_name] = {'chain {}'.format(i): self._diagnostics[i][diag_name] - for i in range(self.num_chains)} + diag[diag_name] = { + "chain {}".format(i): self._diagnostics[i][diag_name] + for i in range(self.num_chains) + } return diag def summary(self, prob=0.9): @@ -504,9 +635,17 @@ def summary(self, prob=0.9): :param float prob: the probability mass of samples within the credibility interval. """ print_summary(self._samples, prob=prob) - if 'divergences' in self._diagnostics[0]: - print("Number of divergences: {}".format( - sum([len(self._diagnostics[i]['divergences']) for i in range(self.num_chains)]))) + if "divergences" in self._diagnostics[0]: + print( + "Number of divergences: {}".format( + sum( + [ + len(self._diagnostics[i]["divergences"]) + for i in range(self.num_chains) + ] + ) + ) + ) class StreamingMCMC(AbstractMCMC): @@ -517,11 +656,25 @@ class StreamingMCMC(AbstractMCMC): For available streaming ops please see :mod:`~pyro.ops.streaming`. """ - def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, - statistics=None, num_chains=1, hook_fn=None, disable_progbar=False, - disable_validation=True, transforms=None, save_params=None): + + def __init__( + self, + kernel, + num_samples, + warmup_steps=None, + initial_params=None, + statistics=None, + num_chains=1, + hook_fn=None, + disable_progbar=False, + disable_validation=True, + transforms=None, + save_params=None, + ): super().__init__(kernel, num_chains, transforms) - self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan + self.warmup_steps = ( + num_samples if warmup_steps is None else warmup_steps + ) # Stan self.num_samples = num_samples self.disable_validation = disable_validation self._samples = None @@ -541,8 +694,15 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, if initial_params: initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()} self._diagnostics = [None] * num_chains - self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar, - initial_params=initial_params, hook=hook_fn) + self.sampler = _UnarySampler( + kernel, + num_samples, + self.warmup_steps, + num_chains, + disable_progbar, + initial_params=initial_params, + hook=hook_fn, + ) @poutine.block def run(self, *args, **kwargs): @@ -552,8 +712,10 @@ def run(self, *args, **kwargs): self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains - with optional(pyro.validation_enabled(not self.disable_validation), - self.disable_validation is not None): + with optional( + pyro.validation_enabled(not self.disable_validation), + self.disable_validation is not None, + ): args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args] for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: @@ -586,9 +748,12 @@ def run(self, *args, **kwargs): if name in self.transforms: z_acc[name] = self.transforms[name].inv(z) - self._statistics.update({ - (chain_id, name): transformed_sample for name, transformed_sample in z_acc.items() - }) + self._statistics.update( + { + (chain_id, name): transformed_sample + for name, transformed_sample in z_acc.items() + } + ) # terminate the sampler (shut down worker processes) self.sampler.terminate(True) @@ -620,6 +785,8 @@ def diagnostics(self): statistics = self._statistics.get() diag = diagnostics_from_stats(statistics, self.num_samples, self.num_chains) for diag_name in self._diagnostics[0]: - diag[diag_name] = {'chain {}'.format(i): self._diagnostics[i][diag_name] - for i in range(self.num_chains)} + diag[diag_name] = { + "chain {}".format(i): self._diagnostics[i][diag_name] + for i in range(self.num_chains) + } return diag diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 139d7f4a0d..34106f7690 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -91,22 +91,24 @@ class HMC(MCMCKernel): tensor([ 0.9819, 1.9258, 2.9737]) """ - def __init__(self, - model=None, - potential_fn=None, - step_size=1, - trajectory_length=None, - num_steps=None, - adapt_step_size=True, - adapt_mass_matrix=True, - full_mass=False, - transforms=None, - max_plate_nesting=None, - jit_compile=False, - jit_options=None, - ignore_jit_warnings=False, - target_accept_prob=0.8, - init_strategy=init_to_uniform): + def __init__( + self, + model=None, + potential_fn=None, + step_size=1, + trajectory_length=None, + num_steps=None, + adapt_step_size=True, + adapt_mass_matrix=True, + full_mass=False, + transforms=None, + max_plate_nesting=None, + jit_compile=False, + jit_options=None, + ignore_jit_warnings=False, + target_accept_prob=0.8, + init_strategy=init_to_uniform, + ): if not ((model is None) ^ (potential_fn is None)): raise ValueError("Only one of `model` or `potential_fn` must be specified.") # NB: deprecating args - model, transforms @@ -131,15 +133,17 @@ def __init__(self, self._direction_threshold = math.log(0.8) # from Stan self._max_sliced_energy = 1000 self._reset() - self._adapter = WarmupAdapter(step_size, - adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, - target_accept_prob=target_accept_prob, - dense_mass=full_mass) + self._adapter = WarmupAdapter( + step_size, + adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, + target_accept_prob=target_accept_prob, + dense_mass=full_mass, + ) super().__init__() def _kinetic_energy(self, r_unscaled): - energy = 0. + energy = 0.0 for site_names, value in r_unscaled.items(): energy = energy + value.dot(value) return 0.5 * energy @@ -147,7 +151,7 @@ def _kinetic_energy(self, r_unscaled): def _reset(self): self._t = 0 self._accept_cnt = 0 - self._mean_accept_prob = 0. + self._mean_accept_prob = 0.0 self._divergences = [] self._prototype_trace = None self._initial_params = None @@ -169,7 +173,8 @@ def _find_reasonable_step_size(self, z): # contains transforms with cache_size > 0 (https://github.com/pyro-ppl/pyro/issues/2292) z = {k: v.clone() for k, v in z.items()} z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( - z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size) + z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size + ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current @@ -191,7 +196,12 @@ def _find_reasonable_step_size(self, z): r, r_unscaled = self._sample_r(name="r_presample_{}".format(t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( - z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size) + z, + r, + self.potential_fn, + self.mass_matrix_adapter.kinetic_grad, + step_size, + ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current @@ -200,14 +210,19 @@ def _find_reasonable_step_size(self, z): def _sample_r(self, name): r_unscaled = {} - options = {"dtype": self._potential_energy_last.dtype, - "device": self._potential_energy_last.device} + options = { + "dtype": self._potential_energy_last.dtype, + "device": self._potential_energy_last.device, + } for site_names, size in self.mass_matrix_adapter.mass_matrix_size.items(): # we want to sample from Normal distribution using `sample` method rather than # `rsample` method because the former is a bit faster r_unscaled[site_names] = pyro.sample( "{}_{}".format(name, site_names), - NonreparameterizedNormal(torch.zeros(size, **options), torch.ones(size, **options))) + NonreparameterizedNormal( + torch.zeros(size, **options), torch.ones(size, **options) + ), + ) r = self.mass_matrix_adapter.scale(r_unscaled, r_prototype=self.initial_params) return r, r_unscaled @@ -272,10 +287,14 @@ def _initialize_adapter(self): for name in dense_sites: assert isinstance(name, str) and name in self.initial_params, msg dense_sites_set = set().union(*dense_sites_list) - diag_sites = tuple(sorted([name for name in self.initial_params - if name not in dense_sites_set])) - assert len(diag_sites) + sum([len(sites) for sites in dense_sites_list]) == len(self.initial_params), \ - "Site names specified in full_mass are duplicated." + diag_sites = tuple( + sorted( + [name for name in self.initial_params if name not in dense_sites_set] + ) + ) + assert len(diag_sites) + sum([len(sites) for sites in dense_sites_list]) == len( + self.initial_params + ), "Site names specified in full_mass are duplicated." mass_matrix_shape = OrderedDict() for dense_sites in dense_sites_list: @@ -286,12 +305,16 @@ def _initialize_adapter(self): size = sum([self.initial_params[site].numel() for site in diag_sites]) mass_matrix_shape[diag_sites] = (size,) - options = {"dtype": self._potential_energy_last.dtype, - "device": self._potential_energy_last.device} - self._adapter.configure(self._warmup_steps, - mass_matrix_shape=mass_matrix_shape, - find_reasonable_step_size_fn=self._find_reasonable_step_size, - options=options) + options = { + "dtype": self._potential_energy_last.dtype, + "device": self._potential_energy_last.device, + } + self._adapter.configure( + self._warmup_steps, + mass_matrix_shape=mass_matrix_shape, + find_reasonable_step_size_fn=self._find_reasonable_step_size, + options=options, + ) if self._adapter.adapt_step_size: self._adapter.reset_step_size_adaptation(self._initial_params) @@ -335,7 +358,7 @@ def sample(self, params): # return early if no sample sites elif len(z) == 0: self._t += 1 - self._mean_accept_prob = 1. + self._mean_accept_prob = 1.0 if self._t > self._warmup_steps: self._accept_cnt += 1 return params @@ -346,21 +369,35 @@ def sample(self, params): # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( - z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, - self.step_size, self.num_steps, z_grads=z_grads) + z, + r, + self.potential_fn, + self.mass_matrix_adapter.kinetic_grad, + self.step_size, + self.num_steps, + z_grads=z_grads, + ) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) - energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new + energy_proposal = ( + self._kinetic_energy(r_new_unscaled) + potential_energy_new + ) delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. - delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy + delta_energy = ( + scalar_like(delta_energy, float("inf")) + if torch_isnan(delta_energy) + else delta_energy + ) if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) - accept_prob = (-delta_energy).exp().clamp(max=1.) - rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), - scalar_like(accept_prob, 1.))) + accept_prob = (-delta_energy).exp().clamp(max=1.0) + rand = pyro.sample( + "rand_t={}".format(self._t), + dist.Uniform(scalar_like(accept_prob, 0.0), scalar_like(accept_prob, 1.0)), + ) accepted = False if rand < accept_prob: accepted = True @@ -381,11 +418,15 @@ def sample(self, params): return z.copy() def logging(self): - return OrderedDict([ - ("step size", "{:.2e}".format(self.step_size)), - ("acc. prob", "{:.3f}".format(self._mean_accept_prob)) - ]) + return OrderedDict( + [ + ("step size", "{:.2e}".format(self.step_size)), + ("acc. prob", "{:.3f}".format(self._mean_accept_prob)), + ] + ) def diagnostics(self): - return {"divergences": self._divergences, - "acceptance rate": self._accept_cnt / (self._t - self._warmup_steps)} + return { + "divergences": self._divergences, + "acceptance rate": self._accept_cnt / (self._t - self._warmup_steps), + } diff --git a/pyro/infer/mcmc/logger.py b/pyro/infer/mcmc/logger.py index 1050ba787f..2082240439 100644 --- a/pyro/infer/mcmc/logger.py +++ b/pyro/infer/mcmc/logger.py @@ -55,19 +55,35 @@ class ProgressBar: If multiple bars are initialized, they need to be separately updated via the ``pos`` kwarg. """ - def __init__(self, warmup_steps, num_samples, min_width=80, max_width=120, - disable=False, num_bars=1): + + def __init__( + self, + warmup_steps, + num_samples, + min_width=80, + max_width=120, + disable=False, + num_bars=1, + ): total_steps = warmup_steps + num_samples # Disable progress bar in "CI" # (see https://github.com/travis-ci/travis-ci/issues/1337). disable = disable or "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ - bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]" + bar_format = ( + "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]" + ) pbar_cls = tqdm_nb if num_bars > 1 and ipython_env else tqdm self.progress_bars = [] for i in range(num_bars): description = "Warmup" if num_bars == 1 else "Warmup [{}]".format(i + 1) - pbar = pbar_cls(total=total_steps, desc=description, bar_format=bar_format, - position=i, file=sys.stderr, disable=disable) + pbar = pbar_cls( + total=total_steps, + desc=description, + bar_format=bar_format, + position=i, + file=sys.stderr, + disable=disable, + ) # Assume reasonable values when terminal width not available if getattr(pbar, "ncols", None) is not None: pbar.ncols = max(min_width, pbar.ncols) @@ -169,6 +185,7 @@ class TqdmHandler(logging.StreamHandler): Handler that synchronizes the log output with the :class:`~tqdm.tqdm` progress bar. """ + def emit(self, record): try: msg = self.format(record) @@ -191,6 +208,7 @@ class MCMCLoggingHandler(logging.Handler): :param progress_bar: If provided, diagnostic information is updated using the bar. """ + def __init__(self, log_handler, progress_bar=None): logging.Handler.__init__(self) self.log_handler = log_handler @@ -199,8 +217,9 @@ def __init__(self, log_handler, progress_bar=None): def emit(self, record): try: if self.progress_bar and record.msg_type == DIAGNOSTIC_MSG: - diagnostics = json.loads(record.getMessage(), - object_pairs_hook=OrderedDict) + diagnostics = json.loads( + record.getMessage(), object_pairs_hook=OrderedDict + ) self.progress_bar.set_postfix(diagnostics, refresh=False) self.progress_bar.update() else: @@ -216,6 +235,7 @@ class MetadataFilter(logging.Filter): Adds auxiliary information to log records, like `logger_id` and `msg_type`. """ + def __init__(self, logger_id): self.logger_id = logger_id super().__init__() @@ -247,8 +267,7 @@ def initialize_logger(logger, logger_id, progress_bar=None, log_queue=None): format = "%(levelname).1s \t %(message)s" handler = TqdmHandler() else: - raise ValueError("Logger cannot be initialized without a " - "valid handler.") + raise ValueError("Logger cannot be initialized without a " "valid handler.") handler.setFormatter(logging.Formatter(format)) logging_handler = MCMCLoggingHandler(handler, progress_bar) logging_handler.addFilter(MetadataFilter(logger_id)) diff --git a/pyro/infer/mcmc/mcmc_kernel.py b/pyro/infer/mcmc/mcmc_kernel.py index 750db9d88e..bdc5edeb7e 100644 --- a/pyro/infer/mcmc/mcmc_kernel.py +++ b/pyro/infer/mcmc/mcmc_kernel.py @@ -5,7 +5,6 @@ class MCMCKernel(object, metaclass=ABCMeta): - def setup(self, warmup_steps, *args, **kwargs): r""" Optional method to set up any state required at the start of the diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index 5016a2517b..0b1071b745 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -28,11 +28,28 @@ def _logaddexp(x, y): # weight is the number of valid points in case we use slice sampling # and is the log sum of (unnormalized) probabilites of valid points # when we use multinomial sampling -_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "r_left_unscaled", "z_left_grads", - "z_right", "r_right", "r_right_unscaled", "z_right_grads", - "z_proposal", "z_proposal_pe", "z_proposal_grads", - "r_sum", "weight", "turning", "diverging", - "sum_accept_probs", "num_proposals"]) +_TreeInfo = namedtuple( + "TreeInfo", + [ + "z_left", + "r_left", + "r_left_unscaled", + "z_left_grads", + "z_right", + "r_right", + "r_right_unscaled", + "z_right_grads", + "z_proposal", + "z_proposal_pe", + "z_proposal_grads", + "r_sum", + "weight", + "turning", + "diverging", + "sum_accept_probs", + "num_proposals", + ], +) class NUTS(HMC): @@ -117,35 +134,39 @@ class NUTS(HMC): tensor([ 0.9221, 1.9464, 2.9228]) """ - def __init__(self, - model=None, - potential_fn=None, - step_size=1, - adapt_step_size=True, - adapt_mass_matrix=True, - full_mass=False, - use_multinomial_sampling=True, - transforms=None, - max_plate_nesting=None, - jit_compile=False, - jit_options=None, - ignore_jit_warnings=False, - target_accept_prob=0.8, - max_tree_depth=10, - init_strategy=init_to_uniform): - super().__init__(model, - potential_fn, - step_size, - adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, - full_mass=full_mass, - transforms=transforms, - max_plate_nesting=max_plate_nesting, - jit_compile=jit_compile, - jit_options=jit_options, - ignore_jit_warnings=ignore_jit_warnings, - target_accept_prob=target_accept_prob, - init_strategy=init_strategy) + def __init__( + self, + model=None, + potential_fn=None, + step_size=1, + adapt_step_size=True, + adapt_mass_matrix=True, + full_mass=False, + use_multinomial_sampling=True, + transforms=None, + max_plate_nesting=None, + jit_compile=False, + jit_options=None, + ignore_jit_warnings=False, + target_accept_prob=0.8, + max_tree_depth=10, + init_strategy=init_to_uniform, + ): + super().__init__( + model, + potential_fn, + step_size, + adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, + full_mass=full_mass, + transforms=transforms, + max_plate_nesting=max_plate_nesting, + jit_compile=jit_compile, + jit_options=jit_options, + ignore_jit_warnings=ignore_jit_warnings, + target_accept_prob=target_accept_prob, + init_strategy=init_strategy, + ) self.use_multinomial_sampling = use_multinomial_sampling self._max_tree_depth = max_tree_depth # There are three conditions to stop doubling process: @@ -162,10 +183,12 @@ def __init__(self, def _is_turning(self, r_left_unscaled, r_right_unscaled, r_sum): # We follow the strategy in Section A.4.2 of [2] for this implementation. - left_angle = 0. - right_angle = 0. + left_angle = 0.0 + right_angle = 0.0 for site_names, value in r_sum.items(): - rho = value - (r_left_unscaled[site_names] + r_right_unscaled[site_names]) / 2 + rho = ( + value - (r_left_unscaled[site_names] + r_right_unscaled[site_names]) / 2 + ) left_angle += r_left_unscaled[site_names].dot(rho) right_angle += r_right_unscaled[site_names].dot(rho) @@ -174,13 +197,23 @@ def _is_turning(self, r_left_unscaled, r_right_unscaled, r_sum): def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( - z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size, z_grads=z_grads) + z, + r, + self.potential_fn, + self.mass_matrix_adapter.kinetic_grad, + step_size, + z_grads=z_grads, + ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = potential_energy + self._kinetic_energy(r_new_unscaled) # handle the NaN case - energy_new = scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new + energy_new = ( + scalar_like(energy_new, float("inf")) + if torch_isnan(energy_new) + else energy_new + ) sliced_energy = energy_new + log_slice - diverging = (sliced_energy > self._max_sliced_energy) + diverging = sliced_energy > self._max_sliced_energy delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) @@ -191,19 +224,41 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. - tree_weight = scalar_like(sliced_energy, 1. if sliced_energy <= 0 else 0.) + tree_weight = scalar_like(sliced_energy, 1.0 if sliced_energy <= 0 else 0.0) r_sum = r_new_unscaled - return _TreeInfo(z_new, r_new, r_new_unscaled, z_grads, z_new, r_new, r_new_unscaled, z_grads, - z_new, potential_energy, z_grads, r_sum, tree_weight, False, diverging, accept_prob, 1) - - def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): + return _TreeInfo( + z_new, + r_new, + r_new_unscaled, + z_grads, + z_new, + r_new, + r_new_unscaled, + z_grads, + z_new, + potential_energy, + z_grads, + r_sum, + tree_weight, + False, + diverging, + accept_prob, + 1, + ) + + def _build_tree( + self, z, r, z_grads, log_slice, direction, tree_depth, energy_current + ): if tree_depth == 0: - return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current) + return self._build_basetree( + z, r, z_grads, log_slice, direction, energy_current + ) # build the first half of tree - half_tree = self._build_tree(z, r, z_grads, log_slice, - direction, tree_depth-1, energy_current) + half_tree = self._build_tree( + z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current + ) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads @@ -223,8 +278,9 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads - other_half_tree = self._build_tree(z, r, z_grads, log_slice, - direction, tree_depth-1, energy_current) + other_half_tree = self._build_tree( + z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current + ) if self.use_multinomial_sampling: tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight) @@ -232,8 +288,10 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals - r_sum = {site_names: half_tree.r_sum[site_names] + other_half_tree.r_sum[site_names] - for site_names in self.inverse_mass_matrix} + r_sum = { + site_names: half_tree.r_sum[site_names] + other_half_tree.r_sum[site_names] + for site_names in self.inverse_mass_matrix + } # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. @@ -243,10 +301,14 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). - other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 - else scalar_like(tree_weight, 0.)) - is_other_half_tree = pyro.sample("is_other_half_tree", - dist.Bernoulli(probs=other_half_tree_prob)) + other_half_tree_prob = ( + other_half_tree.weight / tree_weight + if tree_weight > 0 + else scalar_like(tree_weight, 0.0) + ) + is_other_half_tree = pyro.sample( + "is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob) + ) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal @@ -275,14 +337,32 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. - turning = other_half_tree.turning or self._is_turning(r_left_unscaled, r_right_unscaled, r_sum) + turning = other_half_tree.turning or self._is_turning( + r_left_unscaled, r_right_unscaled, r_sum + ) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging - return _TreeInfo(z_left, r_left, r_left_unscaled, z_left_grads, z_right, r_right, r_right_unscaled, - z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, - turning, diverging, sum_accept_probs, num_proposals) + return _TreeInfo( + z_left, + r_left, + r_left_unscaled, + z_left_grads, + z_right, + r_right, + r_right_unscaled, + z_right_grads, + z_proposal, + z_proposal_pe, + z_proposal_grads, + r_sum, + tree_weight, + turning, + diverging, + sum_accept_probs, + num_proposals, + ) def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() @@ -294,7 +374,7 @@ def sample(self, params): # return early if no sample sites elif len(z) == 0: self._t += 1 - self._mean_accept_prob = 1. + self._mean_accept_prob = 1.0 if self._t > self._warmup_steps: self._accept_cnt += 1 return z @@ -322,8 +402,10 @@ def sample(self, params): # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). - slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), - dist.Exponential(scalar_like(energy_current, 1.))) + slice_exp_term = pyro.sample( + "slicevar_exp_t={}".format(self._t), + dist.Exponential(scalar_like(energy_current, 1.0)), + ) log_slice = -energy_current - slice_exp_term z_left = z_right = z @@ -332,9 +414,11 @@ def sample(self, params): z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_unscaled - sum_accept_probs = 0. + sum_accept_probs = 0.0 num_proposals = 0 - tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) + tree_weight = scalar_like( + energy_current, 0.0 if self.use_multinomial_sampling else 1.0 + ) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. @@ -342,20 +426,38 @@ def sample(self, params): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: - direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), - dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) + direction = pyro.sample( + "direction_t={}_treedepth={}".format(self._t, tree_depth), + dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)), + ) direction = int(direction.item()) - if direction == 1: # go to the right, start from the right leaf of current tree - new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, - direction, tree_depth, energy_current) + if ( + direction == 1 + ): # go to the right, start from the right leaf of current tree + new_tree = self._build_tree( + z_right, + r_right, + z_right_grads, + log_slice, + direction, + tree_depth, + energy_current, + ) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right r_right_unscaled = new_tree.r_right_unscaled z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree - new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, - direction, tree_depth, energy_current) + new_tree = self._build_tree( + z_left, + r_left, + z_left_grads, + log_slice, + direction, + tree_depth, + energy_current, + ) z_left = new_tree.z_left r_left = new_tree.r_left r_left_unscaled = new_tree.r_left_unscaled @@ -379,18 +481,25 @@ def sample(self, params): new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight - rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), - dist.Uniform(scalar_like(new_tree_prob, 0.), - scalar_like(new_tree_prob, 1.))) + rand = pyro.sample( + "rand_t={}_treedepth={}".format(self._t, tree_depth), + dist.Uniform( + scalar_like(new_tree_prob, 0.0), scalar_like(new_tree_prob, 1.0) + ), + ) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal z_grads = new_tree.z_proposal_grads self._cache(z, new_tree.z_proposal_pe, z_grads) - r_sum = {site_names: r_sum[site_names] + new_tree.r_sum[site_names] - for site_names in r_unscaled} - if self._is_turning(r_left_unscaled, r_right_unscaled, r_sum): # stop doubling + r_sum = { + site_names: r_sum[site_names] + new_tree.r_sum[site_names] + for site_names in r_unscaled + } + if self._is_turning( + r_left_unscaled, r_right_unscaled, r_sum + ): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index ebe8350131..da1566fd7a 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -39,10 +39,8 @@ class TraceTreeEvaluator: :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. """ - def __init__(self, - model_trace, - has_enumerable_sites=False, - max_plate_nesting=None): + + def __init__(self, model_trace, has_enumerable_sites=False, max_plate_nesting=None): self.has_enumerable_sites = has_enumerable_sites self.max_plate_nesting = max_plate_nesting # To be populated using the model trace once. @@ -57,13 +55,15 @@ def _parse_model_structure(self, model_trace): if not self.has_enumerable_sites: return if self.max_plate_nesting is None: - raise ValueError("Finite value required for `max_plate_nesting` when model " - "has discrete (enumerable) sites.") + raise ValueError( + "Finite value required for `max_plate_nesting` when model " + "has discrete (enumerable) sites." + ) self._compute_log_prob_terms(model_trace) # 1. Infer model structure - compute parent-child relationship. sorted_ordinals = sorted(self._log_probs.keys()) for i, child_node in enumerate(sorted_ordinals): - for j in range(i-1, -1, -1): + for j in range(i - 1, -1, -1): cur_node = sorted_ordinals[j] if cur_node < child_node: self._children[cur_node].append(child_node) @@ -79,8 +79,13 @@ def _populate_cache(self, ordinal, parent_ordinal, parent_enum_dims): """ log_prob_shape = self._log_prob_shapes[ordinal] plate_dims = sorted([frame.dim for frame in ordinal - parent_ordinal]) - enum_dims = set((i for i in range(-len(log_prob_shape), -self.max_plate_nesting) - if log_prob_shape[i] > 1)) + enum_dims = set( + ( + i + for i in range(-len(log_prob_shape), -self.max_plate_nesting) + if log_prob_shape[i] > 1 + ) + ) self._plate_dims[ordinal] = plate_dims self._enum_dims[ordinal] = set(enum_dims - parent_enum_dims) for c in self._children[ordinal]: @@ -93,9 +98,11 @@ def _compute_log_prob_terms(self, model_trace): """ model_trace.compute_log_prob() self._log_probs = defaultdict(list) - ordering = {name: frozenset(site["cond_indep_stack"]) - for name, site in model_trace.nodes.items() - if site["type"] == "sample"} + ordering = { + name: frozenset(site["cond_indep_stack"]) + for name, site in model_trace.nodes.items() + if site["type"] == "sample" + } # Collect log prob terms per independence context. for name, site in model_trace.nodes.items(): if site["type"] == "sample": @@ -104,9 +111,11 @@ def _compute_log_prob_terms(self, model_trace): self._log_probs[ordering[name]].append(site["log_prob"]) if not self._log_prob_shapes: for ordinal, log_prob in self._log_probs.items(): - self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal])) + self._log_prob_shapes[ordinal] = broadcast_shape( + *(t.shape for t in self._log_probs[ordinal]) + ) - def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.)): + def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.0)): """ Reduce the log prob terms for the given ordinal: - taking log_sum_exp of factors in enum dims (i.e. @@ -164,10 +173,8 @@ class TraceEinsumEvaluator: :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. """ - def __init__(self, - model_trace, - has_enumerable_sites=False, - max_plate_nesting=None): + + def __init__(self, model_trace, has_enumerable_sites=False, max_plate_nesting=None): self.has_enumerable_sites = has_enumerable_sites self.max_plate_nesting = max_plate_nesting # To be populated using the model trace once. @@ -183,18 +190,24 @@ def _populate_cache(self, model_trace): if not self.has_enumerable_sites: return if self.max_plate_nesting is None: - raise ValueError("Finite value required for `max_plate_nesting` when model " - "has discrete (enumerable) sites.") + raise ValueError( + "Finite value required for `max_plate_nesting` when model " + "has discrete (enumerable) sites." + ) model_trace.compute_log_prob() model_trace.pack_tensors() for name, site in model_trace.nodes.items(): if site["type"] == "sample" and not isinstance(site["fn"], _Subsample): if is_validation_enabled(): check_site_shape(site, self.max_plate_nesting) - self.ordering[name] = frozenset(model_trace.plate_to_symbol[f.name] - for f in site["cond_indep_stack"] - if f.vectorized) - self._enum_dims = set(model_trace.symbol_to_dim) - set(model_trace.plate_to_symbol.values()) + self.ordering[name] = frozenset( + model_trace.plate_to_symbol[f.name] + for f in site["cond_indep_stack"] + if f.vectorized + ) + self._enum_dims = set(model_trace.symbol_to_dim) - set( + model_trace.plate_to_symbol.values() + ) def _get_log_factors(self, model_trace): """ @@ -209,7 +222,9 @@ def _get_log_factors(self, model_trace): if site["type"] == "sample" and not isinstance(site["fn"], _Subsample): if is_validation_enabled(): check_site_shape(site, self.max_plate_nesting) - log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"]) + log_probs.setdefault(self.ordering[name], []).append( + site["packed"]["log_prob"] + ) return log_probs def log_prob(self, model_trace): @@ -234,19 +249,22 @@ def _guess_max_plate_nesting(model, args, kwargs): """ with poutine.block(): model_trace = poutine.trace(model).get_trace(*args, **kwargs) - sites = [site for site in model_trace.nodes.values() - if site["type"] == "sample"] - - dims = [frame.dim - for site in sites - for frame in site["cond_indep_stack"] - if frame.vectorized] + sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"] + + dims = [ + frame.dim + for site in sites + for frame in site["cond_indep_stack"] + if frame.vectorized + ] max_plate_nesting = -min(dims) if dims else 0 return max_plate_nesting class _PEMaker: - def __init__(self, model, model_args, model_kwargs, trace_prob_evaluator, transforms): + def __init__( + self, model, model_args, model_kwargs, trace_prob_evaluator, transforms + ): self.model = model self.model_args = model_args self.model_kwargs = model_kwargs @@ -257,12 +275,14 @@ def __init__(self, model, model_args, model_kwargs, trace_prob_evaluator, transf def _potential_fn(self, params): params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()} cond_model = poutine.condition(self.model, params_constrained) - model_trace = poutine.trace(cond_model).get_trace(*self.model_args, - **self.model_kwargs) + model_trace = poutine.trace(cond_model).get_trace( + *self.model_args, **self.model_kwargs + ) log_joint = self.trace_prob_evaluator.log_prob(model_trace) for name, t in self.transforms.items(): log_joint = log_joint - torch.sum( - t.log_abs_det_jacobian(params_constrained[name], params[name])) + t.log_abs_det_jacobian(params_constrained[name], params[name]) + ) return -log_joint def _potential_fn_jit(self, skip_jit_warnings, jit_options, params): @@ -293,16 +313,27 @@ def _pe_jit(*zi): v.requires_grad_(True) return result - def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_options=None): + def get_potential_fn( + self, jit_compile=False, skip_jit_warnings=True, jit_options=None + ): if jit_compile: jit_options = {"check_trace": False} if jit_options is None else jit_options return partial(self._potential_fn_jit, skip_jit_warnings, jit_options) return self._potential_fn -def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, - prototype_params, max_tries_initial_params=100, num_chains=1, - init_strategy=init_to_uniform, trace=None): +def _find_valid_initial_params( + model, + model_args, + model_kwargs, + transforms, + potential_fn, + prototype_params, + max_tries_initial_params=100, + num_chains=1, + init_strategy=init_to_uniform, + trace=None, +): params = prototype_params # For empty models, exit early @@ -319,7 +350,9 @@ def _find_valid_initial_params(model, model_args, model_kwargs, transforms, pote params = {k: transforms[k](v) for k, v in samples.items()} pe_grad, pe = potential_grad(potential_fn, params) - if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))): + if torch.isfinite(pe) and all( + map(torch.all, map(torch.isfinite, pe_grad.values())) + ): for k, v in params.items(): params_per_chain[k].append(v) num_found += 1 @@ -329,12 +362,24 @@ def _find_valid_initial_params(model, model_args, model_kwargs, transforms, pote else: return {k: torch.stack(v) for k, v in params_per_chain.items()} trace = None - raise ValueError("Model specification seems incorrect - cannot find valid initial params.") - - -def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None, - jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1, - init_strategy=init_to_uniform, initial_params=None): + raise ValueError( + "Model specification seems incorrect - cannot find valid initial params." + ) + + +def initialize_model( + model, + model_args=(), + model_kwargs={}, + transforms=None, + max_plate_nesting=None, + jit_compile=False, + jit_options=None, + skip_jit_warnings=False, + num_chains=1, + init_strategy=init_to_uniform, + initial_params=None, +): """ Given a Python callable with Pyro primitives, generates the following model-specific properties needed for inference using HMC/NUTS kernels: @@ -382,8 +427,9 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. - model = poutine.enum(config_enumerate(model), - first_available_dim=-1 - max_plate_nesting) + model = poutine.enum( + config_enumerate(model), first_available_dim=-1 - max_plate_nesting + ) prototype_model = poutine.trace(InitMessenger(init_strategy)(model)) model_trace = prototype_model.get_trace(*model_args, **model_kwargs) has_enumerable_sites = False @@ -392,7 +438,9 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max fn = node["fn"] if isinstance(fn, _Subsample): if fn.subsample_size is not None and fn.subsample_size < fn.size: - raise NotImplementedError("HMC/NUTS does not support model with subsample sites.") + raise NotImplementedError( + "HMC/NUTS does not support model with subsample sites." + ) continue if node["fn"].has_enumerate_support: has_enumerable_sites = True @@ -404,22 +452,33 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max if automatic_transform_enabled: transforms[name] = biject_to(node["fn"].support).inv - trace_prob_evaluator = TraceEinsumEvaluator(model_trace, - has_enumerable_sites, - max_plate_nesting) + trace_prob_evaluator = TraceEinsumEvaluator( + model_trace, has_enumerable_sites, max_plate_nesting + ) - pe_maker = _PEMaker(model, model_args, model_kwargs, trace_prob_evaluator, transforms) + pe_maker = _PEMaker( + model, model_args, model_kwargs, trace_prob_evaluator, transforms + ) if initial_params is None: prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()} # Note that we deliberately do not exercise jit compilation here so as to # enable potential_fn to be picklable (a torch._C.Function cannot be pickled). # We pass model_trace merely for computational savings. - initial_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms, - pe_maker.get_potential_fn(), prototype_params, - num_chains=num_chains, init_strategy=init_strategy, - trace=model_trace) - potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options) + initial_params = _find_valid_initial_params( + model, + model_args, + model_kwargs, + transforms, + pe_maker.get_potential_fn(), + prototype_params, + num_chains=num_chains, + init_strategy=init_strategy, + trace=model_trace, + ) + potential_fn = pe_maker.get_potential_fn( + jit_compile, skip_jit_warnings, jit_options + ) return initial_params, potential_fn, transforms, model_trace @@ -430,14 +489,16 @@ def _safe(fn): :param fn: stats function from :mod:`pyro.ops.stats` module. """ + @functools.wraps(fn) def wrapped(sample, *args, **kwargs): try: val = fn(sample, *args, **kwargs) except Exception: warnings.warn(tb.format_exc()) - val = torch.full(sample.shape[2:], float("nan"), - dtype=sample.dtype, device=sample.device) + val = torch.full( + sample.shape[2:], float("nan"), dtype=sample.dtype, device=sample.device + ) return val return wrapped @@ -493,11 +554,19 @@ def summary(samples, prob=0.9, group_by_chain=True): hpdi = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) - hpd_lower = '{:.1f}%'.format(50 * (1 - prob)) - hpd_upper = '{:.1f}%'.format(50 * (1 + prob)) - summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median), - (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]), - ("n_eff", n_eff), ("r_hat", r_hat)]) + hpd_lower = "{:.1f}%".format(50 * (1 - prob)) + hpd_upper = "{:.1f}%".format(50 * (1 + prob)) + summary_dict[name] = OrderedDict( + [ + ("mean", mean), + ("std", std), + ("median", median), + (hpd_lower, hpdi[0]), + (hpd_upper, hpdi[1]), + ("n_eff", n_eff), + ("r_hat", r_hat), + ] + ) return summary_dict @@ -519,41 +588,63 @@ def print_summary(samples, prob=0.9, group_by_chain=True): return summary_dict = summary(samples, prob, group_by_chain) - row_names = {k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']' - for k, v in samples.items()} + row_names = { + k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]" + for k, v in samples.items() + } max_len = max(max(map(lambda x: len(x), row_names.values())), 10) - name_format = '{:>' + str(max_len) + '}' - header_format = name_format + ' {:>9}' * 7 - columns = [''] + list(list(summary_dict.values())[0].keys()) + name_format = "{:>" + str(max_len) + "}" + header_format = name_format + " {:>9}" * 7 + columns = [""] + list(list(summary_dict.values())[0].keys()) print() print(header_format.format(*columns)) - row_format = name_format + ' {:>9.2f}' * 7 + row_format = name_format + " {:>9.2f}" * 7 for name, stats_dict in summary_dict.items(): shape = stats_dict["mean"].shape if len(shape) == 0: print(row_format.format(name, *stats_dict.values())) else: for idx in product(*map(range, shape)): - idx_str = '[{}]'.format(','.join(map(str, idx))) - print(row_format.format(name + idx_str, *[v[idx] for v in stats_dict.values()])) + idx_str = "[{}]".format(",".join(map(str, idx))) + print( + row_format.format( + name + idx_str, *[v[idx] for v in stats_dict.values()] + ) + ) print() -def _predictive_sequential(model, posterior_samples, model_args, model_kwargs, - num_samples, sample_sites, return_trace=False): +def _predictive_sequential( + model, + posterior_samples, + model_args, + model_kwargs, + num_samples, + sample_sites, + return_trace=False, +): collected = [] - samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)] + samples = [ + {k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples) + ] for i in range(num_samples): - trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs) + trace = poutine.trace(poutine.condition(model, samples[i])).get_trace( + *model_args, **model_kwargs + ) if return_trace: collected.append(trace) else: - collected.append({site: trace.nodes[site]['value'] for site in sample_sites}) + collected.append( + {site: trace.nodes[site]["value"] for site in sample_sites} + ) - return collected if return_trace else {site: torch.stack([s[site] for s in collected]) - for site in sample_sites} + return ( + collected + if return_trace + else {site: torch.stack([s[site] for s in collected]) for site in sample_sites} + ) def predictive(model, posterior_samples, *args, **kwargs): @@ -587,13 +678,15 @@ def predictive(model, posterior_samples, *args, **kwargs): :return: dict of samples from the predictive distribution, or a single vectorized `trace` (if `return_trace=True`). """ - warnings.warn('The `mcmc.predictive` function is deprecated and will be removed in ' - 'a future release. Use the `pyro.infer.Predictive` class instead.', - FutureWarning) - num_samples = kwargs.pop('num_samples', None) - return_sites = kwargs.pop('return_sites', None) - return_trace = kwargs.pop('return_trace', False) - parallel = kwargs.pop('parallel', False) + warnings.warn( + "The `mcmc.predictive` function is deprecated and will be removed in " + "a future release. Use the `pyro.infer.Predictive` class instead.", + FutureWarning, + ) + num_samples = kwargs.pop("num_samples", None) + return_sites = kwargs.pop("return_sites", None) + return_trace = kwargs.pop("return_trace", False) + parallel = kwargs.pop("parallel", False) max_plate_nesting = _guess_max_plate_nesting(model, args, kwargs) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*args, **kwargs)) @@ -607,12 +700,20 @@ def predictive(model, posterior_samples, *args, **kwargs): num_samples = batch_size elif num_samples != batch_size: - warnings.warn("Sample's leading dimension size {} is different from the " - "provided {} num_samples argument. Defaulting to {}." - .format(batch_size, num_samples, batch_size), UserWarning) + warnings.warn( + "Sample's leading dimension size {} is different from the " + "provided {} num_samples argument. Defaulting to {}.".format( + batch_size, num_samples, batch_size + ), + UserWarning, + ) num_samples = batch_size - sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) + sample = sample.reshape( + (num_samples,) + + (1,) * (max_plate_nesting - len(sample_shape)) + + sample_shape + ) reshaped_samples[name] = sample if num_samples is None: @@ -620,7 +721,7 @@ def predictive(model, posterior_samples, *args, **kwargs): return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: - site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape + site_shape = (num_samples,) + model_trace.nodes[site]["value"].shape if return_sites: if site in return_sites: return_site_shapes[site] = site_shape @@ -629,8 +730,15 @@ def predictive(model, posterior_samples, *args, **kwargs): return_site_shapes[site] = site_shape if not parallel: - return _predictive_sequential(model, posterior_samples, args, kwargs, num_samples, - return_site_shapes.keys(), return_trace) + return _predictive_sequential( + model, + posterior_samples, + args, + kwargs, + num_samples, + return_site_shapes.keys(), + return_trace, + ) def _vectorized_fn(fn): """ @@ -642,20 +750,23 @@ def _vectorized_fn(fn): """ def wrapped_fn(*args, **kwargs): - with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1): + with pyro.plate( + "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1 + ): return fn(*args, **kwargs) return wrapped_fn - trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ - .get_trace(*args, **kwargs) + trace = poutine.trace( + poutine.condition(_vectorized_fn(model), reshaped_samples) + ).get_trace(*args, **kwargs) if return_trace: return trace predictions = {} for site, shape in return_site_shapes.items(): - value = trace.nodes[site]['value'] + value = trace.nodes[site]["value"] if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: @@ -709,10 +820,10 @@ def diagnostics_from_stats(statistics, num_samples, num_chains): for (_, name), stat in statistics.items(): if name in mean_var_dict: mean, var = mean_var_dict[name] - mean.append(stat['mean']) - var.append(stat['variance']) - elif 'mean' in stat and 'variance' in stat: - mean_var_dict[name] = ([stat['mean']], [stat['variance']]) + mean.append(stat["mean"]) + var.append(stat["variance"]) + elif "mean" in stat and "variance" in stat: + mean_var_dict[name] = ([stat["mean"]], [stat["variance"]]) for name, (m, v) in mean_var_dict.items(): mean_var_dict[name] = (torch.stack(m), torch.stack(v)) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index dbeee300a3..9d8b1c7f76 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -19,57 +19,92 @@ def _guess_max_plate_nesting(model, args, kwargs): """ with poutine.block(): model_trace = poutine.trace(model).get_trace(*args, **kwargs) - sites = [site for site in model_trace.nodes.values() - if site["type"] == "sample"] - - dims = [frame.dim - for site in sites - for frame in site["cond_indep_stack"] - if frame.vectorized] + sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"] + + dims = [ + frame.dim + for site in sites + for frame in site["cond_indep_stack"] + if frame.vectorized + ] max_plate_nesting = -min(dims) if dims else 0 return max_plate_nesting -def _predictive_sequential(model, posterior_samples, model_args, model_kwargs, - num_samples, return_site_shapes, return_trace=False): +def _predictive_sequential( + model, + posterior_samples, + model_args, + model_kwargs, + num_samples, + return_site_shapes, + return_trace=False, +): collected = [] - samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)] + samples = [ + {k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples) + ] for i in range(num_samples): - trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs) + trace = poutine.trace(poutine.condition(model, samples[i])).get_trace( + *model_args, **model_kwargs + ) if return_trace: collected.append(trace) else: - collected.append({site: trace.nodes[site]['value'] for site in return_site_shapes}) + collected.append( + {site: trace.nodes[site]["value"] for site in return_site_shapes} + ) if return_trace: return collected else: - return {site: torch.stack([s[site] for s in collected]).reshape(shape) - for site, shape in return_site_shapes.items()} - - -def _predictive(model, posterior_samples, num_samples, return_sites=(), - return_trace=False, parallel=False, model_args=(), model_kwargs={}): + return { + site: torch.stack([s[site] for s in collected]).reshape(shape) + for site, shape in return_site_shapes.items() + } + + +def _predictive( + model, + posterior_samples, + num_samples, + return_sites=(), + return_trace=False, + parallel=False, + model_args=(), + model_kwargs={}, +): model = torch.no_grad()(poutine.mask(model, mask=False)) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) - vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1) - model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs)) + vectorize = pyro.plate( + "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1 + ) + model_trace = prune_subsample_sites( + poutine.trace(model).get_trace(*model_args, **model_kwargs) + ) reshaped_samples = {} for name, sample in posterior_samples.items(): sample_shape = sample.shape[1:] - sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) + sample = sample.reshape( + (num_samples,) + + (1,) * (max_plate_nesting - len(sample_shape)) + + sample_shape + ) reshaped_samples[name] = sample if return_trace: - trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ - .get_trace(*model_args, **model_kwargs) + trace = poutine.trace( + poutine.condition(vectorize(model), reshaped_samples) + ).get_trace(*model_args, **model_kwargs) return trace return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape) - site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape + site_shape = ( + (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]["value"].shape + ) # non-empty return-sites if return_sites: if site in return_sites: @@ -83,21 +118,29 @@ def _predictive(model, posterior_samples, num_samples, return_sites=(), return_site_shapes[site] = site_shape # handle _RETURN site - if return_sites is not None and '_RETURN' in return_sites: - value = model_trace.nodes['_RETURN']['value'] + if return_sites is not None and "_RETURN" in return_sites: + value = model_trace.nodes["_RETURN"]["value"] shape = (num_samples,) + value.shape if torch.is_tensor(value) else None - return_site_shapes['_RETURN'] = shape + return_site_shapes["_RETURN"] = shape if not parallel: - return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, - return_site_shapes, return_trace=False) - - trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ - .get_trace(*model_args, **model_kwargs) + return _predictive_sequential( + model, + posterior_samples, + model_args, + model_kwargs, + num_samples, + return_site_shapes, + return_trace=False, + ) + + trace = poutine.trace( + poutine.condition(vectorize(model), reshaped_samples) + ).get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): - value = trace.nodes[site]['value'] - if site == '_RETURN' and shape is None: + value = trace.nodes[site]["value"] + if site == "_RETURN" and shape is None: predictions[site] = value continue if value.numel() < reduce((lambda x, y: x * y), shape): @@ -133,12 +176,22 @@ class Predictive(torch.nn.Module): in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. """ - def __init__(self, model, posterior_samples=None, guide=None, num_samples=None, - return_sites=(), parallel=False): + + def __init__( + self, + model, + posterior_samples=None, + guide=None, + num_samples=None, + return_sites=(), + parallel=False, + ): super().__init__() if posterior_samples is None: if num_samples is None: - raise ValueError("Either posterior_samples or num_samples must be specified.") + raise ValueError( + "Either posterior_samples or num_samples must be specified." + ) posterior_samples = {} for name, sample in posterior_samples.items(): @@ -146,16 +199,24 @@ def __init__(self, model, posterior_samples=None, guide=None, num_samples=None, if num_samples is None: num_samples = batch_size elif num_samples != batch_size: - warnings.warn("Sample's leading dimension size {} is different from the " - "provided {} num_samples argument. Defaulting to {}." - .format(batch_size, num_samples, batch_size), UserWarning) + warnings.warn( + "Sample's leading dimension size {} is different from the " + "provided {} num_samples argument. Defaulting to {}.".format( + batch_size, num_samples, batch_size + ), + UserWarning, + ) num_samples = batch_size if num_samples is None: - raise ValueError("No sample sites in posterior samples to infer `num_samples`.") + raise ValueError( + "No sample sites in posterior samples to infer `num_samples`." + ) if guide is not None and posterior_samples: - raise ValueError("`posterior_samples` cannot be provided with the `guide` argument.") + raise ValueError( + "`posterior_samples` cannot be provided with the `guide` argument." + ) if return_sites is not None: assert isinstance(return_sites, (list, tuple, set)) @@ -200,14 +261,30 @@ def forward(self, *args, **kwargs): if self.guide is not None: # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites - posterior_samples = _predictive(self.guide, posterior_samples, self.num_samples, return_sites=None, - parallel=self.parallel, model_args=args, model_kwargs=kwargs) - return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites, - parallel=self.parallel, model_args=args, model_kwargs=kwargs) + posterior_samples = _predictive( + self.guide, + posterior_samples, + self.num_samples, + return_sites=None, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ) + return _predictive( + self.model, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ) def get_samples(self, *args, **kwargs): - warnings.warn("The method `.get_samples` has been deprecated in favor of `.forward`.", - DeprecationWarning) + warnings.warn( + "The method `.get_samples` has been deprecated in favor of `.forward`.", + DeprecationWarning, + ) return self.forward(*args, **kwargs) def get_vectorized_trace(self, *args, **kwargs): @@ -220,7 +297,19 @@ def get_vectorized_trace(self, *args, **kwargs): """ posterior_samples = self.posterior_samples if self.guide is not None: - posterior_samples = _predictive(self.guide, posterior_samples, self.num_samples, - parallel=self.parallel, model_args=args, model_kwargs=kwargs) - return _predictive(self.model, posterior_samples, self.num_samples, - return_trace=True, model_args=args, model_kwargs=kwargs) + posterior_samples = _predictive( + self.guide, + posterior_samples, + self.num_samples, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ) + return _predictive( + self.model, + posterior_samples, + self.num_samples, + return_trace=True, + model_args=args, + model_kwargs=kwargs, + ) diff --git a/pyro/infer/renyi_elbo.py b/pyro/infer/renyi_elbo.py index 3aac4854ba..349f7c43d4 100644 --- a/pyro/infer/renyi_elbo.py +++ b/pyro/infer/renyi_elbo.py @@ -49,26 +49,34 @@ class RenyiELBO(ELBO): Yuri Burda, Roger Grosse, Ruslan Salakhutdinov """ - def __init__(self, - alpha=0, - num_particles=2, - max_plate_nesting=float('inf'), - max_iarange_nesting=None, # DEPRECATED - vectorize_particles=False, - strict_enumeration_warning=True): + def __init__( + self, + alpha=0, + num_particles=2, + max_plate_nesting=float("inf"), + max_iarange_nesting=None, # DEPRECATED + vectorize_particles=False, + strict_enumeration_warning=True, + ): if max_iarange_nesting is not None: - warnings.warn("max_iarange_nesting is deprecated; use max_plate_nesting instead", - DeprecationWarning) + warnings.warn( + "max_iarange_nesting is deprecated; use max_plate_nesting instead", + DeprecationWarning, + ) max_plate_nesting = max_iarange_nesting if alpha == 1: - raise ValueError("The order alpha should not be equal to 1. Please use Trace_ELBO class" - "for the case alpha = 1.") + raise ValueError( + "The order alpha should not be equal to 1. Please use Trace_ELBO class" + "for the case alpha = 1." + ) self.alpha = alpha - super().__init__(num_particles=num_particles, - max_plate_nesting=max_plate_nesting, - vectorize_particles=vectorize_particles, - strict_enumeration_warning=strict_enumeration_warning) + super().__init__( + num_particles=num_particles, + max_plate_nesting=max_plate_nesting, + vectorize_particles=vectorize_particles, + strict_enumeration_warning=strict_enumeration_warning, + ) def _get_trace(self, model, guide, args, kwargs): """ @@ -76,7 +84,8 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "flat", self.max_plate_nesting, model, guide, args, kwargs) + "flat", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace @@ -94,7 +103,7 @@ def loss(self, model, guide, *args, **kwargs): # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - elbo_particle = 0. + elbo_particle = 0.0 sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) # compute elbo @@ -116,9 +125,11 @@ def loss(self, model, guide, *args, **kwargs): else: elbo_particles = torch.stack(elbo_particles) - log_weights = (1. - self.alpha) * elbo_particles - log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log(self.num_particles) - elbo = log_mean_weight.sum().item() / (1. - self.alpha) + log_weights = (1.0 - self.alpha) * elbo_particles + log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log( + self.num_particles + ) + elbo = log_mean_weight.sum().item() / (1.0 - self.alpha) loss = -elbo warn_if_nan(loss, "loss") @@ -165,8 +176,10 @@ def loss_and_grads(self, model, guide, *args, **kwargs): raise NotImplementedError if not is_identically_zero(score_function_term): - surrogate_elbo_particle = (surrogate_elbo_particle + - (self.alpha / (1. - self.alpha)) * log_prob_sum) + surrogate_elbo_particle = ( + surrogate_elbo_particle + + (self.alpha / (1.0 - self.alpha)) * log_prob_sum + ) if is_identically_zero(elbo_particle): if tensor_holder is not None: @@ -184,7 +197,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: - return 0. + return 0.0 if is_vectorized: elbo_particles = elbo_particles[0] @@ -193,18 +206,26 @@ def loss_and_grads(self, model, guide, *args, **kwargs): elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) - log_weights = (1. - self.alpha) * elbo_particles - log_mean_weight = torch.logsumexp(log_weights, dim=0, keepdim=True) - math.log(self.num_particles) - elbo = log_mean_weight.sum().item() / (1. - self.alpha) + log_weights = (1.0 - self.alpha) * elbo_particles + log_mean_weight = torch.logsumexp(log_weights, dim=0, keepdim=True) - math.log( + self.num_particles + ) + elbo = log_mean_weight.sum().item() / (1.0 - self.alpha) # collect parameters to train from model and guide - trainable_params = any(site["type"] == "param" - for trace in (model_trace, guide_trace) - for site in trace.nodes.values()) - - if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False): + trainable_params = any( + site["type"] == "param" + for trace in (model_trace, guide_trace) + for site in trace.nodes.values() + ) + + if trainable_params and getattr( + surrogate_elbo_particles, "requires_grad", False + ): normalized_weights = (log_weights - log_mean_weight).exp() - surrogate_elbo = (normalized_weights * surrogate_elbo_particles).sum() / self.num_particles + surrogate_elbo = ( + normalized_weights * surrogate_elbo_particles + ).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo diff --git a/pyro/infer/reparam/conjugate.py b/pyro/infer/reparam/conjugate.py index 8de8d1c6cf..478850affe 100644 --- a/pyro/infer/reparam/conjugate.py +++ b/pyro/infer/reparam/conjugate.py @@ -47,6 +47,7 @@ def reparam_guide(): implementation. :type guide: ~pyro.distributions.Distribution or callable """ + def __init__(self, guide): self.guide = guide @@ -69,8 +70,10 @@ def apply(self, msg): if not fn.has_rsample: # Note supporting non-reparameterized sites would require more delicate # handling of traced sites than the crude _do_not_trace flag below. - raise NotImplementedError("ConjugateReparam inference supports only reparameterized " - "distributions, but got {}".format(type(fn))) + raise NotImplementedError( + "ConjugateReparam inference supports only reparameterized " + "distributions, but got {}".format(type(fn)) + ) value = pyro.sample( f"{name}_updated", fn, diff --git a/pyro/infer/reparam/discrete_cosine.py b/pyro/infer/reparam/discrete_cosine.py index 1cdbd68a48..60d67d1445 100644 --- a/pyro/infer/reparam/discrete_cosine.py +++ b/pyro/infer/reparam/discrete_cosine.py @@ -34,8 +34,9 @@ class DiscreteCosineReparam(UnitJacobianReparam): batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False. """ - def __init__(self, dim=-1, smooth=0., *, - experimental_allow_batch=False): + + def __init__(self, dim=-1, smooth=0.0, *, experimental_allow_batch=False): transform = DiscreteCosineTransform(dim=dim, smooth=smooth, cache_size=1) - super().__init__(transform, suffix="dct", - experimental_allow_batch=experimental_allow_batch) + super().__init__( + transform, suffix="dct", experimental_allow_batch=experimental_allow_batch + ) diff --git a/pyro/infer/reparam/haar.py b/pyro/infer/reparam/haar.py index 182c2c519c..9ff05760de 100644 --- a/pyro/infer/reparam/haar.py +++ b/pyro/infer/reparam/haar.py @@ -28,8 +28,9 @@ class HaarReparam(UnitJacobianReparam): batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False. """ - def __init__(self, dim=-1, flip=False, *, - experimental_allow_batch=False): + + def __init__(self, dim=-1, flip=False, *, experimental_allow_batch=False): transform = HaarTransform(dim=dim, flip=flip, cache_size=1) - super().__init__(transform, suffix="haar", - experimental_allow_batch=experimental_allow_batch) + super().__init__( + transform, suffix="haar", experimental_allow_batch=experimental_allow_batch + ) diff --git a/pyro/infer/reparam/hmm.py b/pyro/infer/reparam/hmm.py index 12aab10a0c..740b40ea3c 100644 --- a/pyro/infer/reparam/hmm.py +++ b/pyro/infer/reparam/hmm.py @@ -52,6 +52,7 @@ class LinearHMMReparam(Reparam): :param obs: Optional reparameterizer for the observation distribution. :type obs: ~pyro.infer.reparam.reparam.Reparam """ + def __init__(self, init=None, trans=None, obs=None): assert init is None or isinstance(init, Reparam) assert trans is None or isinstance(trans, Reparam) @@ -69,20 +70,24 @@ def apply(self, msg): fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: - raise ValueError("LinearHMMReparam requires duration to be specified " - "on targeted LinearHMM distributions") + raise ValueError( + "LinearHMMReparam requires duration to be specified " + "on targeted LinearHMM distributions" + ) # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): indep_value = None if value is not None: indep_value = value.transpose(-1, -2).unsqueeze(-1) - msg = self.apply({ - "name": name, - "fn": fn.base_dist.to_event(1), - "value": indep_value, - "is_observed": is_observed, - }) + msg = self.apply( + { + "name": name, + "fn": fn.base_dist.to_event(1), + "value": indep_value, + "is_observed": is_observed, + } + ) hmm = msg["fn"] hmm = dist.IndependentHMM(hmm.to_event(-1)) if msg["value"] is not indep_value: @@ -92,12 +97,14 @@ def apply(self, msg): # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: - msg = self.init.apply({ - "name": f"{name}_init", - "fn": self._wrap(init_dist, event_dim - 1), - "value": None, - "is_observed": False, - }) + msg = self.init.apply( + { + "name": f"{name}_init", + "fn": self._wrap(init_dist, event_dim - 1), + "value": None, + "is_observed": False, + } + ) init_dist = msg["fn"] init_dist = init_dist.to_event(1 - init_dist.event_dim) @@ -108,12 +115,14 @@ def apply(self, msg): trans_dist = trans_dist.expand( trans_dist.batch_shape[:-1] + (fn.duration,) ) - msg = self.trans.apply({ - "name": f"{name}_trans", - "fn": self._wrap(trans_dist, event_dim), - "value": None, - "is_observed": False, - }) + msg = self.trans.apply( + { + "name": f"{name}_trans", + "fn": self._wrap(trans_dist, event_dim), + "value": None, + "is_observed": False, + } + ) trans_dist = msg["fn"] trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) @@ -122,20 +131,28 @@ def apply(self, msg): if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration,)) - msg = self.obs.apply({ - "name": f"{name}_obs", - "fn": self._wrap(obs_dist, event_dim), - "value": value, - "is_observed": is_observed, - }) + msg = self.obs.apply( + { + "name": f"{name}_obs", + "fn": self._wrap(obs_dist, event_dim), + "value": value, + "is_observed": is_observed, + } + ) obs_dist = msg["fn"] obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) value = msg["value"] is_observed = msg["is_observed"] # Reparameterize the entire HMM as conditionally Gaussian. - hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, - fn.observation_matrix, obs_dist, duration=fn.duration) + hmm = dist.GaussianHMM( + init_dist, + fn.transition_matrix, + trans_dist, + fn.observation_matrix, + obs_dist, + duration=fn.duration, + ) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. diff --git a/pyro/infer/reparam/loc_scale.py b/pyro/infer/reparam/loc_scale.py index d9dcb81e4d..232a146517 100644 --- a/pyro/infer/reparam/loc_scale.py +++ b/pyro/infer/reparam/loc_scale.py @@ -31,6 +31,7 @@ class LocScaleReparam(Reparam): all params in a distributions ``.arg_constraints`` will be copied. :type shape_params: tuple or list """ + def __init__(self, centered=None, shape_params=None): assert centered is None or isinstance(centered, (float, torch.Tensor)) if shape_params is not None: @@ -66,9 +67,11 @@ def apply(self, msg): ) params = {key: getattr(fn, key) for key in self.shape_params} if centered is None: - centered = pyro.param("{}_centered".format(name), - lambda: fn.loc.new_full(event_shape, 0.5), - constraint=constraints.unit_interval) + centered = pyro.param( + "{}_centered".format(name), + lambda: fn.loc.new_full(event_shape, 0.5), + constraint=constraints.unit_interval, + ) params["loc"] = fn.loc * centered params["scale"] = fn.scale ** centered decentered_fn = type(fn)(**params) diff --git a/pyro/infer/reparam/neutra.py b/pyro/infer/reparam/neutra.py index 4e0cffb8d9..e3df28a62d 100644 --- a/pyro/infer/reparam/neutra.py +++ b/pyro/infer/reparam/neutra.py @@ -44,10 +44,14 @@ class NeuTraReparam(Reparam): :param ~pyro.infer.autoguide.AutoContinuous guide: A trained guide. """ + def __init__(self, guide): if not isinstance(guide, AutoContinuous): - raise TypeError("NeuTraReparam expected an AutoContinuous guide, but got {}" - .format(type(guide))) + raise TypeError( + "NeuTraReparam expected an AutoContinuous guide, but got {}".format( + type(guide) + ) + ) self.guide = guide self.transform = None self.x_unconstrained = {} @@ -74,29 +78,36 @@ def apply(self, msg): ) log_density = 0.0 - compute_density = (poutine.get_mask() is not False) + compute_density = poutine.get_mask() is not False if name not in self.x_unconstrained: # On first sample site. # Sample a shared latent. try: self.transform = self.guide.get_transform() except (NotImplementedError, TypeError) as e: - raise ValueError("NeuTraReparam only supports guides that implement " - "`get_transform` method that does not depend on the " - "model's `*args, **kwargs`") from e + raise ValueError( + "NeuTraReparam only supports guides that implement " + "`get_transform` method that does not depend on the " + "model's `*args, **kwargs`" + ) from e with ExitStack() as stack: for plate in self.guide.plates.values(): stack.enter_context(block_plate(dim=plate.dim, strict=False)) - z_unconstrained = pyro.sample(f"{name}_shared_latent", - self.guide.get_base_dist().mask(False)) + z_unconstrained = pyro.sample( + f"{name}_shared_latent", self.guide.get_base_dist().mask(False) + ) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: - log_density = self.transform.log_abs_det_jacobian(z_unconstrained, x_unconstrained) + log_density = self.transform.log_abs_det_jacobian( + z_unconstrained, x_unconstrained + ) self.x_unconstrained = { site["name"]: (site, unconstrained_value) - for site, unconstrained_value in self.guide._unpack_latent(x_unconstrained) + for site, unconstrained_value in self.guide._unpack_latent( + x_unconstrained + ) } # Extract a single site's value from the shared latent. diff --git a/pyro/infer/reparam/reparam.py b/pyro/infer/reparam/reparam.py index 411e1d7262..16794222ba 100644 --- a/pyro/infer/reparam/reparam.py +++ b/pyro/infer/reparam/reparam.py @@ -14,6 +14,7 @@ def TypedDict(*args, **kwargs): return dict + ReparamMessage = TypedDict( "ReparamMessage", name=str, diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index 6e65c244bd..d5a389bc0e 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -29,6 +29,7 @@ class SplitReparam(Reparam): :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. """ + def __init__(self, sections, dim): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) @@ -51,7 +52,7 @@ def apply(self, msg): # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] - right_shape = fn.event_shape[1 + dim:] + right_shape = fn.event_shape[1 + dim :] for i, size in enumerate(self.sections): event_shape = left_shape + (size,) + right_shape value_split[i] = pyro.sample( diff --git a/pyro/infer/reparam/stable.py b/pyro/infer/reparam/stable.py index 827586d6a4..e0cb4cd490 100644 --- a/pyro/infer/reparam/stable.py +++ b/pyro/infer/reparam/stable.py @@ -36,6 +36,7 @@ class LatentStableReparam(Reparam): Stable Distributions: Models for Heavy Tailed Data. http://fs2.american.edu/jpnolan/www/stable/chap1.pdf """ + def apply(self, msg): name = msg["name"] fn = msg["fn"] @@ -54,10 +55,13 @@ def apply(self, msg): proto = fn.stability half_pi = proto.new_tensor(math.pi / 2) one = proto.new_ones(proto.shape) - u = pyro.sample("{}_uniform".format(name), - self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim)) - e = pyro.sample("{}_exponential".format(name), - self._wrap(dist.Exponential(one), event_dim)) + u = pyro.sample( + "{}_uniform".format(name), + self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim), + ) + e = pyro.sample( + "{}_exponential".format(name), self._wrap(dist.Exponential(one), event_dim) + ) # Differentiably transform. x = _standard_stable(fn.stability, fn.skew, u, e, coords="S0") @@ -89,6 +93,7 @@ class SymmetricStableReparam(Reparam): "Option Pricing with Levy-Stable Processes" https://pdfs.semanticscholar.org/4d66/c91b136b2a38117dd16c2693679f5341c616.pdf """ + def apply(self, msg): name = msg["name"] fn = msg["fn"] @@ -107,10 +112,13 @@ def apply(self, msg): proto = fn.stability half_pi = proto.new_tensor(math.pi / 2) one = proto.new_ones(proto.shape) - u = pyro.sample("{}_uniform".format(name), - self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim)) - e = pyro.sample("{}_exponential".format(name), - self._wrap(dist.Exponential(one), event_dim)) + u = pyro.sample( + "{}_uniform".format(name), + self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim), + ) + e = pyro.sample( + "{}_exponential".format(name), self._wrap(dist.Exponential(one), event_dim) + ) # Differentiably transform to scale drawn from a totally-skewed stable variable. a = fn.stability @@ -183,14 +191,22 @@ def apply(self, msg): proto = fn.stability half_pi = proto.new_tensor(math.pi / 2) one = proto.new_ones(proto.shape) - zu = pyro.sample("{}_z_uniform".format(name), - self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim)) - ze = pyro.sample("{}_z_exponential".format(name), - self._wrap(dist.Exponential(one), event_dim)) - tu = pyro.sample("{}_t_uniform".format(name), - self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim)) - te = pyro.sample("{}_t_exponential".format(name), - self._wrap(dist.Exponential(one), event_dim)) + zu = pyro.sample( + "{}_z_uniform".format(name), + self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim), + ) + ze = pyro.sample( + "{}_z_exponential".format(name), + self._wrap(dist.Exponential(one), event_dim), + ) + tu = pyro.sample( + "{}_t_uniform".format(name), + self._wrap(dist.Uniform(-half_pi, half_pi).expand(proto.shape), event_dim), + ) + te = pyro.sample( + "{}_t_exponential".format(name), + self._wrap(dist.Exponential(one), event_dim), + ) # Differentiably transform. a = fn.stability diff --git a/pyro/infer/reparam/studentt.py b/pyro/infer/reparam/studentt.py index ee9d27789a..01751827c0 100644 --- a/pyro/infer/reparam/studentt.py +++ b/pyro/infer/reparam/studentt.py @@ -21,6 +21,7 @@ class StudentTReparam(Reparam): an auxiliary :class:`~pyro.distributions.Gamma` variable conditioned on which the result is :class:`~pyro.distributions.Normal` . """ + def apply(self, msg): name = msg["name"] fn = msg["fn"] @@ -32,8 +33,9 @@ def apply(self, msg): # Draw a sample that depends only on df. half_df = fn.df * 0.5 - gamma = pyro.sample("{}_gamma".format(name), - self._wrap(dist.Gamma(half_df, half_df), event_dim)) + gamma = pyro.sample( + "{}_gamma".format(name), self._wrap(dist.Gamma(half_df, half_df), event_dim) + ) # Construct a scaled Normal. loc = fn.loc diff --git a/pyro/infer/reparam/unit_jacobian.py b/pyro/infer/reparam/unit_jacobian.py index 7ac6c1a554..a45037c58f 100644 --- a/pyro/infer/reparam/unit_jacobian.py +++ b/pyro/infer/reparam/unit_jacobian.py @@ -25,8 +25,10 @@ class UnitJacobianReparam(Reparam): batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False. """ - def __init__(self, transform, suffix="transformed", *, - experimental_allow_batch=False): + + def __init__( + self, transform, suffix="transformed", *, experimental_allow_batch=False + ): self.transform = transform.with_cache() self.suffix = suffix self.experimental_allow_batch = experimental_allow_batch @@ -46,30 +48,34 @@ def apply(self, msg): raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " - "setting experimental_allow_batch=True.") + "setting experimental_allow_batch=True." + ) # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) + old_shape = fn.batch_shape new_shape = old_shape[:-shift] + (1,) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) - transform = reshape_transform_batch(transform, - old_shape + fn.event_shape, - new_shape + fn.event_shape) + transform = reshape_transform_batch( + transform, old_shape + fn.event_shape, new_shape + fn.event_shape + ) if value is not None: value = value.reshape( - value.shape[:-shift - event_dim] + (1,) * shift - + value.shape[-shift - event_dim:] + value.shape[: -shift - event_dim] + + (1,) * shift + + value.shape[-shift - event_dim :] ) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Differentiably invert transform. - transform = ComposeTransform([biject_to(fn.support).inv.with_cache(), - self.transform]) + transform = ComposeTransform( + [biject_to(fn.support).inv.with_cache(), self.transform] + ) value_trans = None if value is not None: value_trans = transform(value) @@ -87,8 +93,8 @@ def apply(self, msg): value = transform.inv(value_trans) if shift: value = value.reshape( - value.shape[:-2 * shift - event_dim] - + value.shape[-shift - event_dim:] + value.shape[: -2 * shift - event_dim] + + value.shape[-shift - event_dim :] ) # Simulate a pyro.deterministic() site. diff --git a/pyro/infer/rws.py b/pyro/infer/rws.py index 0646c1ef32..6ec7f28c08 100644 --- a/pyro/infer/rws.py +++ b/pyro/infer/rws.py @@ -74,35 +74,42 @@ class ReweightedWakeSleep(ELBO): Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood """ - def __init__(self, - num_particles=2, - insomnia=1., - model_has_params=True, - num_sleep_particles=None, - vectorize_particles=True, - max_plate_nesting=float('inf'), - strict_enumeration_warning=True): + def __init__( + self, + num_particles=2, + insomnia=1.0, + model_has_params=True, + num_sleep_particles=None, + vectorize_particles=True, + max_plate_nesting=float("inf"), + strict_enumeration_warning=True, + ): # force K > 1 otherwise SNIS not possible - assert(num_particles > 1), \ - "Reweighted Wake Sleep needs to be run with more than one particle" - - super().__init__(num_particles=num_particles, - max_plate_nesting=max_plate_nesting, - vectorize_particles=vectorize_particles, - strict_enumeration_warning=strict_enumeration_warning) + assert ( + num_particles > 1 + ), "Reweighted Wake Sleep needs to be run with more than one particle" + + super().__init__( + num_particles=num_particles, + max_plate_nesting=max_plate_nesting, + vectorize_particles=vectorize_particles, + strict_enumeration_warning=strict_enumeration_warning, + ) self.insomnia = insomnia self.model_has_params = model_has_params - self.num_sleep_particles = num_particles if num_sleep_particles is None else num_sleep_particles + self.num_sleep_particles = ( + num_particles if num_sleep_particles is None else num_sleep_particles + ) - assert(insomnia >= 0 and insomnia <= 1), \ - "insomnia should be in [0, 1]" + assert insomnia >= 0 and insomnia <= 1, "insomnia should be in [0, 1]" def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ - model_trace, guide_trace = get_importance_trace("flat", self.max_plate_nesting, - model, guide, args, kwargs, detach=True) + model_trace, guide_trace = get_importance_trace( + "flat", self.max_plate_nesting, model, guide, args, kwargs, detach=True + ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace @@ -117,20 +124,24 @@ def _loss(self, model, guide, args, kwargs): Performs backward as appropriate on both, over the specified number of particles. """ - wake_theta_loss = torch.tensor(100.) - if self.model_has_params or self.insomnia > 0.: + wake_theta_loss = torch.tensor(100.0) + if self.model_has_params or self.insomnia > 0.0: # compute quantities for wake theta and wake phi log_joints = [] log_qs = [] - for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - log_joint = 0. - log_q = 0. + for model_trace, guide_trace in self._get_traces( + model, guide, args, kwargs + ): + log_joint = 0.0 + log_q = 0.0 for _, site in model_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: - log_p_site = site["log_prob"].reshape(self.num_particles, -1).sum(-1) + log_p_site = ( + site["log_prob"].reshape(self.num_particles, -1).sum(-1) + ) else: log_p_site = site["log_prob_sum"] log_joint = log_joint + log_p_site @@ -138,7 +149,9 @@ def _loss(self, model, guide, args, kwargs): for _, site in guide_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: - log_q_site = site["log_prob"].reshape(self.num_particles, -1).sum(-1) + log_q_site = ( + site["log_prob"].reshape(self.num_particles, -1).sum(-1) + ) else: log_q_site = site["log_prob_sum"] log_q = log_q + log_q_site @@ -146,7 +159,9 @@ def _loss(self, model, guide, args, kwargs): log_joints.append(log_joint) log_qs.append(log_q) - log_joints = log_joints[0] if self.vectorize_particles else torch.stack(log_joints) + log_joints = ( + log_joints[0] if self.vectorize_particles else torch.stack(log_joints) + ) log_qs = log_qs[0] if self.vectorize_particles else torch.stack(log_qs) log_weights = log_joints - log_qs.detach() @@ -165,10 +180,10 @@ def _loss(self, model, guide, args, kwargs): # compute sleep phi loss _model = pyro.poutine.uncondition(model) _guide = guide - _log_q = 0. + _log_q = 0.0 if self.vectorize_particles: - if self.max_plate_nesting == float('inf'): + if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(_model, _guide, args, kwargs) _model = self._vectorized_num_sleep_particles(_model) _guide = self._vectorized_num_sleep_particles(guide) @@ -176,16 +191,22 @@ def _loss(self, model, guide, args, kwargs): for _ in range(1 if self.vectorize_particles else self.num_sleep_particles): _model_trace = poutine.trace(_model).get_trace(*args, **kwargs) _model_trace.detach_() - _guide_trace = self._get_matched_trace(_model_trace, _guide, args, kwargs) + _guide_trace = self._get_matched_trace( + _model_trace, _guide, args, kwargs + ) _log_q += _guide_trace.log_prob_sum() sleep_phi_loss = -_log_q / self.num_sleep_particles warn_if_nan(sleep_phi_loss, "sleep phi loss") # compute phi loss - phi_loss = sleep_phi_loss if self.insomnia == 0 \ - else wake_phi_loss if self.insomnia == 1 \ - else self.insomnia * wake_phi_loss + (1. - self.insomnia) * sleep_phi_loss + phi_loss = ( + sleep_phi_loss + if self.insomnia == 0 + else wake_phi_loss + if self.insomnia == 1 + else self.insomnia * wake_phi_loss + (1.0 - self.insomnia) * sleep_phi_loss + ) return wake_theta_loss, phi_loss @@ -220,10 +241,15 @@ def _vectorized_num_sleep_particles(self, fn): """ Copy of `_vectorised_num_particles` that uses `num_sleep_particles`. """ + def wrapped_fn(*args, **kwargs): if self.num_sleep_particles == 1: return fn(*args, **kwargs) - with pyro.plate("num_sleep_particles_vectorized", self.num_sleep_particles, dim=-self.max_plate_nesting): + with pyro.plate( + "num_sleep_particles_vectorized", + self.num_sleep_particles, + dim=-self.max_plate_nesting, + ): return fn(*args, **kwargs) return wrapped_fn @@ -236,7 +262,9 @@ def _get_matched_trace(model_trace, guide, args, kwargs): model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] - guide_trace = poutine.trace(poutine.replay(guide, model_trace)).get_trace(*args, **kwargs) + guide_trace = poutine.trace(poutine.replay(guide, model_trace)).get_trace( + *args, **kwargs + ) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) diff --git a/pyro/infer/smcfilter.py b/pyro/infer/smcfilter.py index 22d2360748..9bb301375b 100644 --- a/pyro/infer/smcfilter.py +++ b/pyro/infer/smcfilter.py @@ -18,6 +18,7 @@ class SMCFailed(ValueError): Exception raised when :class:`SMCFilter` fails to find any hypothesis with nonzero probability. """ + pass @@ -50,9 +51,11 @@ class SMCFilter: when to importance resample: resampling occurs when ``ess < ess_threshold * num_particles``. """ + # TODO: Add window kwarg that defaults to float("inf") - def __init__(self, model, guide, num_particles, max_plate_nesting, *, - ess_threshold=0.5): + def __init__( + self, model, guide, num_particles, max_plate_nesting, *, ess_threshold=0.5 + ): assert 0 < ess_threshold <= 1 self.model = model self.guide = guide @@ -69,10 +72,14 @@ def init(self, *args, **kwargs): Perform any initialization for sequential importance resampling. Any args or kwargs are passed to the model and guide """ - self.particle_plate = pyro.plate("particles", self.num_particles, dim=-1-self.max_plate_nesting) + self.particle_plate = pyro.plate( + "particles", self.num_particles, dim=-1 - self.max_plate_nesting + ) with poutine.block(), self.particle_plate: with self.state._lock(): - guide_trace = poutine.trace(self.guide.init).get_trace(self.state, *args, **kwargs) + guide_trace = poutine.trace(self.guide.init).get_trace( + self.state, *args, **kwargs + ) model = poutine.replay(self.model.init, guide_trace) model_trace = poutine.trace(model).get_trace(self.state, *args, **kwargs) @@ -87,7 +94,9 @@ def step(self, *args, **kwargs): """ with poutine.block(), self.particle_plate: with self.state._lock(): - guide_trace = poutine.trace(self.guide.step).get_trace(self.state, *args, **kwargs) + guide_trace = poutine.trace(self.guide.step).get_trace( + self.state, *args, **kwargs + ) model = poutine.replay(self.model.step, guide_trace) model_trace = poutine.trace(model).get_trace(self.state, *args, **kwargs) @@ -100,8 +109,10 @@ def get_empirical(self): :rtype: a dictionary with keys which are latent variables and values which are :class:`~pyro.distributions.Empirical` objects. """ - return {key: dist.Empirical(value, self.state._log_weights) - for key, value in self.state.items()} + return { + key: dist.Empirical(value, self.state._log_weights) + for key, value in self.state.items() + } @torch.no_grad() def _update_weights(self, model_trace, guide_trace): @@ -120,16 +131,20 @@ def _update_weights(self, model_trace, guide_trace): log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p - log_q if not (self.state._log_weights.max() > -math.inf): - raise SMCFailed("Failed to find feasible hypothesis after site {}" - .format(name)) + raise SMCFailed( + "Failed to find feasible hypothesis after site {}".format(name) + ) for site in model_trace.nodes.values(): if site["type"] == "sample" and site["is_observed"]: log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p if not (self.state._log_weights.max() > -math.inf): - raise SMCFailed("Failed to find feasible hypothesis after site {}" - .format(site["name"])) + raise SMCFailed( + "Failed to find feasible hypothesis after site {}".format( + site["name"] + ) + ) self.state._log_weights -= self.state._log_weights.max() @@ -172,6 +187,7 @@ class SMCState(dict): :param int num_particles: """ + def __init__(self, num_particles): assert isinstance(num_particles, int) and num_particles > 0 super().__init__() @@ -192,14 +208,20 @@ def __setitem__(self, key, value): raise RuntimeError("Guide cannot write to SMCState") if is_validation_enabled(): if not isinstance(value, torch.Tensor): - raise TypeError("Only Tensors can be stored in an SMCState, but got {}" - .format(type(value).__name__)) + raise TypeError( + "Only Tensors can be stored in an SMCState, but got {}".format( + type(value).__name__ + ) + ) if value.dim() == 0 or value.size(0) != self._num_particles: - raise ValueError("Expected leading dim of size {} but got shape {}" - .format(self._num_particles, value.shape)) + raise ValueError( + "Expected leading dim of size {} but got shape {}".format( + self._num_particles, value.shape + ) + ) super().__setitem__(key, value) def _resample(self, index): for key, value in self.items(): self[key] = value[index].contiguous() - self._log_weights.fill_(0.) + self._log_weights.fill_(0.0) diff --git a/pyro/infer/svgd.py b/pyro/infer/svgd.py index 9e722a745f..7eceb3581a 100644 --- a/pyro/infer/svgd.py +++ b/pyro/infer/svgd.py @@ -18,8 +18,11 @@ def vectorize(fn, num_particles, max_plate_nesting): def _fn(*args, **kwargs): - with pyro.plate("num_particles_vectorized", num_particles, dim=-max_plate_nesting - 1): + with pyro.plate( + "num_particles_vectorized", num_particles, dim=-max_plate_nesting - 1 + ): return fn(*args, **kwargs) + return _fn @@ -28,6 +31,7 @@ class _SVGDGuide(AutoContinuous): This modification of :class:`AutoContinuous` is used internally in the :class:`SVGD` inference algorithm. """ + def __init__(self, model): super().__init__(model, init_loc_fn=init_to_sample) @@ -70,6 +74,7 @@ class RBFSteinKernel(SteinKernel): [1] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm," Qiang Liu, Dilin Wang """ + def __init__(self, bandwidth_factor=None): """ :param float bandwidth_factor: Optional factor by which to scale the bandwidth @@ -135,6 +140,7 @@ class IMQSteinKernel(SteinKernel): [1] "Stein Points," Wilson Ye Chen, Lester Mackey, Jackson Gorham, Francois-Xavier Briol, Chris. J. Oates. [2] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm," Qiang Liu, Dilin Wang """ + def __init__(self, alpha=0.5, beta=-0.5, bandwidth_factor=None): """ :param float alpha: Kernel hyperparameter, defaults to 0.5. @@ -224,13 +230,21 @@ class SVGD: [2] "Kernelized Complete Conditional Stein Discrepancy," Raghav Singhal, Saad Lahlou, Rajesh Ranganath """ - def __init__(self, model, kernel, optim, num_particles, max_plate_nesting, mode="univariate"): + + def __init__( + self, model, kernel, optim, num_particles, max_plate_nesting, mode="univariate" + ): assert callable(model) assert isinstance(kernel, SteinKernel), "Must provide a valid SteinKernel" - assert isinstance(optim, pyro.optim.PyroOptim), "Must provide a valid Pyro optimizer" + assert isinstance( + optim, pyro.optim.PyroOptim + ), "Must provide a valid Pyro optimizer" assert num_particles > 1, "Must use at least two particles" assert max_plate_nesting >= 0 - assert mode in ['univariate', 'multivariate'], "mode must be one of (univariate, multivariate)" + assert mode in [ + "univariate", + "multivariate", + ], "mode must be one of (univariate, multivariate)" self.model = vectorize(model, num_particles, max_plate_nesting) self.kernel = kernel @@ -247,8 +261,12 @@ def get_named_particles(self): Create a dictionary mapping name to vectorized value, of the form ``{name: tensor}``. The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays. """ - return {site["name"]: biject_to(site["fn"].support)(unconstrained_value) - for site, unconstrained_value in self.guide._unpack_latent(pyro.param("svgd_particles"))} + return { + site["name"]: biject_to(site["fn"].support)(unconstrained_value) + for site, unconstrained_value in self.guide._unpack_latent( + pyro.param("svgd_particles") + ) + } @torch.no_grad() def step(self, *args, **kwargs): @@ -279,20 +297,32 @@ def step(self, *args, **kwargs): repulsive_grad = torch.einsum("nm,nm...->n...", kernel, kernel_grad) elif self.mode == "univariate": kernel = log_kernel.exp() - assert kernel.shape == (self.num_particles, self.num_particles, reshaped_particles.size(-1)) - attractive_grad = torch.einsum("nmd,md->nd", kernel, reshaped_particles_grad) + assert kernel.shape == ( + self.num_particles, + self.num_particles, + reshaped_particles.size(-1), + ) + attractive_grad = torch.einsum( + "nmd,md->nd", kernel, reshaped_particles_grad + ) repulsive_grad = torch.einsum("nmd,nmd->nd", kernel, kernel_grad) # combine the attractive and repulsive terms in the SVGD gradient assert attractive_grad.shape == repulsive_grad.shape - particles.grad = (attractive_grad + repulsive_grad).reshape(particles.shape) / self.num_particles + particles.grad = (attractive_grad + repulsive_grad).reshape( + particles.shape + ) / self.num_particles # compute per-parameter mean squared gradients - squared_gradients = {site["name"]: value.mean().item() - for site, value in self.guide._unpack_latent(particles.grad.pow(2.0))} + squared_gradients = { + site["name"]: value.mean().item() + for site, value in self.guide._unpack_latent(particles.grad.pow(2.0)) + } # torch.optim objects gets instantiated for any params that haven't been seen yet - params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values()) + params = set( + site["value"].unconstrained() for site in param_capture.trace.nodes.values() + ) self.optim(params) # zero gradients diff --git a/pyro/infer/svi.py b/pyro/infer/svi.py index b4058ef819..067b0ca852 100644 --- a/pyro/infer/svi.py +++ b/pyro/infer/svi.py @@ -34,23 +34,32 @@ class SVI(TracePosterior): commonly used loss is ``loss=Trace_ELBO()``. See the tutorial `SVI Part I `_ for a discussion. """ - def __init__(self, - model, - guide, - optim, - loss, - loss_and_grads=None, - num_samples=0, - num_steps=0, - **kwargs): + + def __init__( + self, + model, + guide, + optim, + loss, + loss_and_grads=None, + num_samples=0, + num_steps=0, + **kwargs + ): if num_steps: - warnings.warn('The `num_steps` argument to SVI is deprecated and will be removed in ' - 'a future release. Use `SVI.step` directly to control the ' - 'number of iterations.', FutureWarning) + warnings.warn( + "The `num_steps` argument to SVI is deprecated and will be removed in " + "a future release. Use `SVI.step` directly to control the " + "number of iterations.", + FutureWarning, + ) if num_samples: - warnings.warn('The `num_samples` argument to SVI is deprecated and will be removed in ' - 'a future release. Use `pyro.infer.Predictive` class to draw ' - 'samples from the posterior.', FutureWarning) + warnings.warn( + "The `num_samples` argument to SVI is deprecated and will be removed in " + "a future release. Use `pyro.infer.Predictive` class to draw " + "samples from the posterior.", + FutureWarning, + ) self.model = model self.guide = guide @@ -60,18 +69,22 @@ def __init__(self, super().__init__(**kwargs) if not isinstance(optim, pyro.optim.PyroOptim): - raise ValueError("Optimizer should be an instance of pyro.optim.PyroOptim class.") + raise ValueError( + "Optimizer should be an instance of pyro.optim.PyroOptim class." + ) if isinstance(loss, ELBO): self.loss = loss.loss self.loss_and_grads = loss.loss_and_grads else: if loss_and_grads is None: + def _loss_and_grads(*args, **kwargs): loss_val = loss(*args, **kwargs) - if getattr(loss_val, 'requires_grad', False): + if getattr(loss_val, "requires_grad", False): loss_val.backward(retain_graph=True) return loss_val + loss_and_grads = _loss_and_grads self.loss = loss self.loss_and_grads = loss_and_grads @@ -83,10 +96,12 @@ def run(self, *args, **kwargs): For inference, use :meth:`step` directly, and for predictions, use the :class:`~pyro.infer.predictive.Predictive` class. """ - warnings.warn('The `SVI.run` method is deprecated and will be removed in a ' - 'future release. For inference, use `SVI.step` directly, ' - 'and for predictions, use the `pyro.infer.Predictive` class.', - FutureWarning) + warnings.warn( + "The `SVI.run` method is deprecated and will be removed in a " + "future release. For inference, use `SVI.step` directly, " + "and for predictions, use the `pyro.infer.Predictive` class.", + FutureWarning, + ) if self.num_steps > 0: with poutine.block(): for i in range(self.num_steps): @@ -96,7 +111,9 @@ def run(self, *args, **kwargs): def _traces(self, *args, **kwargs): for i in range(self.num_samples): guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) - model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace(*args, **kwargs) yield model_trace, 1.0 def evaluate_loss(self, *args, **kwargs): @@ -127,8 +144,9 @@ def step(self, *args, **kwargs): with poutine.trace(param_only=True) as param_capture: loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) - params = set(site["value"].unconstrained() - for site in param_capture.trace.nodes.values()) + params = set( + site["value"].unconstrained() for site in param_capture.trace.nodes.values() + ) # actually perform gradient steps # torch.optim objects gets instantiated for any params that haven't been seen yet diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index adb433c475..93041c92cc 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -55,7 +55,8 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "flat", self.max_plate_nesting, model, guide, args, kwargs) + "flat", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace @@ -69,7 +70,9 @@ def loss(self, model, guide, *args, **kwargs): """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum()) + elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item( + guide_trace.log_prob_sum() + ) elbo += elbo_particle / self.num_particles loss = -elbo @@ -94,13 +97,17 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = elbo_particle - torch_item(site["log_prob_sum"]) if not is_identically_zero(entropy_term): - surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum() + surrogate_elbo_particle = ( + surrogate_elbo_particle - entropy_term.sum() + ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) - surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum() + surrogate_elbo_particle = ( + surrogate_elbo_particle + (site * score_function_term).sum() + ) return -elbo_particle, -surrogate_elbo_particle @@ -109,10 +116,12 @@ def differentiable_loss(self, model, guide, *args, **kwargs): Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters """ - loss = 0. - surrogate_loss = 0. + loss = 0.0 + surrogate_loss = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace) + loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( + model_trace, guide_trace + ) surrogate_loss += surrogate_loss_particle / self.num_particles loss += loss_particle / self.num_particles warn_if_nan(surrogate_loss, "loss") @@ -129,15 +138,21 @@ def loss_and_grads(self, model, guide, *args, **kwargs): loss = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace) + loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( + model_trace, guide_trace + ) loss += loss_particle / self.num_particles # collect parameters to train from model and guide - trainable_params = any(site["type"] == "param" - for trace in (model_trace, guide_trace) - for site in trace.nodes.values()) - - if trainable_params and getattr(surrogate_loss_particle, 'requires_grad', False): + trainable_params = any( + site["type"] == "param" + for trace in (model_trace, guide_trace) + for site in trace.nodes.values() + ) + + if trainable_params and getattr( + surrogate_loss_particle, "requires_grad", False + ): surrogate_loss_particle = surrogate_loss_particle / self.num_particles surrogate_loss_particle.backward(retain_graph=self.retain_graph) warn_if_nan(loss, "loss") @@ -158,22 +173,26 @@ class JitTrace_ELBO(Trace_ELBO): ``**kwargs``, and compilation will be triggered once per unique ``**kwargs``. """ + def loss_and_surrogate_loss(self, model, guide, *args, **kwargs): - kwargs['_pyro_model_id'] = id(model) - kwargs['_pyro_guide_id'] = id(guide) - if getattr(self, '_loss_and_surrogate_loss', None) is None: + kwargs["_pyro_model_id"] = id(model) + kwargs["_pyro_guide_id"] = id(guide) + if getattr(self, "_loss_and_surrogate_loss", None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings, - jit_options=self.jit_options) + @pyro.ops.jit.trace( + ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options + ) def loss_and_surrogate_loss(*args, **kwargs): - kwargs.pop('_pyro_model_id') - kwargs.pop('_pyro_guide_id') + kwargs.pop("_pyro_model_id") + kwargs.pop("_pyro_guide_id") self = weakself() loss = 0.0 surrogate_loss = 0.0 - for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): + for model_trace, guide_trace in self._get_traces( + model, guide, args, kwargs + ): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None @@ -182,25 +201,36 @@ def loss_and_surrogate_loss(*args, **kwargs): for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + site["log_prob_sum"] - surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"] + surrogate_elbo_particle = ( + surrogate_elbo_particle + site["log_prob_sum"] + ) for name, site in guide_trace.nodes.items(): if site["type"] == "sample": - log_prob, score_function_term, entropy_term = site["score_parts"] + log_prob, score_function_term, entropy_term = site[ + "score_parts" + ] elbo_particle = elbo_particle - site["log_prob_sum"] if not is_identically_zero(entropy_term): - surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum() + surrogate_elbo_particle = ( + surrogate_elbo_particle - entropy_term.sum() + ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) - surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum() + surrogate_elbo_particle = ( + surrogate_elbo_particle + + (site * score_function_term).sum() + ) loss = loss - elbo_particle / self.num_particles - surrogate_loss = surrogate_loss - surrogate_elbo_particle / self.num_particles + surrogate_loss = ( + surrogate_loss - surrogate_elbo_particle / self.num_particles + ) return loss, surrogate_loss @@ -209,13 +239,17 @@ def loss_and_surrogate_loss(*args, **kwargs): return self._loss_and_surrogate_loss(*args, **kwargs) def differentiable_loss(self, model, guide, *args, **kwargs): - loss, surrogate_loss = self.loss_and_surrogate_loss(model, guide, *args, **kwargs) + loss, surrogate_loss = self.loss_and_surrogate_loss( + model, guide, *args, **kwargs + ) warn_if_nan(loss, "loss") return loss + (surrogate_loss - surrogate_loss.detach()) def loss_and_grads(self, model, guide, *args, **kwargs): - loss, surrogate_loss = self.loss_and_surrogate_loss(model, guide, *args, **kwargs) + loss, surrogate_loss = self.loss_and_surrogate_loss( + model, guide, *args, **kwargs + ) surrogate_loss.backward() loss = loss.item() diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index 6b8329b1e0..f3b61c147b 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -23,17 +23,27 @@ def _check_mean_field_requirement(model_trace, guide_trace): Checks that the guide and model sample sites are ordered identically. This is sufficient but not necessary for correctness. """ - model_sites = [name for name, site in model_trace.nodes.items() - if site["type"] == "sample" and name in guide_trace.nodes] - guide_sites = [name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and name in model_trace.nodes] + model_sites = [ + name + for name, site in model_trace.nodes.items() + if site["type"] == "sample" and name in guide_trace.nodes + ] + guide_sites = [ + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and name in model_trace.nodes + ] assert set(model_sites) == set(guide_sites) if model_sites != guide_sites: - warnings.warn("Failed to verify mean field restriction on the guide. " - "To eliminate this warning, ensure model and guide sites " - "occur in the same order.\n" + - "Model sites:\n " + "\n ".join(model_sites) + - "Guide sites:\n " + "\n ".join(guide_sites)) + warnings.warn( + "Failed to verify mean field restriction on the guide. " + "To eliminate this warning, ensure model and guide sites " + "occur in the same order.\n" + + "Model sites:\n " + + "\n ".join(model_sites) + + "Guide sites:\n " + + "\n ".join(guide_sites) + ) class TraceMeanField_ELBO(Trace_ELBO): @@ -67,9 +77,9 @@ class TraceMeanField_ELBO(Trace_ELBO): condition is always satisfied if the model and guide have identical dependency structures. """ + def _get_trace(self, model, guide, args, kwargs): - model_trace, guide_trace = super()._get_trace( - model, guide, args, kwargs) + model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) if is_validation_enabled(): _check_mean_field_requirement(model_trace, guide_trace) return model_trace, guide_trace @@ -83,7 +93,9 @@ def loss(self, model, guide, *args, **kwargs): """ loss = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - loss_particle, _ = self._differentiable_loss_particle(model_trace, guide_trace) + loss_particle, _ = self._differentiable_loss_particle( + model_trace, guide_trace + ) loss = loss + loss_particle / self.num_particles warn_if_nan(loss, "loss") @@ -104,16 +116,24 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): # use kl divergence if available, else fall back on sampling try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) - kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"]) + kl_qp = scale_and_mask( + kl_qp, scale=guide_site["scale"], mask=guide_site["mask"] + ) if torch.is_tensor(kl_qp): assert kl_qp.shape == guide_site["fn"].batch_shape kl_qp_sum = kl_qp.sum() else: - kl_qp_sum = kl_qp * torch.Size(guide_site["fn"].batch_shape).numel() + kl_qp_sum = ( + kl_qp * torch.Size(guide_site["fn"].batch_shape).numel() + ) elbo_particle = elbo_particle - kl_qp_sum except NotImplementedError: entropy_term = guide_site["score_parts"].entropy_term - elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum() + elbo_particle = ( + elbo_particle + + model_site["log_prob_sum"] + - entropy_term.sum() + ) # handle auxiliary sites in the guide for name, guide_site in guide_trace.nodes.items(): @@ -124,7 +144,11 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): entropy_term = guide_site["score_parts"].entropy_term elbo_particle = elbo_particle - entropy_term.sum() - loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle)) + loss = -( + elbo_particle.detach() + if torch._C._get_tracing_state() + else torch_item(elbo_particle) + ) surrogate_loss = -elbo_particle return loss, surrogate_loss @@ -143,22 +167,28 @@ class JitTraceMeanField_ELBO(TraceMeanField_ELBO): ``**kwargs``, and compilation will be triggered once per unique ``**kwargs``. """ + def differentiable_loss(self, model, guide, *args, **kwargs): - kwargs['_pyro_model_id'] = id(model) - kwargs['_pyro_guide_id'] = id(guide) - if getattr(self, '_loss_and_surrogate_loss', None) is None: + kwargs["_pyro_model_id"] = id(model) + kwargs["_pyro_guide_id"] = id(guide) + if getattr(self, "_loss_and_surrogate_loss", None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings, - jit_options=self.jit_options) + @pyro.ops.jit.trace( + ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options + ) def differentiable_loss(*args, **kwargs): - kwargs.pop('_pyro_model_id') - kwargs.pop('_pyro_guide_id') + kwargs.pop("_pyro_model_id") + kwargs.pop("_pyro_guide_id") self = weakself() loss = 0.0 - for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - _, loss_particle = self._differentiable_loss_particle(model_trace, guide_trace) + for model_trace, guide_trace in self._get_traces( + model, guide, args, kwargs + ): + _, loss_particle = self._differentiable_loss_particle( + model_trace, guide_trace + ) loss = loss + loss_particle / self.num_particles return loss diff --git a/pyro/infer/trace_mmd.py b/pyro/infer/trace_mmd.py index 661ff727c2..42b22297bc 100644 --- a/pyro/infer/trace_mmd.py +++ b/pyro/infer/trace_mmd.py @@ -61,18 +61,26 @@ class Trace_MMD(ELBO): Shengjia Zhao, Jiaming Song, Stefano Ermon """ - def __init__(self, - kernel, mmd_scale=1, - num_particles=10, - max_plate_nesting=float('inf'), - max_iarange_nesting=None, # DEPRECATED - vectorize_particles=True, - strict_enumeration_warning=True, - ignore_jit_warnings=False, - retain_graph=None): + def __init__( + self, + kernel, + mmd_scale=1, + num_particles=10, + max_plate_nesting=float("inf"), + max_iarange_nesting=None, # DEPRECATED + vectorize_particles=True, + strict_enumeration_warning=True, + ignore_jit_warnings=False, + retain_graph=None, + ): super().__init__( - num_particles, max_plate_nesting, max_iarange_nesting, vectorize_particles, - strict_enumeration_warning, ignore_jit_warnings, retain_graph, + num_particles, + max_plate_nesting, + max_iarange_nesting, + vectorize_particles, + strict_enumeration_warning, + ignore_jit_warnings, + retain_graph, ) self._kernel = None self._mmd_scale = None @@ -91,13 +99,17 @@ def kernel(self, kernel): if isinstance(k, pyro.contrib.gp.kernels.kernel.Kernel): k.requires_grad_(False) else: - raise TypeError("`kernel` values should be instances of `pyro.contrib.gp.kernels.kernel.Kernel`") + raise TypeError( + "`kernel` values should be instances of `pyro.contrib.gp.kernels.kernel.Kernel`" + ) self._kernel = kernel elif isinstance(kernel, pyro.contrib.gp.kernels.kernel.Kernel): kernel.requires_grad_(False) self._kernel = defaultdict(lambda: kernel) else: - raise TypeError("`kernel` should be an instance of `pyro.contrib.gp.kernels.kernel.Kernel`") + raise TypeError( + "`kernel` should be an instance of `pyro.contrib.gp.kernels.kernel.Kernel`" + ) @property def mmd_scale(self): @@ -118,7 +130,8 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "flat", self.max_plate_nesting, model, guide, args, kwargs) + "flat", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace @@ -135,30 +148,47 @@ def _differentiable_loss_parts(self, model, guide, args, kwargs): self._vectorized_num_particles(model) ).get_trace(*args, **kwargs) else: - model_trace_independent = poutine.trace(model, graph_type='flat').get_trace(*args, **kwargs) + model_trace_independent = poutine.trace( + model, graph_type="flat" + ).get_trace(*args, **kwargs) loglikelihood_particle = 0.0 for name, model_site in model_trace.nodes.items(): - if model_site['type'] == 'sample': - if name in guide_trace and not model_site['is_observed']: + if model_site["type"] == "sample": + if name in guide_trace and not model_site["is_observed"]: guide_site = guide_trace.nodes[name] independent_model_site = model_trace_independent.nodes[name] if not independent_model_site["fn"].has_rsample: - raise ValueError("Model site {} is not reparameterizable".format(name)) + raise ValueError( + "Model site {} is not reparameterizable".format(name) + ) if not guide_site["fn"].has_rsample: - raise ValueError("Guide site {} is not reparameterizable".format(name)) + raise ValueError( + "Guide site {} is not reparameterizable".format(name) + ) - particle_dim = -self.max_plate_nesting - independent_model_site["fn"].event_dim + particle_dim = ( + -self.max_plate_nesting + - independent_model_site["fn"].event_dim + ) - model_samples = independent_model_site['value'] - guide_samples = guide_site['value'] + model_samples = independent_model_site["value"] + guide_samples = guide_site["value"] if self.vectorize_particles: - model_samples = model_samples.transpose(-model_samples.dim(), particle_dim) - model_samples = model_samples.view(model_samples.shape[0], -1) - - guide_samples = guide_samples.transpose(-guide_samples.dim(), particle_dim) - guide_samples = guide_samples.view(guide_samples.shape[0], -1) + model_samples = model_samples.transpose( + -model_samples.dim(), particle_dim + ) + model_samples = model_samples.view( + model_samples.shape[0], -1 + ) + + guide_samples = guide_samples.transpose( + -guide_samples.dim(), particle_dim + ) + guide_samples = guide_samples.view( + guide_samples.shape[0], -1 + ) else: model_samples = model_samples.view(1, -1) guide_samples = guide_samples.view(1, -1) @@ -166,14 +196,20 @@ def _differentiable_loss_parts(self, model, guide, args, kwargs): all_model_samples[name].append(model_samples) all_guide_samples[name].append(guide_samples) else: - loglikelihood_particle = loglikelihood_particle + model_site['log_prob_sum'] + loglikelihood_particle = ( + loglikelihood_particle + model_site["log_prob_sum"] + ) loglikelihood = loglikelihood_particle / self.num_particles + loglikelihood for name in all_model_samples.keys(): all_model_samples[name] = torch.cat(all_model_samples[name]) all_guide_samples[name] = torch.cat(all_guide_samples[name]) - divergence = _compute_mmd(all_model_samples[name], all_guide_samples[name], kernel=self._kernel[name]) + divergence = _compute_mmd( + all_model_samples[name], + all_guide_samples[name], + kernel=self._kernel[name], + ) penalty = self._mmd_scale[name] * divergence + penalty warn_if_nan(loglikelihood, "loglikelihood") @@ -192,7 +228,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): Shengjia Zhao https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/ """ - loglikelihood, penalty = self._differentiable_loss_parts(model, guide, args, kwargs) + loglikelihood, penalty = self._differentiable_loss_parts( + model, guide, args, kwargs + ) loss = -loglikelihood + penalty warn_if_nan(loss, "loss") return loss diff --git a/pyro/infer/trace_tail_adaptive_elbo.py b/pyro/infer/trace_tail_adaptive_elbo.py index b05251a300..5ff15e4bee 100644 --- a/pyro/infer/trace_tail_adaptive_elbo.py +++ b/pyro/infer/trace_tail_adaptive_elbo.py @@ -31,21 +31,28 @@ class TraceTailAdaptive_ELBO(Trace_ELBO): Hao Liu, Qiang Liu, NeurIPS 2018 https://papers.nips.cc/paper/7816-variational-inference-with-tail-adaptive-f-divergence """ + def loss(self, model, guide, *args, **kwargs): """ It is not necessary to estimate the tail-adaptive f-divergence itself in order to compute the corresponding gradients. Consequently the loss method is left unimplemented. """ - raise NotImplementedError("Loss method for TraceTailAdaptive_ELBO not implemented") + raise NotImplementedError( + "Loss method for TraceTailAdaptive_ELBO not implemented" + ) def _differentiable_loss_particle(self, model_trace, guide_trace): if not self.vectorize_particles: - raise NotImplementedError("TraceTailAdaptive_ELBO only implemented for vectorize_particles==True") + raise NotImplementedError( + "TraceTailAdaptive_ELBO only implemented for vectorize_particles==True" + ) if self.num_particles == 1: - warnings.warn("For num_particles==1 TraceTailAdaptive_ELBO uses the same loss function as Trace_ELBO. " + - "Increase num_particles to get an adaptive f-divergence.") + warnings.warn( + "For num_particles==1 TraceTailAdaptive_ELBO uses the same loss function as Trace_ELBO. " + + "Increase num_particles to get an adaptive f-divergence." + ) log_p, log_q = 0, 0 @@ -64,11 +71,13 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): # rank the particles according to p/q log_pq = log_p - log_q rank = torch.argsort(log_pq, descending=False) - rank = torch.index_select(torch.arange(self.num_particles, device=log_pq.device) + 1, -1, rank).type_as(log_pq) + rank = torch.index_select( + torch.arange(self.num_particles, device=log_pq.device) + 1, -1, rank + ).type_as(log_pq) # compute the particle-specific weights used to construct the surrogate loss gamma = torch.pow(rank, self.tail_adaptive_beta).detach() surrogate_loss = -(log_pq * gamma).sum() / gamma.sum() # we do not compute the loss, so return `inf` - return float('inf'), surrogate_loss + return float("inf"), surrogate_loss diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index e6e564b91f..019c7f821e 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -37,11 +37,13 @@ def _get_common_scale(scales): scales_set = set() for scale in scales: if isinstance(scale, torch.Tensor) and scale.dim(): - raise ValueError('enumeration only supports scalar poutine.scale') + raise ValueError("enumeration only supports scalar poutine.scale") scales_set.add(float(scale)) if len(scales_set) != 1: - raise ValueError("Expected all enumerated sample sites to share a common poutine.scale, " - "but found {} different scales.".format(len(scales_set))) + raise ValueError( + "Expected all enumerated sample sites to share a common poutine.scale, " + "but found {} different scales.".format(len(scales_set)) + ) return scales[0] @@ -50,51 +52,72 @@ def _check_model_guide_enumeration_constraint(model_enum_sites, guide_trace): for name, site in guide_trace.nodes.items(): if site["type"] == "sample" and site["infer"].get("_enumerate_dim") is not None: for f in site["cond_indep_stack"]: - if f.vectorized and guide_trace.plate_to_symbol[f.name] not in min_ordinal: - raise ValueError("Expected model enumeration to be no more global than guide enumeration, " - "but found model enumeration sites upstream of guide site '{}' in plate('{}'). " - "Try converting some model enumeration sites to guide enumeration sites." - .format(name, f.name)) + if ( + f.vectorized + and guide_trace.plate_to_symbol[f.name] not in min_ordinal + ): + raise ValueError( + "Expected model enumeration to be no more global than guide enumeration, " + "but found model enumeration sites upstream of guide site '{}' in plate('{}'). " + "Try converting some model enumeration sites to guide enumeration sites.".format( + name, f.name + ) + ) def _check_tmc_elbo_constraint(model_trace, guide_trace): num_samples = frozenset( site["infer"].get("num_samples") for site in guide_trace.nodes.values() - if site["type"] == "sample" and - site["infer"].get("enumerate") == "parallel" and - site["infer"].get("num_samples") is not None) + if site["type"] == "sample" + and site["infer"].get("enumerate") == "parallel" + and site["infer"].get("num_samples") is not None + ) if len(num_samples) > 1: - warnings.warn('\n'.join([ - "Using different numbers of Monte Carlo samples for different guide sites in TraceEnum_ELBO.", - "This may be biased if the guide is not factorized", - ]), UserWarning) + warnings.warn( + "\n".join( + [ + "Using different numbers of Monte Carlo samples for different guide sites in TraceEnum_ELBO.", + "This may be biased if the guide is not factorized", + ] + ), + UserWarning, + ) for name, site in model_trace.nodes.items(): - if site["type"] == "sample" and \ - site["infer"].get("enumerate", None) == "parallel" and \ - site["infer"].get("num_samples", None) and \ - name not in guide_trace: - warnings.warn('\n'.join([ - "Site {} is multiply sampled in model,".format(site["name"]), - "expect incorrect gradient estimates from TraceEnum_ELBO.", - "Consider using exact enumeration or guide sampling if possible.", - ]), RuntimeWarning) + if ( + site["type"] == "sample" + and site["infer"].get("enumerate", None) == "parallel" + and site["infer"].get("num_samples", None) + and name not in guide_trace + ): + warnings.warn( + "\n".join( + [ + "Site {} is multiply sampled in model,".format(site["name"]), + "expect incorrect gradient estimates from TraceEnum_ELBO.", + "Consider using exact enumeration or guide sampling if possible.", + ] + ), + RuntimeWarning, + ) def _find_ordinal(trace, site): - return frozenset(trace.plate_to_symbol[f.name] - for f in site["cond_indep_stack"] - if f.vectorized) + return frozenset( + trace.plate_to_symbol[f.name] for f in site["cond_indep_stack"] if f.vectorized + ) # TODO move this logic into a poutine def _compute_model_factors(model_trace, guide_trace): # y depends on x iff ordering[x] <= ordering[y] # TODO refine this coarse dependency ordering using time. - ordering = {name: _find_ordinal(trace, site) - for trace in (model_trace, guide_trace) - for name, site in trace.nodes.items() - if site["type"] == "sample"} + ordering = { + name: _find_ordinal(trace, site) + for trace in (model_trace, guide_trace) + for name, site in trace.nodes.items() + if site["type"] == "sample" + } # Collect model sites that may have been enumerated in the model. cost_sites = OrderedDict() @@ -105,7 +128,9 @@ def _compute_model_factors(model_trace, guide_trace): if site["type"] == "sample": if name in guide_trace.nodes: cost_sites.setdefault(ordering[name], []).append(site) - non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims) + non_enum_dims.update( + guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims + ) elif site["infer"].get("_enumerate_dim") is None: cost_sites.setdefault(ordering[name], []).append(site) else: @@ -115,8 +140,10 @@ def _compute_model_factors(model_trace, guide_trace): log_factors = OrderedDict() scale = 1 if not enum_sites: - marginal_costs = OrderedDict((t, [site["packed"]["log_prob"] for site in sites_t]) - for t, sites_t in cost_sites.items()) + marginal_costs = OrderedDict( + (t, [site["packed"]["log_prob"] for site in sites_t]) + for t, sites_t in cost_sites.items() + ) return marginal_costs, log_factors, ordering, enum_dims, scale _check_model_guide_enumeration_constraint(enum_sites, guide_trace) @@ -133,7 +160,8 @@ def _compute_model_factors(model_trace, guide_trace): # the mask inside- and the scale outside- of the log expectation. if "masked_log_prob" not in site["packed"]: site["packed"]["masked_log_prob"] = packed.scale_and_mask( - site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"]) + site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"] + ) cost = site["packed"]["masked_log_prob"] log_factors.setdefault(t, []).append(cost) scales.append(site["scale"]) @@ -150,7 +178,8 @@ def _compute_model_factors(model_trace, guide_trace): def _compute_dice_elbo(model_trace, guide_trace): # Accumulate marginal model costs. marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors( - model_trace, guide_trace) + model_trace, guide_trace + ) if log_factors: dim_to_size = {} for terms in log_factors.values(): @@ -199,17 +228,23 @@ def _compute_marginals(model_trace, guide_trace): marginal_dists = OrderedDict() with shared_intermediates() as cache: for name, site in model_trace.nodes.items(): - if (site["type"] != "sample" or - name in guide_trace.nodes or - site["infer"].get("_enumerate_dim") is None): + if ( + site["type"] != "sample" + or name in guide_trace.nodes + or site["infer"].get("_enumerate_dim") is None + ): continue enum_dim = site["infer"]["_enumerate_dim"] enum_symbol = site["infer"]["_enumerate_symbol"] ordinal = _find_ordinal(model_trace, site) - logits = contract_to_tensor(log_factors, sum_dims, - target_ordinal=ordinal, target_dims={enum_symbol}, - cache=cache) + logits = contract_to_tensor( + log_factors, + sum_dims, + target_ordinal=ordinal, + target_dims={enum_symbol}, + cache=cache, + ) logits = packed.unpack(logits, model_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: @@ -223,6 +258,7 @@ class BackwardSampleMessenger(pyro.poutine.messenger.Messenger): Implements forward filtering / backward sampling for sampling from the joint posterior distribution """ + def __init__(self, enum_trace, guide_trace): self.enum_trace = enum_trace args = _compute_model_factors(enum_trace, guide_trace) @@ -248,9 +284,13 @@ def _pyro_sample(self, msg): enum_dim = enum_msg["infer"]["_enumerate_dim"] with shared_intermediates(self.cache): ordinal = _find_ordinal(self.enum_trace, msg) - logits = contract_to_tensor(self.log_factors, self.sum_dims, - target_ordinal=ordinal, target_dims={enum_symbol}, - cache=self.cache) + logits = contract_to_tensor( + self.log_factors, + self.sum_dims, + target_ordinal=ordinal, + target_dims={enum_symbol}, + cache=self.cache, + ) logits = packed.unpack(logits, self.enum_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: @@ -297,22 +337,27 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "flat", self.max_plate_nesting, model, guide, args, kwargs) + "flat", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) _check_tmc_elbo_constraint(model_trace, guide_trace) - has_enumerated_sites = any(site["infer"].get("enumerate") - for trace in (guide_trace, model_trace) - for name, site in trace.nodes.items() - if site["type"] == "sample") + has_enumerated_sites = any( + site["infer"].get("enumerate") + for trace in (guide_trace, model_trace) + for name, site in trace.nodes.items() + if site["type"] == "sample" + ) if self.strict_enumeration_warning and not has_enumerated_sites: - warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' - 'If you want to enumerate sites, you need to @config_enumerate or set ' - 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' - 'If you do not want to enumerate, consider using Trace_ELBO instead.') + warnings.warn( + "TraceEnum_ELBO found no sample sites configured for enumeration. " + "If you want to enumerate sites, you need to @config_enumerate or set " + 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' + "If you do not want to enumerate, consider using Trace_ELBO instead." + ) guide_trace.pack_tensors() model_trace.pack_tensors(guide_trace.plate_to_symbol) @@ -323,7 +368,7 @@ def _get_traces(self, model, guide, args, kwargs): Runs the guide and runs the model against the guide with the result packaged as a trace generator. """ - if self.max_plate_nesting == float('inf'): + if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(model, guide, args, kwargs) if self.vectorize_particles: guide = self._vectorized_num_particles(guide) @@ -338,9 +383,9 @@ def _get_traces(self, model, guide, args, kwargs): model = model_enum(model) q = queue.LifoQueue() - guide = poutine.queue(guide, q, - escape_fn=iter_discrete_escape, - extend_fn=iter_discrete_extend) + guide = poutine.queue( + guide, q, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend + ) for i in range(1 if self.vectorize_particles else self.num_particles): q.put(poutine.Trace()) while not q.empty(): @@ -386,7 +431,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = elbo / self.num_particles if not torch.is_tensor(elbo) or not elbo.requires_grad: - raise ValueError('ELBO is cannot be differentiated: {}'.format(elbo)) + raise ValueError("ELBO is cannot be differentiated: {}".format(elbo)) loss = -elbo warn_if_nan(loss, "loss") @@ -409,9 +454,11 @@ def loss_and_grads(self, model, guide, *args, **kwargs): elbo += elbo_particle.item() / self.num_particles # collect parameters to train from model and guide - trainable_params = any(site["type"] == "param" - for trace in (model_trace, guide_trace) - for site in trace.nodes.values()) + trainable_params = any( + site["type"] == "param" + for trace in (model_trace, guide_trace) + for site in trace.nodes.values() + ) if trainable_params and elbo_particle.requires_grad: loss_particle = -elbo_particle @@ -429,14 +476,18 @@ def compute_marginals(self, model, guide, *args, **kwargs): :rtype: OrderedDict """ if self.num_particles != 1: - raise NotImplementedError("TraceEnum_ELBO.compute_marginals() is not " - "compatible with multiple particles.") + raise NotImplementedError( + "TraceEnum_ELBO.compute_marginals() is not " + "compatible with multiple particles." + ) model_trace, guide_trace = next(self._get_traces(model, guide, args, kwargs)) for site in guide_trace.nodes.values(): if site["type"] == "sample": if "_enumerate_dim" in site["infer"] or "_enum_total" in site["infer"]: - raise NotImplementedError("TraceEnum_ELBO.compute_marginals() is not " - "compatible with guide enumeration.") + raise NotImplementedError( + "TraceEnum_ELBO.compute_marginals() is not " + "compatible with guide enumeration." + ) return _compute_marginals(model_trace, guide_trace) def sample_posterior(self, model, guide, *args, **kwargs): @@ -444,17 +495,23 @@ def sample_posterior(self, model, guide, *args, **kwargs): Sample from the joint posterior distribution of all model-enumerated sites given all observations """ if self.num_particles != 1: - raise NotImplementedError("TraceEnum_ELBO.sample_posterior() is not " - "compatible with multiple particles.") + raise NotImplementedError( + "TraceEnum_ELBO.sample_posterior() is not " + "compatible with multiple particles." + ) with poutine.block(), warnings.catch_warnings(): warnings.filterwarnings("ignore", "Found vars in model but not guide") - model_trace, guide_trace = next(self._get_traces(model, guide, args, kwargs)) + model_trace, guide_trace = next( + self._get_traces(model, guide, args, kwargs) + ) for name, site in guide_trace.nodes.items(): if site["type"] == "sample": if "_enumerate_dim" in site["infer"] or "_enum_total" in site["infer"]: - raise NotImplementedError("TraceEnum_ELBO.sample_posterior() is not " - "compatible with guide enumeration.") + raise NotImplementedError( + "TraceEnum_ELBO.sample_posterior() is not " + "compatible with guide enumeration." + ) # TODO replace BackwardSample with torch_sample backend to ubersum with BackwardSampleMessenger(model_trace, guide_trace): @@ -475,21 +532,25 @@ class JitTraceEnum_ELBO(TraceEnum_ELBO): ``**kwargs``, and compilation will be triggered once per unique ``**kwargs``. """ + def differentiable_loss(self, model, guide, *args, **kwargs): - kwargs['_model_id'] = id(model) - kwargs['_guide_id'] = id(guide) - if getattr(self, '_differentiable_loss', None) is None: + kwargs["_model_id"] = id(model) + kwargs["_guide_id"] = id(guide) + if getattr(self, "_differentiable_loss", None) is None: # build a closure for differentiable_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings, - jit_options=self.jit_options) + @pyro.ops.jit.trace( + ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options + ) def differentiable_loss(*args, **kwargs): - kwargs.pop('_model_id') - kwargs.pop('_guide_id') + kwargs.pop("_model_id") + kwargs.pop("_guide_id") self = weakself() elbo = 0.0 - for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): + for model_trace, guide_trace in self._get_traces( + model, guide, args, kwargs + ): elbo = elbo + _compute_dice_elbo(model_trace, guide_trace) return elbo * (-1.0 / self.num_particles) diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 6ac3df08c7..d0c12532ff 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -27,13 +27,17 @@ def _get_baseline_options(site): """ # XXX default for baseline_beta currently set here options_dict = site["infer"].get("baseline", {}).copy() - options_tuple = (options_dict.pop('nn_baseline', None), - options_dict.pop('nn_baseline_input', None), - options_dict.pop('use_decaying_avg_baseline', False), - options_dict.pop('baseline_beta', 0.90), - options_dict.pop('baseline_value', None)) + options_tuple = ( + options_dict.pop("nn_baseline", None), + options_dict.pop("nn_baseline_input", None), + options_dict.pop("use_decaying_avg_baseline", False), + options_dict.pop("baseline_beta", 0.90), + options_dict.pop("baseline_value", None), + ) if options_dict: - raise ValueError("Unrecognized baseline options: {}".format(options_dict.keys())) + raise ValueError( + "Unrecognized baseline options: {}".format(options_dict.keys()) + ) return options_tuple @@ -44,24 +48,32 @@ def _construct_baseline(node, guide_site, downstream_cost): baseline = 0.0 baseline_loss = 0.0 - (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, - baseline_value) = _get_baseline_options(guide_site) + ( + nn_baseline, + nn_baseline_input, + use_decaying_avg_baseline, + baseline_beta, + baseline_value, + ) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None use_baseline = use_nn_baseline or use_decaying_avg_baseline or use_baseline_value - assert(not (use_nn_baseline and use_baseline_value)), \ - "cannot use baseline_value and nn_baseline simultaneously" + assert not ( + use_nn_baseline and use_baseline_value + ), "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: dc_shape = downstream_cost.shape param_name = "__baseline_avg_downstream_cost_" + node with torch.no_grad(): - avg_downstream_cost_old = pyro.param(param_name, - torch.zeros(dc_shape, device=guide_site['value'].device)) - avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ - baseline_beta * avg_downstream_cost_old + avg_downstream_cost_old = pyro.param( + param_name, torch.zeros(dc_shape, device=guide_site["value"].device) + ) + avg_downstream_cost_new = ( + 1 - baseline_beta + ) * downstream_cost + baseline_beta * avg_downstream_cost_old pyro.get_param_store()[param_name] = avg_downstream_cost_new baseline += avg_downstream_cost_old if use_nn_baseline: @@ -76,22 +88,25 @@ def _construct_baseline(node, guide_site, downstream_cost): if use_baseline: if downstream_cost.shape != baseline.shape: - raise ValueError("Expected baseline at site {} to be {} instead got {}".format( - node, downstream_cost.shape, baseline.shape)) + raise ValueError( + "Expected baseline at site {} to be {} instead got {}".format( + node, downstream_cost.shape, baseline.shape + ) + ) return use_baseline, baseline_loss, baseline -def _compute_downstream_costs(model_trace, guide_trace, # - non_reparam_nodes): +def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes): # # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = guide_trace.topological_sort(reverse=True) - topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes - if guide_trace.nodes[x]["type"] == "sample"] + topo_sort_guide_nodes = [ + x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample" + ] ordered_guide_nodes_dict = {n: i for i, n in enumerate(topo_sort_guide_nodes)} downstream_guide_cost_nodes = {} @@ -99,13 +114,19 @@ def _compute_downstream_costs(model_trace, guide_trace, # stacks = get_plate_stacks(model_trace) for node in topo_sort_guide_nodes: - downstream_costs[node] = MultiFrameTensor((stacks[node], - model_trace.nodes[node]['log_prob'] - - guide_trace.nodes[node]['log_prob'])) + downstream_costs[node] = MultiFrameTensor( + ( + stacks[node], + model_trace.nodes[node]["log_prob"] + - guide_trace.nodes[node]["log_prob"], + ) + ) nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) # make more efficient by ordering children appropriately (higher children first) - children = [(k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node)] + children = [ + (k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node) + ] sorted_children = sorted(children, key=itemgetter(1)) for child, _ in sorted_children: child_cost_nodes = downstream_guide_cost_nodes[child] @@ -115,12 +136,18 @@ def _compute_downstream_costs(model_trace, guide_trace, # # XXX nodes_included_in_sum logic could be more fine-grained, possibly leading # to speed-ups in case there are many duplicates nodes_included_in_sum.update(child_cost_nodes) - missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum + missing_downstream_costs = ( + downstream_guide_cost_nodes[node] - nodes_included_in_sum + ) # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: - downstream_costs[node].add((stacks[missing_node], - model_trace.nodes[missing_node]['log_prob'] - - guide_trace.nodes[missing_node]['log_prob'])) + downstream_costs[node].add( + ( + stacks[missing_node], + model_trace.nodes[missing_node]["log_prob"] + - guide_trace.nodes[missing_node]["log_prob"], + ) + ) # finish assembling complete downstream costs # (the above computation may be missing terms from model) @@ -131,13 +158,16 @@ def _compute_downstream_costs(model_trace, guide_trace, # # remove terms accounted for above children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: - assert (model_trace.nodes[child]["type"] == "sample") - downstream_costs[site].add((stacks[child], - model_trace.nodes[child]['log_prob'])) + assert model_trace.nodes[child]["type"] == "sample" + downstream_costs[site].add( + (stacks[child], model_trace.nodes[child]["log_prob"]) + ) downstream_guide_cost_nodes[site].update([child]) for k in non_reparam_nodes: - downstream_costs[k] = downstream_costs[k].sum_to(guide_trace.nodes[k]["cond_indep_stack"]) + downstream_costs[k] = downstream_costs[k].sum_to( + guide_trace.nodes[k]["cond_indep_stack"] + ) return downstream_costs, downstream_guide_cost_nodes @@ -187,7 +217,9 @@ def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs): downstream_cost = downstream_costs[node] score_function = guide_site["score_parts"].score_function - use_baseline, baseline_loss_term, baseline = _construct_baseline(node, guide_site, downstream_cost) + use_baseline, baseline_loss_term, baseline = _construct_baseline( + node, guide_site, downstream_cost + ) if use_baseline: downstream_cost = downstream_cost - baseline @@ -227,7 +259,8 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "dense", self.max_plate_nesting, model, guide, args, kwargs) + "dense", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace @@ -241,7 +274,9 @@ def loss(self, model, guide, *args, **kwargs): """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): - elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(guide_trace.log_prob_sum()) + elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item( + guide_trace.log_prob_sum() + ) elbo += elbo_particle / float(self.num_particles) loss = -elbo @@ -291,10 +326,12 @@ def _loss_and_surrogate_loss_particle(self, model_trace, guide_trace): # the following computations are only necessary if we have non-reparameterizable nodes non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) if non_reparam_nodes: - downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) - surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace, - non_reparam_nodes, - downstream_costs) + downstream_costs, _ = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) + surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam( + guide_trace, non_reparam_nodes, downstream_costs + ) surrogate_elbo += surrogate_elbo_term surrogate_loss = -surrogate_elbo + baseline_loss @@ -318,17 +355,18 @@ class JitTraceGraph_ELBO(TraceGraph_ELBO): """ def loss_and_grads(self, model, guide, *args, **kwargs): - kwargs['_pyro_model_id'] = id(model) - kwargs['_pyro_guide_id'] = id(guide) - if getattr(self, '_jit_loss_and_surrogate_loss', None) is None: + kwargs["_pyro_model_id"] = id(model) + kwargs["_pyro_guide_id"] = id(guide) + if getattr(self, "_jit_loss_and_surrogate_loss", None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings, - jit_options=self.jit_options) + @pyro.ops.jit.trace( + ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options + ) def jit_loss_and_surrogate_loss(*args, **kwargs): - kwargs.pop('_pyro_model_id') - kwargs.pop('_pyro_guide_id') + kwargs.pop("_pyro_model_id") + kwargs.pop("_pyro_guide_id") self = weakself() return self._loss_and_surrogate_loss(model, guide, args, kwargs) @@ -336,7 +374,9 @@ def jit_loss_and_surrogate_loss(*args, **kwargs): loss, surrogate_loss = self._jit_loss_and_surrogate_loss(*args, **kwargs) - surrogate_loss.backward(retain_graph=self.retain_graph) # triggers jit compilation + surrogate_loss.backward( + retain_graph=self.retain_graph + ) # triggers jit compilation loss = loss.item() warn_if_nan(loss, "loss") diff --git a/pyro/infer/tracetmc_elbo.py b/pyro/infer/tracetmc_elbo.py index f76b8c50b3..c949dd9b68 100644 --- a/pyro/infer/tracetmc_elbo.py +++ b/pyro/infer/tracetmc_elbo.py @@ -59,10 +59,12 @@ def _compute_tmc_factors(model_trace, guide_trace): for name, site in model_trace.nodes.items(): if site["type"] != "sample": continue - if site["name"] not in guide_trace and \ - not site["is_observed"] and \ - site["infer"].get("enumerate", None) == "parallel" and \ - site["infer"].get("num_samples", -1) > 0: + if ( + site["name"] not in guide_trace + and not site["is_observed"] + and site["infer"].get("enumerate", None) == "parallel" + and site["infer"].get("num_samples", -1) > 0 + ): # site was sampled from the prior log_proposal = packed.neg(site["packed"]["log_prob"]) log_factors.append(log_proposal) @@ -80,15 +82,23 @@ def _compute_tmc_estimate(model_trace, guide_trace): log_factors += _compute_dice_factors(model_trace, guide_trace) if not log_factors: - return 0. + return 0.0 # loss eqn = ",".join([f._pyro_dims for f in log_factors]) + "->" - plates = "".join(frozenset().union(list(model_trace.plate_to_symbol.values()), - list(guide_trace.plate_to_symbol.values()))) - tmc, = einsum(eqn, *log_factors, plates=plates, - backend="pyro.ops.einsum.torch_log", - modulo_total=False) + plates = "".join( + frozenset().union( + list(model_trace.plate_to_symbol.values()), + list(guide_trace.plate_to_symbol.values()), + ) + ) + (tmc,) = einsum( + eqn, + *log_factors, + plates=plates, + backend="pyro.ops.einsum.torch_log", + modulo_total=False + ) return tmc @@ -126,21 +136,26 @@ def _get_trace(self, model, guide, args, kwargs): against it. """ model_trace, guide_trace = get_importance_trace( - "flat", self.max_plate_nesting, model, guide, args, kwargs) + "flat", self.max_plate_nesting, model, guide, args, kwargs + ) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) - has_enumerated_sites = any(site["infer"].get("enumerate") - for trace in (guide_trace, model_trace) - for name, site in trace.nodes.items() - if site["type"] == "sample") + has_enumerated_sites = any( + site["infer"].get("enumerate") + for trace in (guide_trace, model_trace) + for name, site in trace.nodes.items() + if site["type"] == "sample" + ) if self.strict_enumeration_warning and not has_enumerated_sites: - warnings.warn('Found no sample sites configured for enumeration. ' - 'If you want to enumerate sites, you need to @config_enumerate or set ' - 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' - 'If you do not want to enumerate, consider using Trace_ELBO instead.') + warnings.warn( + "Found no sample sites configured for enumeration. " + "If you want to enumerate sites, you need to @config_enumerate or set " + 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' + "If you do not want to enumerate, consider using Trace_ELBO instead." + ) model_trace.compute_score_parts() guide_trace.pack_tensors() @@ -152,7 +167,7 @@ def _get_traces(self, model, guide, args, kwargs): Runs the guide and runs the model against the guide with the result packaged as a trace generator. """ - if self.max_plate_nesting == float('inf'): + if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(model, guide, args, kwargs) if self.vectorize_particles: guide = self._vectorized_num_particles(guide) @@ -167,9 +182,9 @@ def _get_traces(self, model, guide, args, kwargs): model = model_enum(model) q = queue.LifoQueue() - guide = poutine.queue(guide, q, - escape_fn=iter_discrete_escape, - extend_fn=iter_discrete_extend) + guide = poutine.queue( + guide, q, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend + ) for i in range(1 if self.vectorize_particles else self.num_particles): q.put(poutine.Trace()) while not q.empty(): diff --git a/pyro/infer/util.py b/pyro/infer/util.py index de94b7d0a6..3ee94b884d 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -92,17 +92,20 @@ def get_plate_stacks(trace): an :class:`plate`. This information is used by :class:`Trace_ELBO` and :class:`TraceGraph_ELBO`. """ - return {name: [f for f in node["cond_indep_stack"] if f.vectorized] - for name, node in trace.nodes.items() - if node["type"] == "sample" and not site_is_subsample(node)} + return { + name: [f for f in node["cond_indep_stack"] if f.vectorized] + for name, node in trace.nodes.items() + if node["type"] == "sample" and not site_is_subsample(node) + } def get_dependent_plate_dims(sites): """ Return a list of unique dims for plates that are not common to all sites. """ - plate_sets = [site["cond_indep_stack"] - for site in sites if site["type"] == "sample"] + plate_sets = [ + site["cond_indep_stack"] for site in sites if site["type"] == "sample" + ] all_plates = set().union(*plate_sets) common_plates = all_plates.intersection(*plate_sets) sum_plates = all_plates - common_plates @@ -125,6 +128,7 @@ class MultiFrameTensor(dict): downstream_cost.add(*other_costs.items()) # add in bulk summed = downstream_cost.sum_to(target_site["cond_indep_stack"]) """ + def __init__(self, *items): super().__init__() self.add(*items) @@ -152,11 +156,13 @@ def sum_to(self, target_frames): while value.shape and value.shape[0] == 1: value = value.squeeze(0) total = value if total is None else total + value - return 0. if total is None else total + return 0.0 if total is None else total def __repr__(self): - return '%s(%s)' % (type(self).__name__, ",\n\t".join([ - '({}, ...)'.format(frames) for frames in self])) + return "%s(%s)" % ( + type(self).__name__, + ",\n\t".join(["({}, ...)".format(frames) for frames in self]), + ) def compute_site_dice_factor(site): @@ -212,8 +218,11 @@ class Dice: Ordinal values may be any type that is (1) ``<=`` comparable and (2) hashable; the canonical ordinal is a ``frozenset`` of site names. """ + def __init__(self, guide_trace, ordering): - log_denoms = defaultdict(float) # avoids double-counting when sequentially enumerating + log_denoms = defaultdict( + float + ) # avoids double-counting when sequentially enumerating log_probs = defaultdict(list) # accounts for upstream probabilties for name, site in guide_trace.nodes.items(): @@ -257,10 +266,12 @@ def compute_expectation(self, costs): # Share computation across all cost terms. with shared_intermediates() as cache: ring = MarginalRing(cache=cache) - expected_cost = 0. + expected_cost = 0.0 for ordinal, cost_terms in costs.items(): log_factors = self._get_log_factors(ordinal) - scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor))) + scale = math.exp( + sum(x for x in log_factors if not isinstance(x, torch.Tensor)) + ) log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)] # Collect log_prob terms to query for marginal probability. @@ -285,7 +296,10 @@ def compute_expectation(self, costs): require_backward(query) root = ring.sumproduct(log_factors, sum_dims) root._pyro_backward() - probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()} + probs = { + key: query._pyro_backward_result.exp() + for key, query in queries.items() + } # Aggregate prob * cost terms. for cost in cost_terms: @@ -315,7 +329,12 @@ def _fulldot(x, y): def check_fully_reparametrized(guide_site): log_prob, score_function_term, entropy_term = guide_site["score_parts"] - fully_rep = (guide_site["fn"].has_rsample and not is_identically_zero(entropy_term) and - is_identically_zero(score_function_term)) + fully_rep = ( + guide_site["fn"].has_rsample + and not is_identically_zero(entropy_term) + and is_identically_zero(score_function_term) + ) if not fully_rep: - raise NotImplementedError("All distributions in the guide must be fully reparameterized.") + raise NotImplementedError( + "All distributions in the guide must be fully reparameterized." + ) diff --git a/pyro/logger.py b/pyro/logger.py index 64a8c70c46..e5d99748ab 100644 --- a/pyro/logger.py +++ b/pyro/logger.py @@ -3,7 +3,7 @@ import logging -default_format = '%(levelname)s \t %(message)s' +default_format = "%(levelname)s \t %(message)s" log = logging.getLogger("pyro") log.setLevel(logging.INFO) diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index c33d9f8ec8..3ae06bd055 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -19,7 +19,9 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): :param simple: True to space fractional indices by rounding to nearest int, false round randomly :type simple: bool """ - indices = torch.linspace(1, input_dim, steps=hidden_dim, device='cpu').to(torch.Tensor().device) + indices = torch.linspace(1, input_dim, steps=hidden_dim, device="cpu").to( + torch.Tensor().device + ) if simple: # Simple procedure tries to space fractional indices evenly by rounding to nearest int return torch.round(indices) @@ -30,7 +32,9 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): return ints -def create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier): +def create_mask( + input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier +): """ Creates MADE masks for a conditional distribution @@ -64,15 +68,29 @@ def create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_mul output_indices = (var_index + 1).repeat(output_dim_multiplier) # Create mask from input to output for the skips connections - mask_skip = (output_indices.unsqueeze(-1) > input_indices.unsqueeze(0)).type_as(var_index) + mask_skip = (output_indices.unsqueeze(-1) > input_indices.unsqueeze(0)).type_as( + var_index + ) # Create mask from input to first hidden layer, and between subsequent hidden layers - masks = [(hidden_indices[0].unsqueeze(-1) >= input_indices.unsqueeze(0)).type_as(var_index)] + masks = [ + (hidden_indices[0].unsqueeze(-1) >= input_indices.unsqueeze(0)).type_as( + var_index + ) + ] for i in range(1, len(hidden_dims)): - masks.append((hidden_indices[i].unsqueeze(-1) >= hidden_indices[i - 1].unsqueeze(0)).type_as(var_index)) + masks.append( + ( + hidden_indices[i].unsqueeze(-1) >= hidden_indices[i - 1].unsqueeze(0) + ).type_as(var_index) + ) # Create mask from last hidden layer to output layer - masks.append((output_indices.unsqueeze(-1) > hidden_indices[-1].unsqueeze(0)).type_as(var_index)) + masks.append( + (output_indices.unsqueeze(-1) > hidden_indices[-1].unsqueeze(0)).type_as( + var_index + ) + ) return masks, mask_skip @@ -93,7 +111,7 @@ class MaskedLinear(nn.Linear): def __init__(self, in_features, out_features, mask, bias=True): super().__init__(in_features, out_features, bias) - self.register_buffer('mask', mask.data) + self.register_buffer("mask", mask.data) def forward(self, _input): masked_weight = self.weight * self.mask @@ -147,17 +165,20 @@ class ConditionalAutoRegressiveNN(nn.Module): """ def __init__( - self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU()): + self, + input_dim, + context_dim, + hidden_dims, + param_dims=[1, 1], + permutation=None, + skip_connections=False, + nonlinearity=nn.ReLU(), + ): super().__init__() if input_dim == 1: - warnings.warn('ConditionalAutoRegressiveNN input_dim = 1. Consider using an affine transformation instead.') + warnings.warn( + "ConditionalAutoRegressiveNN input_dim = 1. Consider using an affine transformation instead." + ) self.input_dim = input_dim self.context_dim = context_dim self.hidden_dims = hidden_dims @@ -175,36 +196,47 @@ def __init__( # possible to connect to the outputs correctly for h in hidden_dims: if h < input_dim: - raise ValueError('Hidden dimension must not be less than input dimension.') + raise ValueError( + "Hidden dimension must not be less than input dimension." + ) if permutation is None: # By default set a random permutation of variables, which is important for performance with multiple steps - P = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device) + P = torch.randperm(input_dim, device="cpu").to(torch.Tensor().device) else: # The permutation is chosen by the user P = permutation.type(dtype=torch.int64) - self.register_buffer('permutation', P) + self.register_buffer("permutation", P) # Create masks self.masks, self.mask_skip = create_mask( - input_dim=input_dim, context_dim=context_dim, hidden_dims=hidden_dims, permutation=self.permutation, - output_dim_multiplier=self.output_multiplier) + input_dim=input_dim, + context_dim=context_dim, + hidden_dims=hidden_dims, + permutation=self.permutation, + output_dim_multiplier=self.output_multiplier, + ) # Create masked layers layers = [MaskedLinear(input_dim + context_dim, hidden_dims[0], self.masks[0])] for i in range(1, len(hidden_dims)): - layers.append(MaskedLinear(hidden_dims[i - 1], hidden_dims[i], self.masks[i])) - layers.append(MaskedLinear(hidden_dims[-1], input_dim * self.output_multiplier, self.masks[-1])) + layers.append( + MaskedLinear(hidden_dims[i - 1], hidden_dims[i], self.masks[i]) + ) + layers.append( + MaskedLinear( + hidden_dims[-1], input_dim * self.output_multiplier, self.masks[-1] + ) + ) self.layers = nn.ModuleList(layers) if skip_connections: self.skip_layer = MaskedLinear( - input_dim + - context_dim, - input_dim * - self.output_multiplier, + input_dim + context_dim, + input_dim * self.output_multiplier, self.mask_skip, - bias=False) + bias=False, + ) else: self.skip_layer = None @@ -239,7 +271,9 @@ def _forward(self, x): if self.output_multiplier == 1: return h else: - h = h.reshape(list(x.size()[:-1]) + [self.output_multiplier, self.input_dim]) + h = h.reshape( + list(x.size()[:-1]) + [self.output_multiplier, self.input_dim] + ) # Squeeze dimension if all parameters are one dimensional if self.count_params == 1: @@ -293,23 +327,23 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN): """ def __init__( - self, - input_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU()): - super( - AutoRegressiveNN, - self).__init__( + self, + input_dim, + hidden_dims, + param_dims=[1, 1], + permutation=None, + skip_connections=False, + nonlinearity=nn.ReLU(), + ): + super(AutoRegressiveNN, self).__init__( input_dim, 0, hidden_dims, param_dims=param_dims, permutation=permutation, skip_connections=skip_connections, - nonlinearity=nonlinearity) + nonlinearity=nonlinearity, + ) def forward(self, x): return self._forward(x) diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index b8af6ef0f1..a7a9a7e645 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -34,12 +34,13 @@ class ConditionalDenseNN(torch.nn.Module): """ def __init__( - self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU()): + self, + input_dim, + context_dim, + hidden_dims, + param_dims=[1, 1], + nonlinearity=torch.nn.ReLU(), + ): super().__init__() self.input_dim = input_dim @@ -55,7 +56,7 @@ def __init__( self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)] # Create masked layers - layers = [torch.nn.Linear(input_dim+context_dim, hidden_dims[0])] + layers = [torch.nn.Linear(input_dim + context_dim, hidden_dims[0])] for i in range(1, len(hidden_dims)): layers.append(torch.nn.Linear(hidden_dims[i - 1], hidden_dims[i])) layers.append(torch.nn.Linear(hidden_dims[-1], self.output_multiplier)) @@ -66,7 +67,7 @@ def __init__( def forward(self, x, context): # We must be able to broadcast the size of the context over the input - context = context.expand(x.size()[:-1]+(context.size(-1),)) + context = context.expand(x.size()[:-1] + (context.size(-1),)) x = torch.cat([context, x], dim=-1) return self._forward(x) @@ -121,17 +122,10 @@ class DenseNN(ConditionalDenseNN): """ def __init__( - self, - input_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU()): + self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU() + ): super(DenseNN, self).__init__( - input_dim, - 0, - hidden_dims, - param_dims=param_dims, - nonlinearity=nonlinearity + input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity ) def forward(self, x): diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 10d4c779b0..ad11567537 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -68,6 +68,7 @@ def forward(self): subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed. """ + # Support use as a decorator. def __get__(self, obj, obj_type): assert issubclass(obj_type, PyroModule) @@ -120,12 +121,15 @@ def forward(self): :class:`PyroModule` instance ``self`` and returns a distribution object. """ + def __init__(self, prior): super().__init__() if not hasattr(prior, "sample"): # if not a distribution - assert 1 == sum(1 for p in inspect.signature(prior).parameters.values() - if p.default is inspect.Parameter.empty), \ - "prior should take the single argument 'self'" + assert 1 == sum( + 1 + for p in inspect.signature(prior).parameters.values() + if p.default is inspect.Parameter.empty + ), "prior should take the single argument 'self'" self.name = getattr(prior, "__name__", None) if self.name is not None: # Ensure decorated function is accessible for pickling. @@ -168,6 +172,7 @@ class _Context: """ Sometimes-active cache for ``PyroModule.__call__()`` contexts. """ + def __init__(self): self.active = 0 self.cache = {} @@ -194,8 +199,11 @@ def set(self, name, value): def _get_pyro_params(module): for name in module._parameters: if name.endswith("_unconstrained"): - constrained_name = name[:-len("_unconstrained")] - if isinstance(module, PyroModule) and constrained_name in module._pyro_params: + constrained_name = name[: -len("_unconstrained")] + if ( + isinstance(module, PyroModule) + and constrained_name in module._pyro_params + ): yield constrained_name, getattr(module, constrained_name) continue yield name, module._parameters[name] @@ -218,13 +226,14 @@ def __getitem__(cls, Module): return PyroModule if Module in _PyroModuleMeta._pyro_mixin_cache: return _PyroModuleMeta._pyro_mixin_cache[Module] - bases = [PyroModule[b] for b in Module.__bases__ - if issubclass(b, torch.nn.Module)] + bases = [ + PyroModule[b] for b in Module.__bases__ if issubclass(b, torch.nn.Module) + ] class result(Module, *bases): # Unpickling helper to load an object of type PyroModule[Module]. def __reduce__(self): - state = getattr(self, '__getstate__', self.__dict__.copy)() + state = getattr(self, "__getstate__", self.__dict__.copy)() return _PyroModuleMeta._New, (Module,), state result.__name__ = "Pyro" + Module.__name__ @@ -365,6 +374,7 @@ class PyroLinear(nn.Linear, PyroModule): :param str name: Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule. """ + def __init__(self, name=""): self._pyro_name = name self._pyro_context = _Context() # shared among sub-PyroModules @@ -377,10 +387,12 @@ def add_module(self, name, module): Adds a child module to the current module. """ if isinstance(module, PyroModule): - module._pyro_set_supermodule(_make_name(self._pyro_name, name), self._pyro_context) + module._pyro_set_supermodule( + _make_name(self._pyro_name, name), self._pyro_context + ) super().add_module(name, module) - def named_pyro_params(self, prefix='', recurse=True): + def named_pyro_params(self, prefix="", recurse=True): """ Returns an iterator over PyroModule parameters, yielding both the name of the parameter as well as the parameter itself. @@ -400,13 +412,14 @@ def _pyro_set_supermodule(self, name, context): self._pyro_context = context for key, value in self._modules.items(): if isinstance(value, PyroModule): - assert not value._pyro_context.used, \ - "submodule {} has executed outside of supermodule".format(name) + assert ( + not value._pyro_context.used + ), "submodule {} has executed outside of supermodule".format(name) value._pyro_set_supermodule(_make_name(name, key), context) def _pyro_get_fullname(self, name): - assert self.__dict__['_pyro_context'].used, "fullname is not yet defined" - return _make_name(self.__dict__['_pyro_name'], name) + assert self.__dict__["_pyro_context"].used, "fullname is not yet defined" + return _make_name(self.__dict__["_pyro_name"], name) def __call__(self, *args, **kwargs): with self._pyro_context: @@ -414,23 +427,34 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): # PyroParams trigger pyro.param statements. - if '_pyro_params' in self.__dict__: - _pyro_params = self.__dict__['_pyro_params'] + if "_pyro_params" in self.__dict__: + _pyro_params = self.__dict__["_pyro_params"] if name in _pyro_params: constraint, event_dim = _pyro_params[name] unconstrained_value = getattr(self, name + "_unconstrained") if self._pyro_context.active: fullname = self._pyro_get_fullname(name) if fullname in _PYRO_PARAM_STORE: - if _PYRO_PARAM_STORE._params[fullname] is not unconstrained_value: + if ( + _PYRO_PARAM_STORE._params[fullname] + is not unconstrained_value + ): # Update PyroModule <--- ParamStore. unconstrained_value = _PYRO_PARAM_STORE._params[fullname] if not isinstance(unconstrained_value, torch.nn.Parameter): # Update PyroModule ---> ParamStore (type only; data is preserved). - unconstrained_value = torch.nn.Parameter(unconstrained_value) - _PYRO_PARAM_STORE._params[fullname] = unconstrained_value - _PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname - super().__setattr__(name + "_unconstrained", unconstrained_value) + unconstrained_value = torch.nn.Parameter( + unconstrained_value + ) + _PYRO_PARAM_STORE._params[ + fullname + ] = unconstrained_value + _PYRO_PARAM_STORE._param_to_name[ + unconstrained_value + ] = fullname + super().__setattr__( + name + "_unconstrained", unconstrained_value + ) else: # Update PyroModule ---> ParamStore. _PYRO_PARAM_STORE._constraints[fullname] = constraint @@ -441,8 +465,8 @@ def __getattr__(self, name): return transform_to(constraint)(unconstrained_value) # PyroSample trigger pyro.sample statements. - if '_pyro_samples' in self.__dict__: - _pyro_samples = self.__dict__['_pyro_samples'] + if "_pyro_samples" in self.__dict__: + _pyro_samples = self.__dict__["_pyro_samples"] if name in _pyro_samples: prior = _pyro_samples[name] context = self._pyro_context @@ -463,7 +487,9 @@ def __getattr__(self, name): result = super().__getattr__(name) # Regular nn.Parameters trigger pyro.param statements. - if isinstance(result, torch.nn.Parameter) and not name.endswith("_unconstrained"): + if isinstance(result, torch.nn.Parameter) and not name.endswith( + "_unconstrained" + ): if self._pyro_context.active: pyro.param(self._pyro_get_fullname(name), result) @@ -471,7 +497,9 @@ def __getattr__(self, name): if isinstance(result, PyroModule): if not result._pyro_name: # Update sub-PyroModules that were converted from nn.Modules in-place. - result._pyro_set_supermodule(_make_name(self._pyro_name, name), self._pyro_context) + result._pyro_set_supermodule( + _make_name(self._pyro_name, name), self._pyro_context + ) else: # Regular nn.Modules trigger pyro.module statements. if self._pyro_context.active: @@ -499,7 +527,12 @@ def __setattr__(self, name, value): self._pyro_params[name] = constraint, event_dim if self._pyro_context.active: fullname = self._pyro_get_fullname(name) - pyro.param(fullname, constrained_value, constraint=constraint, event_dim=event_dim) + pyro.param( + fullname, + constrained_value, + constraint=constraint, + event_dim=event_dim, + ) constrained_value = pyro.param(fullname) unconstrained_value = constrained_value.unconstrained() if not isinstance(unconstrained_value, torch.nn.Parameter): @@ -535,7 +568,9 @@ def __setattr__(self, name, value): constraint, event_dim = self._pyro_params[name] unconstrained_value = getattr(self, name + "_unconstrained") with torch.no_grad(): - unconstrained_value.data = transform_to(constraint).inv(value.detach()) + unconstrained_value.data = transform_to(constraint).inv( + value.detach() + ) return if isinstance(value, PyroSample): @@ -544,7 +579,7 @@ def __setattr__(self, name, value): delattr(self, name) except AttributeError: pass - _pyro_samples = self.__dict__['_pyro_samples'] + _pyro_samples = self.__dict__["_pyro_samples"] _pyro_samples[name] = value.prior return diff --git a/pyro/ops/arrowhead.py b/pyro/ops/arrowhead.py index 8bd14af6cc..e97c8872b3 100644 --- a/pyro/ops/arrowhead.py +++ b/pyro/ops/arrowhead.py @@ -32,7 +32,9 @@ def sqrt(x): num_attempts = 6 for i in range(num_attempts): B_Dsqrt = B / Dsqrt.unsqueeze(-2) # shape: head_size x N - schur_complement = A - B_Dsqrt.matmul(B_Dsqrt.t()) # complexity: head_size^2 x N + schur_complement = A - B_Dsqrt.matmul( + B_Dsqrt.t() + ) # complexity: head_size^2 x N # we will decompose schur_complement to U @ U.T (so that the sqrt matrix # is upper triangular) using some `flip` operators: # flip(cholesky(flip(schur_complement))) @@ -44,8 +46,10 @@ def sqrt(x): except RuntimeError: B = B / 2 continue - raise RuntimeError("Singular schur complement in computing Cholesky of the input" - " arrowhead matrix") + raise RuntimeError( + "Singular schur complement in computing Cholesky of the input" + " arrowhead matrix" + ) top_right = B_Dsqrt top = torch.cat([top_left, top_right], -1) diff --git a/pyro/ops/contract.py b/pyro/ops/contract.py index fcffe8ce59..d10f98a726 100644 --- a/pyro/ops/contract.py +++ b/pyro/ops/contract.py @@ -15,9 +15,12 @@ def _check_plates_are_sensible(output_dims, nonoutput_ordinal): if output_dims and nonoutput_ordinal: - raise ValueError(u"It is nonsensical to preserve a plated dim without preserving " - u"all of that dim's plates, but found '{}' without '{}'" - .format(output_dims, ','.join(nonoutput_ordinal))) + raise ValueError( + u"It is nonsensical to preserve a plated dim without preserving " + u"all of that dim's plates, but found '{}' without '{}'".format( + output_dims, ",".join(nonoutput_ordinal) + ) + ) def _check_tree_structure(parent, leaf): @@ -26,8 +29,10 @@ def _check_tree_structure(parent, leaf): "Expected tree-structured plate nesting, but found " "dependencies on independent plates [{}]. " "Try converting one of the vectorized plates to a sequential plate (but beware " - "exponential cost in the size of the sequence)" - .format(', '.join(getattr(f, 'name', str(f)) for f in leaf))) + "exponential cost in the size of the sequence)".format( + ", ".join(getattr(f, "name", str(f)) for f in leaf) + ) + ) def _partition_terms(ring, terms, dims): @@ -64,7 +69,9 @@ def _partition_terms(ring, terms, dims): # Split this connected component into tensors and dims. component_terms = [v for v in component if isinstance(v, torch.Tensor)] if component_terms: - component_dims = set(v for v in component if not isinstance(v, torch.Tensor)) + component_dims = set( + v for v in component if not isinstance(v, torch.Tensor) + ) components.append((component_terms, component_dims)) return components @@ -122,12 +129,16 @@ def _contract_component(ring, tensor_tree, sum_dims, target_dims): parent = leaf else: pending_dims = sum_dims.intersection(term._pyro_dims) - parent = frozenset.union(*(t for t, d in dims_tree.items() if d & pending_dims)) + parent = frozenset.union( + *(t for t, d in dims_tree.items() if d & pending_dims) + ) _check_tree_structure(parent, leaf) contract_frames = leaf - parent contract_dims = dims & local_dims if contract_dims: - term, local_term = ring.global_local(term, contract_dims, contract_frames) + term, local_term = ring.global_local( + term, contract_dims, contract_frames + ) local_terms.append(local_term) local_dims |= sum_dims.intersection(local_term._pyro_dims) local_ordinal |= leaf @@ -192,8 +203,9 @@ def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None): return contracted_tree -def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None, target_dims=None, - cache=None, ring=None): +def contract_to_tensor( + tensor_tree, sum_dims, target_ordinal=None, target_dims=None, cache=None, ring=None +): """ Contract out ``sum_dims`` in a tree of tensors, via message passing. This reduces all terms down to a single tensor in the plate @@ -244,8 +256,9 @@ def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None, target_dims=N # Contract this connected component down to a single tensor. ordinal, term = _contract_component(ring, component, dims, target_dims & dims) - _check_plates_are_sensible(target_dims.intersection(term._pyro_dims), - ordinal - target_ordinal) + _check_plates_are_sensible( + target_dims.intersection(term._pyro_dims), ordinal - target_ordinal + ) # Eliminate extra plate dims via product contractions. contract_frames = ordinal - target_ordinal @@ -342,28 +355,33 @@ def einsum(equation, *operands, **kwargs): the size of any input tensor. """ # Extract kwargs. - cache = kwargs.pop('cache', None) - plates = kwargs.pop('plates', '') - backend = kwargs.pop('backend', 'torch') - modulo_total = kwargs.pop('modulo_total', False) + cache = kwargs.pop("cache", None) + plates = kwargs.pop("plates", "") + backend = kwargs.pop("backend", "torch") + modulo_total = kwargs.pop("modulo_total", False) try: Ring = BACKEND_TO_RING[backend] except KeyError as e: - raise NotImplementedError('\n'.join( - ['Only the following pyro backends are currently implemented:'] + - list(BACKEND_TO_RING))) from e + raise NotImplementedError( + "\n".join( + ["Only the following pyro backends are currently implemented:"] + + list(BACKEND_TO_RING) + ) + ) from e # Parse generalized einsum equation. - if '.' in equation: - raise NotImplementedError('ubsersum does not yet support ellipsis notation') - inputs, outputs = equation.split('->') - inputs = inputs.split(',') - outputs = outputs.split(',') + if "." in equation: + raise NotImplementedError("ubsersum does not yet support ellipsis notation") + inputs, outputs = equation.split("->") + inputs = inputs.split(",") + outputs = outputs.split(",") assert len(inputs) == len(operands) assert all(isinstance(x, torch.Tensor) for x in operands) if not modulo_total and any(outputs): - raise NotImplementedError('Try setting modulo_total=True and ensuring that your use case ' - 'allows an arbitrary scale factor on each result plate.') + raise NotImplementedError( + "Try setting modulo_total=True and ensuring that your use case " + "allows an arbitrary scale factor on each result plate." + ) if len(operands) != len(set(operands)): operands = [x[...] for x in operands] # ensure tensors are unique @@ -374,8 +392,11 @@ def einsum(equation, *operands, **kwargs): for dim, size in zip(dims, map(int, term.shape)): old = dim_to_size.setdefault(dim, size) if old != size: - raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}" - .format(dim, size, old)) + raise ValueError( + u"Dimension size mismatch at dim '{}': {} vs {}".format( + dim, size, old + ) + ) # Construct a tensor tree shared by all outputs. tensor_tree = OrderedDict() @@ -392,10 +413,13 @@ def einsum(equation, *operands, **kwargs): ring = Ring(cache, dim_to_size=dim_to_size) for output in outputs: sum_dims = set(output).union(*inputs) - set(plates) - term = contract_to_tensor(tensor_tree, sum_dims, - target_ordinal=plates.intersection(output), - target_dims=sum_dims.intersection(output), - ring=ring) + term = contract_to_tensor( + tensor_tree, + sum_dims, + target_ordinal=plates.intersection(output), + target_dims=sum_dims.intersection(output), + ring=ring, + ) if term._pyro_dims != output: term = term.permute(*map(term._pyro_dims.index, output)) term._pyro_dims = output @@ -407,13 +431,16 @@ def ubersum(equation, *operands, **kwargs): """ Deprecated, use :func:`einsum` instead. """ - warnings.warn("'ubersum' is deprecated, use 'pyro.ops.contract.einsum' instead", - DeprecationWarning) - if 'batch_dims' in kwargs: - warnings.warn("'batch_dims' is deprecated, use 'plates' instead", - DeprecationWarning) - kwargs['plates'] = kwargs.pop('batch_dims') - kwargs.setdefault('backend', 'pyro.ops.einsum.torch_log') + warnings.warn( + "'ubersum' is deprecated, use 'pyro.ops.contract.einsum' instead", + DeprecationWarning, + ) + if "batch_dims" in kwargs: + warnings.warn( + "'batch_dims' is deprecated, use 'plates' instead", DeprecationWarning + ) + kwargs["plates"] = kwargs.pop("batch_dims") + kwargs.setdefault("backend", "pyro.ops.einsum.torch_log") return einsum(equation, *operands, **kwargs) @@ -430,8 +457,11 @@ class _DimUnroller: :param dict dim_to_ordinal: a mapping from contraction dim to the set of plates over which the contraction dim is plated. """ + def __init__(self, dim_to_ordinal): - self._plates = {d: tuple(sorted(ordinal)) for d, ordinal in dim_to_ordinal.items()} + self._plates = { + d: tuple(sorted(ordinal)) for d, ordinal in dim_to_ordinal.items() + } self._symbols = map(opt_einsum.get_symbol, itertools.count()) self._map = {} @@ -463,17 +493,19 @@ def naive_ubersum(equation, *operands, **kwargs): :func:`ubersum` does not raise ``NotImplementedError``. """ # Parse equation, without loss of generality assuming a single output. - inputs, outputs = equation.split('->') - outputs = outputs.split(',') + inputs, outputs = equation.split("->") + outputs = outputs.split(",") if len(outputs) > 1: - return tuple(naive_ubersum(inputs + '->' + output, *operands, **kwargs)[0] - for output in outputs) - output, = outputs - inputs = inputs.split(',') - backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log') + return tuple( + naive_ubersum(inputs + "->" + output, *operands, **kwargs)[0] + for output in outputs + ) + (output,) = outputs + inputs = inputs.split(",") + backend = kwargs.pop("backend", "pyro.ops.einsum.torch_log") # Split dims into plate dims, contraction dims, and dims to keep. - plates = set(kwargs.pop('plates', '')) + plates = set(kwargs.pop("plates", "")) if not plates: result = opt_einsum.contract(equation, *operands, backend=backend) return (result,) @@ -485,8 +517,11 @@ def naive_ubersum(equation, *operands, **kwargs): for dim, size in zip(input_, operand.shape): old = sizes.setdefault(dim, size) if old != size: - raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}" - .format(dim, size, old)) + raise ValueError( + u"Dimension size mismatch at dim '{}': {} vs {}".format( + dim, size, old + ) + ) # Compute plate context for each non-plate dim, by convention the # intersection over all plate contexts of tensors in which the dim appears. @@ -506,20 +541,33 @@ def naive_ubersum(equation, *operands, **kwargs): local_dims = [d for d in input_ if d in plates] offsets = [input_.index(d) - len(input_) for d in local_dims] for index in itertools.product(*(range(sizes[d]) for d in local_dims)): - flat_inputs.append(''.join(unroll_dim(d, dict(zip(local_dims, index))) - for d in input_ if d not in plates)) + flat_inputs.append( + "".join( + unroll_dim(d, dict(zip(local_dims, index))) + for d in input_ + if d not in plates + ) + ) flat_operands.append(_select(operand, offsets, index)) # Defer to unplated einsum. - result = torch.empty(torch.Size(sizes[d] for d in output), - dtype=operands[0].dtype, device=operands[0].device) + result = torch.empty( + torch.Size(sizes[d] for d in output), + dtype=operands[0].dtype, + device=operands[0].device, + ) local_dims = [d for d in output if d in plates] offsets = [output.index(d) - len(output) for d in local_dims] for index in itertools.product(*(range(sizes[d]) for d in local_dims)): - flat_output = ''.join(unroll_dim(d, dict(zip(local_dims, index))) - for d in output if d not in plates) - flat_equation = ','.join(flat_inputs) + '->' + flat_output - flat_result = opt_einsum.contract(flat_equation, *flat_operands, backend=backend) + flat_output = "".join( + unroll_dim(d, dict(zip(local_dims, index))) + for d in output + if d not in plates + ) + flat_equation = ",".join(flat_inputs) + "->" + flat_output + flat_result = opt_einsum.contract( + flat_equation, *flat_operands, backend=backend + ) if not local_dims: result = flat_result break diff --git a/pyro/ops/dual_averaging.py b/pyro/ops/dual_averaging.py index 390220f941..f5951ec1db 100644 --- a/pyro/ops/dual_averaging.py +++ b/pyro/ops/dual_averaging.py @@ -60,7 +60,9 @@ def step(self, g): """ self._t += 1 # g_avg = (g_1 + ... + g_t) / t - self._g_avg = (1 - 1/(self._t + self.t0)) * self._g_avg + g / (self._t + self.t0) + self._g_avg = (1 - 1 / (self._t + self.t0)) * self._g_avg + g / ( + self._t + self.t0 + ) # According to formula (3.4) of [1], we have # x_t = argmin{ g_avg . x + loc_t . |x - x0|^2 }, # where loc_t := beta_t / t, beta_t := (gamma/2) * sqrt(t) diff --git a/pyro/ops/einsum/__init__.py b/pyro/ops/einsum/__init__.py index 3de44e0b56..93f568c0e9 100644 --- a/pyro/ops/einsum/__init__.py +++ b/pyro/ops/einsum/__init__.py @@ -17,7 +17,7 @@ def contract_expression(equation, *shapes, **kwargs): Defaults to True. """ # memoize the contraction path - cache_path = kwargs.pop('cache_path', True) + cache_path = kwargs.pop("cache_path", True) if cache_path: kwargs_key = tuple(kwargs.items()) key = equation, shapes, kwargs_key @@ -38,12 +38,12 @@ def contract(equation, *operands, **kwargs): :param bool cache_path: whether to cache the contraction path. Defaults to True. """ - backend = kwargs.pop('backend', 'numpy') - out = kwargs.pop('out', None) + backend = kwargs.pop("backend", "numpy") + out = kwargs.pop("out", None) shapes = [tuple(t.shape) for t in operands] with ignore_jit_warnings(): expr = contract_expression(equation, *shapes) return expr(*operands, backend=backend, out=out) -__all__ = ['contract', 'contract_expression'] +__all__ = ["contract", "contract_expression"] diff --git a/pyro/ops/einsum/adjoint.py b/pyro/ops/einsum/adjoint.py index 36f2940465..b98bead03f 100644 --- a/pyro/ops/einsum/adjoint.py +++ b/pyro/ops/einsum/adjoint.py @@ -38,7 +38,7 @@ def __init__(self, target): def process(self, message): target = self.target() - assert message is not target, 'memory leak' + assert message is not target, "memory leak" target._pyro_backward_result = message return () @@ -68,9 +68,9 @@ def process(self, message): # this requires https://github.com/dgasmith/opt_einsum/pull/74 def transpose(a, axes): result = a.permute(axes) - if hasattr(a, '_pyro_backward'): + if hasattr(a, "_pyro_backward"): result._pyro_backward = _TransposeBackward(a, axes) - result._pyro_name = getattr(a, '_pyro_name', '?') + "'" + result._pyro_name = getattr(a, "_pyro_name", "?") + "'" return result @@ -105,7 +105,7 @@ def einsum_backward_sample(operands, sample1, sample2): # Select sample dimensions to pass on to downstream sites. for x in operands: - if not hasattr(x, '_pyro_backward'): + if not hasattr(x, "_pyro_backward"): continue if sample is None: yield x._pyro_backward, None @@ -117,9 +117,10 @@ def einsum_backward_sample(operands, sample1, sample2): if x_sample_dims == set(sample._pyro_sample_dims): yield x._pyro_backward, sample continue - x_sample_dims = ''.join(sorted(x_sample_dims)) - x_sample = sample[[sample._pyro_sample_dims.index(dim) - for dim in x_sample_dims]] + x_sample_dims = "".join(sorted(x_sample_dims)) + x_sample = sample[ + [sample._pyro_sample_dims.index(dim) for dim in x_sample_dims] + ] x_sample._pyro_dims = sample._pyro_dims x_sample._pyro_sample_dims = x_sample_dims assert x_sample.dim() == len(x_sample._pyro_dims) diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index bb8d53890d..01ff601433 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -17,14 +17,14 @@ def einsum(equation, *operands): """ # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. - symbols = sorted(set(equation) - set(',->')) - rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) - equation = ''.join(rename.get(s, s) for s in equation) + symbols = sorted(set(equation) - set(",->")) + rename = dict(zip(symbols, "abcdefghijklmnopqrstuvwxyz")) + equation = "".join(rename.get(s, s) for s in equation) - inputs, output = equation.split('->') + inputs, output = equation.split("->") if inputs == output: return operands[0][...] # create a new object - inputs = inputs.split(',') + inputs = inputs.split(",") shifts = [] exp_operands = [] @@ -38,8 +38,9 @@ def einsum(equation, *operands): exp_operands.append((operand - shift).exp()) # permute shift to match output - shift = shift.reshape(torch.Size(size for size, dim in zip(operand.shape, dims) - if dim in output)) + shift = shift.reshape( + torch.Size(size for size, dim in zip(operand.shape, dims) if dim in output) + ) if shift.dim(): shift = shift.reshape((1,) * (len(output) - shift.dim()) + shift.shape) dims = [dim for dim in dims if dim in output] diff --git a/pyro/ops/einsum/torch_map.py b/pyro/ops/einsum/torch_map.py index 5eef63aa82..cf47d70a6b 100644 --- a/pyro/ops/einsum/torch_map.py +++ b/pyro/ops/einsum/torch_map.py @@ -31,16 +31,18 @@ def einsum(equation, *operands): This assumes all operands have a ``._pyro_dims`` attribute set. """ equation = packed.rename_equation(equation, *operands) - inputs, output = equation.split('->') - any_requires_backward = any(hasattr(x, '_pyro_backward') for x in operands) + inputs, output = equation.split("->") + any_requires_backward = any(hasattr(x, "_pyro_backward") for x in operands) - contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) + contract_dims = "".join( + sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)) + ) dims = output + contract_dims result = reduce(operator.add, packed.broadcast_all(*operands, dims=dims)) argmax = None # work around lack of pytorch support for zero-sized tensors if contract_dims: - output_shape = result.shape[:len(output)] - contract_shape = result.shape[len(output):] + output_shape = result.shape[: len(output)] + contract_shape = result.shape[len(output) :] result, argmax = result.reshape(output_shape + (-1,)).max(-1) if any_requires_backward: argmax = unflatten(argmax, output, contract_dims, contract_shape) diff --git a/pyro/ops/einsum/torch_marginal.py b/pyro/ops/einsum/torch_marginal.py index 25be1e3234..b197afbdd2 100644 --- a/pyro/ops/einsum/torch_marginal.py +++ b/pyro/ops/einsum/torch_marginal.py @@ -14,8 +14,8 @@ def __init__(self, equation, operands): def process(self, message): # Create extended lists of inputs and operands. operands = list(self.operands) - inputs, output = self.equation.split('->') - inputs = inputs.split(',') + inputs, output = self.equation.split("->") + inputs = inputs.split(",") if message is not None: assert message.dim() == len(output) inputs.append(output) @@ -32,9 +32,9 @@ def process(self, message): del inputs_i[i] del operands_i[i] if operands_i: - inputs_i = ','.join(inputs_i) - output_i = ''.join(dim for dim in output_i if dim in inputs_i) - equation = inputs_i + '->' + output_i + inputs_i = ",".join(inputs_i) + output_i = "".join(dim for dim in output_i if dim in inputs_i) + equation = inputs_i + "->" + output_i message_i = pyro.ops.einsum.torch_log.einsum(equation, *operands_i) if output_i != inputs[i]: for pos, dim in enumerate(inputs[i]): @@ -52,7 +52,7 @@ def einsum(equation, *operands): """ result = pyro.ops.einsum.torch_log.einsum(equation, *operands) - if any(hasattr(x, '_pyro_backward') for x in operands): + if any(hasattr(x, "_pyro_backward") for x in operands): result._pyro_backward = _EinsumBackward(equation, operands) return result diff --git a/pyro/ops/einsum/torch_sample.py b/pyro/ops/einsum/torch_sample.py index 9c2b66d926..162553a321 100644 --- a/pyro/ops/einsum/torch_sample.py +++ b/pyro/ops/einsum/torch_sample.py @@ -25,14 +25,16 @@ def __init__(self, output, operands): def process(self, message): output = self.output operands = list(self.operands) - contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) + contract_dims = "".join( + sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)) + ) batch_dims = output # Slice down operands before combining terms. sample2 = message if sample2 is not None: for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)): - batch_dims = batch_dims.replace(dim, '') + batch_dims = batch_dims.replace(dim, "") for i, x in enumerate(operands): if dim in x._pyro_dims: index._pyro_dims = sample2._pyro_dims[1:] @@ -46,8 +48,8 @@ def process(self, message): # Sample. sample1 = None # work around lack of pytorch support for zero-sized tensors if contract_dims: - output_shape = logits.shape[:len(batch_dims)] - contract_shape = logits.shape[len(batch_dims):] + output_shape = logits.shape[: len(batch_dims)] + contract_shape = logits.shape[len(batch_dims) :] flat_logits = logits.reshape(output_shape + (-1,)) flat_sample = dist.Categorical(logits=flat_logits).sample() sample1 = unflatten(flat_sample, batch_dims, contract_dims, contract_shape) @@ -62,12 +64,12 @@ def einsum(equation, *operands): This assumes all operands have a ``._pyro_dims`` attribute set. """ equation = packed.rename_equation(equation, *operands) - inputs, output = equation.split('->') + inputs, output = equation.split("->") result = pyro.ops.einsum.torch_log.einsum(equation, *operands) result._pyro_dims = output assert result.dim() == len(result._pyro_dims) - if any(hasattr(x, '_pyro_backward') for x in operands): + if any(hasattr(x, "_pyro_backward") for x in operands): result._pyro_backward = _EinsumBackward(output, operands) return result diff --git a/pyro/ops/einsum/util.py b/pyro/ops/einsum/util.py index 6bb0cdb21d..ee598d87a1 100644 --- a/pyro/ops/einsum/util.py +++ b/pyro/ops/einsum/util.py @@ -1,13 +1,14 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 AND MIT -EINSUM_SYMBOLS_BASE = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' +EINSUM_SYMBOLS_BASE = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" class Tensordot: """ Creates a tensordot implementation from an einsum implementation. """ + def __init__(self, einsum): self.einsum = einsum diff --git a/pyro/ops/gamma_gaussian.py b/pyro/ops/gamma_gaussian.py index c4d7c6a6ac..66807aa577 100644 --- a/pyro/ops/gamma_gaussian.py +++ b/pyro/ops/gamma_gaussian.py @@ -19,6 +19,7 @@ class Gamma: Gamma(concentration, rate) ~ (concentration - 1) * log(s) - rate * s """ + def __init__(self, log_normalizer, concentration, rate): self.log_normalizer = log_normalizer self.concentration = concentration @@ -36,8 +37,11 @@ def logsumexp(self): """ Integrates out the latent variable. """ - return self.log_normalizer + torch.lgamma(self.concentration) - \ - self.concentration * self.rate.log() + return ( + self.log_normalizer + + torch.lgamma(self.concentration) + - self.concentration * self.rate.log() + ) class GammaGaussian: @@ -77,6 +81,7 @@ class GammaGaussian: beta = Gamma.rate + 0.5 * info_vec.T @ inv(precision) @ info_vec """ + def __init__(self, log_normalizer, info_vec, precision, alpha, beta): # NB: using info_vec instead of mean to deal with rank-deficient problem assert info_vec.dim() >= 1 @@ -93,11 +98,13 @@ def dim(self): @lazy_property def batch_shape(self): - return broadcast_shape(self.log_normalizer.shape, - self.info_vec.shape[:-1], - self.precision.shape[:-2], - self.alpha.shape, - self.beta.shape) + return broadcast_shape( + self.log_normalizer.shape, + self.info_vec.shape[:-1], + self.precision.shape[:-2], + self.alpha.shape, + self.beta.shape, + ) def expand(self, batch_shape): n = self.dim() @@ -136,8 +143,10 @@ def cat(parts, dim=0): """ if dim < 0: dim += len(parts[0].batch_shape) - args = [torch.cat([getattr(g, attr) for g in parts], dim=dim) - for attr in ["log_normalizer", "info_vec", "precision", "alpha", "beta"]] + args = [ + torch.cat([getattr(g, attr) for g in parts], dim=dim) + for attr in ["log_normalizer", "info_vec", "precision", "alpha", "beta"] + ] return GammaGaussian(*args) def event_pad(self, left=0, right=0): @@ -151,7 +160,9 @@ def event_pad(self, left=0, right=0): # otherwise, we need to change alpha (similar for beta) to # keep the term (alpha + 0.5 * dim - 1) * log(s) constant # (note that `dim` has been changed due to padding) - return GammaGaussian(self.log_normalizer, info_vec, precision, self.alpha, self.beta) + return GammaGaussian( + self.log_normalizer, info_vec, precision, self.alpha, self.beta + ) def event_permute(self, perm): """ @@ -161,7 +172,9 @@ def event_permute(self, perm): assert perm.shape == (self.dim(),) info_vec = self.info_vec[..., perm] precision = self.precision[..., perm][..., perm, :] - return GammaGaussian(self.log_normalizer, info_vec, precision, self.alpha, self.beta) + return GammaGaussian( + self.log_normalizer, info_vec, precision, self.alpha, self.beta + ) def __add__(self, other): """ @@ -169,11 +182,13 @@ def __add__(self, other): """ assert isinstance(other, GammaGaussian) assert self.dim() == other.dim() - return GammaGaussian(self.log_normalizer + other.log_normalizer, - self.info_vec + other.info_vec, - self.precision + other.precision, - self.alpha + other.alpha, - self.beta + other.beta) + return GammaGaussian( + self.log_normalizer + other.log_normalizer, + self.info_vec + other.info_vec, + self.precision + other.precision, + self.alpha + other.alpha, + self.beta + other.beta, + ) def log_density(self, value, s): """ @@ -185,7 +200,11 @@ def log_density(self, value, s): """ if value.size(-1) == 0: batch_shape = broadcast_shape(value.shape[:-1], s.shape, self.batch_shape) - return self.alpha * s.log() - self.beta * s + self.log_normalizer.expand(batch_shape) + return ( + self.alpha * s.log() + - self.beta * s + + self.log_normalizer.expand(batch_shape) + ) result = (-0.5) * self.precision.matmul(value.unsqueeze(-1)).squeeze(-1) result = result + self.info_vec result = (value * result).sum(-1) @@ -222,7 +241,11 @@ def condition(self, value): log_normalizer = self.log_normalizer alpha = self.alpha - beta = self.beta + 0.5 * P_bb.matmul(b.unsqueeze(-1)).squeeze(-1).mul(b).sum(-1) - b.mul(info_b).sum(-1) + beta = ( + self.beta + + 0.5 * P_bb.matmul(b.unsqueeze(-1)).squeeze(-1).mul(b).sum(-1) + - b.mul(info_b).sum(-1) + ) return GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) def marginalize(self, left=0, right=0): @@ -265,9 +288,11 @@ def marginalize(self, left=0, right=0): alpha = self.alpha - 0.5 * n_b beta = self.beta - 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1) - log_normalizer = (self.log_normalizer + - 0.5 * n_b * math.log(2 * math.pi) - - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1)) + log_normalizer = ( + self.log_normalizer + + 0.5 * n_b * math.log(2 * math.pi) + - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) return GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) def compound(self): @@ -277,12 +302,16 @@ def compound(self): """ concentration = self.alpha - 0.5 * self.dim() + 1 scale_tril = precision_to_scale_tril(self.precision) - scale_tril_t_u = scale_tril.transpose(-1, -2).matmul(self.info_vec.unsqueeze(-1)).squeeze(-1) + scale_tril_t_u = ( + scale_tril.transpose(-1, -2).matmul(self.info_vec.unsqueeze(-1)).squeeze(-1) + ) u_Pinv_u = scale_tril_t_u.pow(2).sum(-1) rate = self.beta - 0.5 * u_Pinv_u loc = scale_tril.matmul(scale_tril_t_u.unsqueeze(-1)).squeeze(-1) - scale_tril = scale_tril * (rate / concentration).sqrt().unsqueeze(-1).unsqueeze(-1) + scale_tril = scale_tril * (rate / concentration).sqrt().unsqueeze(-1).unsqueeze( + -1 + ) return MultivariateStudentT(2 * concentration, loc, scale_tril) def event_logsumexp(self): @@ -291,7 +320,11 @@ def event_logsumexp(self): """ n = self.dim() chol_P = torch.linalg.cholesky(self.precision) - chol_P_u = self.info_vec.unsqueeze(-1).triangular_solve(chol_P, upper=False).solution.squeeze(-1) + chol_P_u = ( + self.info_vec.unsqueeze(-1) + .triangular_solve(chol_P, upper=False) + .solution.squeeze(-1) + ) u_P_u = chol_P_u.pow(2).sum(-1) # considering GammaGaussian as a Gaussian with precision = s * precision, info_vec = s * info_vec, # marginalize x variable, we get @@ -303,7 +336,9 @@ def event_logsumexp(self): # Gamma(concentration, rate) concentration = self.alpha - 0.5 * n + 1 rate = self.beta - 0.5 * u_P_u - log_normalizer_tmp = 0.5 * n * math.log(2 * math.pi) - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1) + log_normalizer_tmp = 0.5 * n * math.log(2 * math.pi) - chol_P.diagonal( + dim1=-2, dim2=-1 + ).log().sum(-1) return Gamma(self.log_normalizer + log_normalizer_tmp, concentration, rate) @@ -330,9 +365,12 @@ def gamma_and_mvn_to_gamma_gaussian(gamma, mvn): # reparameterized version of concentration, rate in GaussianGamma alpha = gamma.concentration + (0.5 * n - 1) beta = gamma.rate + 0.5 * (info_vec * mvn.loc).sum(-1) - gaussian_logsumexp = 0.5 * n * math.log(2 * math.pi) + \ - mvn.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) - log_normalizer = -Gamma(gaussian_logsumexp, gamma.concentration, gamma.rate).logsumexp() + gaussian_logsumexp = 0.5 * n * math.log(2 * math.pi) + mvn.scale_tril.diagonal( + dim1=-2, dim2=-1 + ).log().sum(-1) + log_normalizer = -Gamma( + gaussian_logsumexp, gamma.concentration, gamma.rate + ).logsumexp() return GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) @@ -377,12 +415,15 @@ def matrix_and_mvn_to_gamma_gaussian(matrix, mvn): P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2)) - precision = torch.cat([torch.cat([P_xx, P_xy], -1), - torch.cat([P_yx, P_yy], -1)], -2) + precision = torch.cat( + [torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2 + ) info_y = P_yy.matmul(mvn.loc.unsqueeze(-1)).squeeze(-1) info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1) info_vec = torch.cat([info_x, info_y], -1) - log_normalizer = -0.5 * y_dim * math.log(2 * math.pi) - mvn.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + log_normalizer = -0.5 * y_dim * math.log(2 * math.pi) - mvn.scale_tril.diagonal( + dim1=-2, dim2=-1 + ).log().sum(-1) beta = 0.5 * (info_y * mvn.loc).sum(-1) alpha = beta.new_full(beta.shape, 0.5 * y_dim) @@ -415,8 +456,15 @@ def gamma_gaussian_tensordot(x, y, dims=0): assert nc >= 0 device = x.info_vec.device - perm = torch.cat([ - torch.arange(na, device=device), - torch.arange(x.dim(), x.dim() + nc, device=device), - torch.arange(na, x.dim(), device=device)]) - return (x.event_pad(right=nc) + y.event_pad(left=na)).event_permute(perm).marginalize(right=nb) + perm = torch.cat( + [ + torch.arange(na, device=device), + torch.arange(x.dim(), x.dim() + nc, device=device), + torch.arange(na, x.dim(), device=device), + ] + ) + return ( + (x.event_pad(right=nc) + y.event_pad(left=na)) + .event_permute(perm) + .marginalize(right=nb) + ) diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 34998298e5..3712aa8af5 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -27,6 +27,7 @@ class Gaussian: fast and stable. :param torch.Tensor precision: precision matrix of this gaussian. """ + def __init__(self, log_normalizer, info_vec, precision): # NB: using info_vec instead of mean to deal with rank-deficient problem assert info_vec.dim() >= 1 @@ -41,9 +42,11 @@ def dim(self): @lazy_property def batch_shape(self): - return broadcast_shape(self.log_normalizer.shape, - self.info_vec.shape[:-1], - self.precision.shape[:-2]) + return broadcast_shape( + self.log_normalizer.shape, + self.info_vec.shape[:-1], + self.precision.shape[:-2], + ) def expand(self, batch_shape): n = self.dim() @@ -76,8 +79,10 @@ def cat(parts, dim=0): """ if dim < 0: dim += len(parts[0].batch_shape) - args = [torch.cat([getattr(g, attr) for g in parts], dim=dim) - for attr in ["log_normalizer", "info_vec", "precision"]] + args = [ + torch.cat([getattr(g, attr) for g in parts], dim=dim) + for attr in ["log_normalizer", "info_vec", "precision"] + ] return Gaussian(*args) def event_pad(self, left=0, right=0): @@ -106,9 +111,11 @@ def __add__(self, other): """ if isinstance(other, Gaussian): assert self.dim() == other.dim() - return Gaussian(self.log_normalizer + other.log_normalizer, - self.info_vec + other.info_vec, - self.precision + other.precision) + return Gaussian( + self.log_normalizer + other.log_normalizer, + self.info_vec + other.info_vec, + self.precision + other.precision, + ) if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) @@ -175,9 +182,11 @@ def condition(self, value): info_vec = info_a - matvecmul(P_ab, b) precision = P_aa - log_normalizer = (self.log_normalizer + - -0.5 * matvecmul(P_bb, b).mul(b).sum(-1) + - b.mul(info_b).sum(-1)) + log_normalizer = ( + self.log_normalizer + + -0.5 * matvecmul(P_bb, b).mul(b).sum(-1) + + b.mul(info_b).sum(-1) + ) return Gaussian(log_normalizer, info_vec, precision) def left_condition(self, value): @@ -200,8 +209,12 @@ def left_condition(self, value): dim = self.dim() assert left <= dim - perm = torch.cat([torch.arange(left, dim, device=value.device), - torch.arange(left, device=value.device)]) + perm = torch.cat( + [ + torch.arange(left, dim, device=value.device), + torch.arange(left, device=value.device), + ] + ) return self.event_permute(perm).condition(value) def marginalize(self, left=0, right=0): @@ -238,10 +251,12 @@ def marginalize(self, left=0, right=0): b_tmp = triangular_solve(info_b.unsqueeze(-1), P_b, upper=False) info_vec = info_a - matmul(P_at, b_tmp).squeeze(-1) - log_normalizer = (self.log_normalizer + - 0.5 * n_b * math.log(2 * math.pi) - - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) + - 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1)) + log_normalizer = ( + self.log_normalizer + + 0.5 * n_b * math.log(2 * math.pi) + - P_b.diagonal(dim1=-2, dim2=-1).log().sum(-1) + + 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1) + ) return Gaussian(log_normalizer, info_vec, precision) def event_logsumexp(self): @@ -250,10 +265,16 @@ def event_logsumexp(self): """ n = self.dim() chol_P = cholesky(self.precision) - chol_P_u = triangular_solve(self.info_vec.unsqueeze(-1), chol_P, upper=False).squeeze(-1) + chol_P_u = triangular_solve( + self.info_vec.unsqueeze(-1), chol_P, upper=False + ).squeeze(-1) u_P_u = chol_P_u.pow(2).sum(-1) - return (self.log_normalizer + 0.5 * n * math.log(2 * math.pi) + 0.5 * u_P_u - - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1)) + return ( + self.log_normalizer + + 0.5 * n * math.log(2 * math.pi) + + 0.5 * u_P_u + - chol_P.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) class AffineNormal: @@ -275,6 +296,7 @@ class AffineNormal: :param torch.Tensor scale: Standard deviation for ``Y``. Should have rightmost shape ``(y_dim,)``. """ + def __init__(self, matrix, loc, scale): assert loc.shape == scale.shape assert matrix.shape[:-2] == loc.shape[:-1] @@ -294,8 +316,11 @@ def condition(self, value): precision = matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) delta = (value - self.loc) / self.scale info_vec = matvecmul(prec_sqrt, delta) - log_normalizer = (-0.5 * self.loc.size(-1) * math.log(2 * math.pi) - - 0.5 * delta.pow(2).sum(-1) - self.scale.log().sum(-1)) + log_normalizer = ( + -0.5 * self.loc.size(-1) * math.log(2 * math.pi) + - 0.5 * delta.pow(2).sum(-1) + - self.scale.log().sum(-1) + ) return Gaussian(log_normalizer, info_vec, precision) else: return self.to_gaussian().condition(value) @@ -327,7 +352,8 @@ def rsample(self, sample_shape=torch.Size()): def to_gaussian(self): if self._gaussian is None: mvn = torch.distributions.Independent( - torch.distributions.Normal(self.loc, scale=self.scale), 1) + torch.distributions.Normal(self.loc, scale=self.scale), 1 + ) y_gaussian = mvn_to_gaussian(mvn) self._gaussian = _matrix_and_gaussian_to_gaussian(self.matrix, y_gaussian) return self._gaussian @@ -376,9 +402,10 @@ def mvn_to_gaussian(mvn): :return: An equivalent Gaussian object. :rtype: ~pyro.ops.gaussian.Gaussian """ - assert (isinstance(mvn, torch.distributions.MultivariateNormal) or - (isinstance(mvn, torch.distributions.Independent) and - isinstance(mvn.base_dist, torch.distributions.Normal))) + assert isinstance(mvn, torch.distributions.MultivariateNormal) or ( + isinstance(mvn, torch.distributions.Independent) + and isinstance(mvn.base_dist, torch.distributions.Normal) + ) if isinstance(mvn, torch.distributions.Independent): mvn = mvn.base_dist precision_diag = mvn.scale.pow(-2) @@ -391,9 +418,11 @@ def mvn_to_gaussian(mvn): scale_diag = mvn.scale_tril.diagonal(dim1=-2, dim2=-1) n = mvn.loc.size(-1) - log_normalizer = (-0.5 * n * math.log(2 * math.pi) + - -0.5 * (info_vec * mvn.loc).sum(-1) - - scale_diag.log().sum(-1)) + log_normalizer = ( + -0.5 * n * math.log(2 * math.pi) + + -0.5 * (info_vec * mvn.loc).sum(-1) + - scale_diag.log().sum(-1) + ) return Gaussian(log_normalizer, info_vec, precision) @@ -403,8 +432,9 @@ def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian): P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = matmul(neg_P_xy, matrix.transpose(-1, -2)) - precision = torch.cat([torch.cat([P_xx, P_xy], -1), - torch.cat([P_yx, P_yy], -1)], -2) + precision = torch.cat( + [torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2 + ) info_y = y_gaussian.info_vec info_x = -matvecmul(matrix, info_y) info_vec = torch.cat([info_x, info_y], -1) @@ -425,9 +455,10 @@ def matrix_and_mvn_to_gaussian(matrix, mvn): :return: A Gaussian with broadcasted batch shape and ``.dim() == x_dim + y_dim``. :rtype: ~pyro.ops.gaussian.Gaussian """ - assert (isinstance(mvn, torch.distributions.MultivariateNormal) or - (isinstance(mvn, torch.distributions.Independent) and - isinstance(mvn.base_dist, torch.distributions.Normal))) + assert isinstance(mvn, torch.distributions.MultivariateNormal) or ( + isinstance(mvn, torch.distributions.Independent) + and isinstance(mvn.base_dist, torch.distributions.Normal) + ) assert isinstance(matrix, torch.Tensor) x_dim, y_dim = matrix.shape[-2:] assert mvn.event_shape == (y_dim,) @@ -468,8 +499,16 @@ def gaussian_tensordot(x, y, dims=0): assert nb >= 0 assert nc >= 0 - Paa, Pba, Pbb = x.precision[..., :na, :na], x.precision[..., na:, :na], x.precision[..., na:, na:] - Qbb, Qbc, Qcc = y.precision[..., :nb, :nb], y.precision[..., :nb, nb:], y.precision[..., nb:, nb:] + Paa, Pba, Pbb = ( + x.precision[..., :na, :na], + x.precision[..., na:, :na], + x.precision[..., na:, na:], + ) + Qbb, Qbc, Qcc = ( + y.precision[..., :nb, :nb], + y.precision[..., :nb, nb:], + y.precision[..., nb:, nb:], + ) xa, xb = x.info_vec[..., :na], x.info_vec[..., na:] # x.precision @ x.mean yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:] # y.precision @ y.mean @@ -491,7 +530,11 @@ def gaussian_tensordot(x, y, dims=0): if na + nc > 0: info_vec = info_vec - matmul(LinvBt, Linvb).squeeze(-1) logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1) - diff = 0.5 * nb * math.log(2 * math.pi) + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet + diff = ( + 0.5 * nb * math.log(2 * math.pi) + + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) + - logdet + ) log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision) diff --git a/pyro/ops/hessian.py b/pyro/ops/hessian.py index a96aa1b2cb..4f166e6c07 100644 --- a/pyro/ops/hessian.py +++ b/pyro/ops/hessian.py @@ -14,7 +14,9 @@ def hessian(y, xs): flat_dy = torch.cat([dy.reshape(-1) for dy in dys]) H = [] for dyi in flat_dy: - Hi = torch.cat([Hij.reshape(-1) for Hij in torch.autograd.grad(dyi, xs, retain_graph=True)]) + Hi = torch.cat( + [Hij.reshape(-1) for Hij in torch.autograd.grad(dyi, xs, retain_graph=True)] + ) H.append(Hi) H = torch.stack(H) return H diff --git a/pyro/ops/indexing.py b/pyro/ops/indexing.py index 7fc3c82e29..2fc57aa9f8 100644 --- a/pyro/ops/indexing.py +++ b/pyro/ops/indexing.py @@ -71,6 +71,7 @@ class Index: :param torch.Tensor tensor: A tensor to be indexed. :return: An object with a special :meth:`__getitem__` method. """ + def __init__(self, tensor): self._tensor = tensor @@ -208,6 +209,7 @@ class Vindex: :param torch.Tensor tensor: A tensor to be indexed. :return: An object with a special :meth:`__getitem__` method. """ + def __init__(self, tensor): self._tensor = tensor diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 7970264738..6d17cb4452 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -4,7 +4,9 @@ from torch.autograd import grad -def velocity_verlet(z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None): +def velocity_verlet( + z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None +): r""" Second order symplectic integrator that uses the velocity verlet algorithm. @@ -27,12 +29,9 @@ def velocity_verlet(z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_ z_next = z.copy() r_next = r.copy() for _ in range(num_steps): - z_next, r_next, z_grads, potential_energy = _single_step_verlet(z_next, - r_next, - potential_fn, - kinetic_grad, - step_size, - z_grads) + z_next, r_next, z_grads, potential_energy = _single_step_verlet( + z_next, r_next, potential_fn, kinetic_grad, step_size, z_grads + ) return z_next, r_next, z_grads, potential_energy @@ -44,7 +43,9 @@ def _single_step_verlet(z, r, potential_fn, kinetic_grad, step_size, z_grads=Non z_grads = potential_grad(potential_fn, z)[0] if z_grads is None else z_grads for site_name in r: - r[site_name] = r[site_name] + 0.5 * step_size * (-z_grads[site_name]) # r(n+1/2) + r[site_name] = r[site_name] + 0.5 * step_size * ( + -z_grads[site_name] + ) # r(n+1/2) r_grads = kinetic_grad(r) for site_name in z: @@ -77,7 +78,7 @@ def potential_grad(potential_fn, z): except RuntimeError as e: if "singular U" in str(e): grads = {k: v.new_zeros(v.shape) for k, v in z.items()} - return grads, z_nodes[0].new_tensor(float('nan')) + return grads, z_nodes[0].new_tensor(float("nan")) else: raise e diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 247165fc11..5c28eb8d43 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -20,7 +20,11 @@ def _hash(value, allow_id): if isinstance(value, list): return tuple(_hash(x, allow_id) for x in value) elif isinstance(value, dict): - return tuple(sorted((_hash(x, allow_id), _hash(y, allow_id)) for x, y in value.items())) + return tuple( + sorted( + (_hash(x, allow_id), _hash(y, allow_id)) for x, y in value.items() + ) + ) elif isinstance(value, set): return frozenset(_hash(x, allow_id) for x in value) elif isinstance(value, argparse.Namespace): @@ -51,12 +55,13 @@ class CompiledFunction: The actual PyTorch compilation artifact is stored in :attr:`compiled`. Call diagnostic methods on this attribute. """ + def __init__(self, fn, ignore_warnings=False, jit_options=None): self.fn = fn self.compiled = {} # len(args) -> callable self.ignore_warnings = ignore_warnings self.jit_options = {} if jit_options is None else jit_options - self.jit_options.setdefault('check_trace', False) + self.jit_options.setdefault("check_trace", False) self.compile_time = None self._param_names = None @@ -71,33 +76,43 @@ def __call__(self, *args, **kwargs): self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) - unconstrained_params = tuple(pyro.param(name).unconstrained() - for name in self._param_names) + unconstrained_params = tuple( + pyro.param(name).unconstrained() for name in self._param_names + ) params_and_args = unconstrained_params + args weakself = weakref.ref(self) def compiled(*params_and_args): self = weakself() - unconstrained_params = params_and_args[:len(self._param_names)] - args = params_and_args[len(self._param_names):] + unconstrained_params = params_and_args[: len(self._param_names)] + args = params_and_args[len(self._param_names) :] constrained_params = {} - for name, unconstrained_param in zip(self._param_names, unconstrained_params): - constrained_param = pyro.param(name) # assume param has been initialized + for name, unconstrained_param in zip( + self._param_names, unconstrained_params + ): + constrained_param = pyro.param( + name + ) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param constrained_params[name] = constrained_param - return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) + return poutine.replay(self.fn, params=constrained_params)( + *args, **kwargs + ) if self.ignore_warnings: compiled = ignore_jit_warnings()(compiled) with pyro.validation_enabled(False): time_compilation = self.jit_options.pop("time_compilation", False) with optional(timed(), time_compilation) as t: - self.compiled[key] = torch.jit.trace(compiled, params_and_args, **self.jit_options) + self.compiled[key] = torch.jit.trace( + compiled, params_and_args, **self.jit_options + ) if time_compilation: self.compile_time = t.elapsed else: - unconstrained_params = [pyro.param(name).unconstrained() - for name in self._param_names] + unconstrained_params = [ + pyro.param(name).unconstrained() for name in self._param_names + ] params_and_args = unconstrained_params + list(args) with poutine.block(hide=self._param_names): @@ -106,8 +121,10 @@ def compiled(*params_and_args): for name in param_capture.trace.nodes.keys(): if name not in self._param_names: - raise NotImplementedError('pyro.ops.jit.trace assumes all params are created on ' - 'first invocation, but found new param: {}'.format(name)) + raise NotImplementedError( + "pyro.ops.jit.trace assumes all params are created on " + "first invocation, but found new param: {}".format(name) + ) return ret @@ -138,5 +155,9 @@ def model_log_prob_fn(x, y): :func:`torch.jit.trace` , e.g. ``{"optimize": False}``. """ if fn is None: - return lambda fn: trace(fn, ignore_warnings=ignore_warnings, jit_options=jit_options) - return CompiledFunction(fn, ignore_warnings=ignore_warnings, jit_options=jit_options) + return lambda fn: trace( + fn, ignore_warnings=ignore_warnings, jit_options=jit_options + ) + return CompiledFunction( + fn, ignore_warnings=ignore_warnings, jit_options=jit_options + ) diff --git a/pyro/ops/linalg.py b/pyro/ops/linalg.py index 0b2c29a43c..d6d247c1fd 100644 --- a/pyro/ops/linalg.py +++ b/pyro/ops/linalg.py @@ -15,9 +15,9 @@ def rinverse(M, sym=False): """ assert M.shape[-1] == M.shape[-2] if M.shape[-1] == 1: - return 1./M + return 1.0 / M elif M.shape[-1] == 2: - det = M[..., 0, 0]*M[..., 1, 1] - M[..., 1, 0]*M[..., 0, 1] + det = M[..., 0, 0] * M[..., 1, 1] - M[..., 1, 0] * M[..., 0, 1] inv = torch.empty_like(M) inv[..., 0, 0] = M[..., 1, 1] inv[..., 1, 1] = M[..., 0, 0] @@ -34,9 +34,11 @@ def determinant_3d(H): """ Returns the determinants of a batched 3-D matrix """ - detH = (H[..., 0, 0] * (H[..., 1, 1] * H[..., 2, 2] - H[..., 2, 1] * H[..., 1, 2]) + - H[..., 0, 1] * (H[..., 1, 2] * H[..., 2, 0] - H[..., 1, 0] * H[..., 2, 2]) + - H[..., 0, 2] * (H[..., 1, 0] * H[..., 2, 1] - H[..., 2, 0] * H[..., 1, 1])) + detH = ( + H[..., 0, 0] * (H[..., 1, 1] * H[..., 2, 2] - H[..., 2, 1] * H[..., 1, 2]) + + H[..., 0, 1] * (H[..., 1, 2] * H[..., 2, 0] - H[..., 1, 0] * H[..., 2, 2]) + + H[..., 0, 2] * (H[..., 1, 0] * H[..., 2, 1] - H[..., 2, 0] * H[..., 1, 1]) + ) return detH @@ -46,16 +48,23 @@ def eig_3d(H): """ p1 = H[..., 0, 1].pow(2) + H[..., 0, 2].pow(2) + H[..., 1, 2].pow(2) q = (H[..., 0, 0] + H[..., 1, 1] + H[..., 2, 2]) / 3 - p2 = (H[..., 0, 0] - q).pow(2) + (H[..., 1, 1] - q).pow(2) + (H[..., 2, 2] - q).pow(2) + 2 * p1 + p2 = ( + (H[..., 0, 0] - q).pow(2) + + (H[..., 1, 1] - q).pow(2) + + (H[..., 2, 2] - q).pow(2) + + 2 * p1 + ) p = torch.sqrt(p2 / 6) - B = (1 / p).unsqueeze(-1).unsqueeze(-1) * (H - q.unsqueeze(-1).unsqueeze(-1) * torch.eye(3)) + B = (1 / p).unsqueeze(-1).unsqueeze(-1) * ( + H - q.unsqueeze(-1).unsqueeze(-1) * torch.eye(3) + ) r = determinant_3d(B) / 2 phi = (r.acos() / 3).unsqueeze(-1).unsqueeze(-1).expand(r.shape + (3, 3)).clone() phi[r < -1 + 1e-6] = math.pi / 3 - phi[r > 1 - 1e-6] = 0. + phi[r > 1 - 1e-6] = 0.0 eig1 = q + 2 * p * torch.cos(phi[..., 0, 0]) - eig2 = q + 2 * p * torch.cos(phi[..., 0, 0] + (2 * math.pi/3)) + eig2 = q + 2 * p * torch.cos(phi[..., 0, 0] + (2 * math.pi / 3)) eig3 = 3 * q - eig1 - eig2 # eig2 <= eig3 <= eig1 return eig2, eig3, eig1 diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index e651b4284e..4907df37c0 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -60,7 +60,9 @@ def newton_step(loss, x, trust_radius=None): :rtype: tuple """ if x.dim() < 1: - raise ValueError('Expected x to have at least one dimension, actual shape {}'.format(x.shape)) + raise ValueError( + "Expected x to have at least one dimension, actual shape {}".format(x.shape) + ) dim = x.shape[-1] if dim == 1: return newton_step_1d(loss, x, trust_radius) @@ -69,7 +71,7 @@ def newton_step(loss, x, trust_radius=None): elif dim == 3: return newton_step_3d(loss, x, trust_radius) else: - raise NotImplementedError('newton_step_nd is not implemented') + raise NotImplementedError("newton_step_nd is not implemented") def newton_step_1d(loss, x, trust_radius=None): @@ -91,15 +93,19 @@ def newton_step_1d(loss, x, trust_radius=None): :rtype: tuple """ if loss.shape != (): - raise ValueError('Expected loss to be a scalar, actual shape {}'.format(loss.shape)) + raise ValueError( + "Expected loss to be a scalar, actual shape {}".format(loss.shape) + ) if x.dim() < 1 or x.shape[-1] != 1: - raise ValueError('Expected x to have rightmost size 1, actual shape {}'.format(x.shape)) + raise ValueError( + "Expected x to have rightmost size 1, actual shape {}".format(x.shape) + ) # compute derivatives g = grad(loss, [x], create_graph=True)[0] H = grad(g.sum(), [x], create_graph=True)[0] - warn_if_nan(g, 'g') - warn_if_nan(H, 'H') + warn_if_nan(g, "g") + warn_if_nan(H, "H") Hinv = H.clamp(min=1e-8).reciprocal() dx = -g * Hinv dx[~(dx == dx)] = 0 @@ -131,31 +137,44 @@ def newton_step_2d(loss, x, trust_radius=None): :rtype: tuple """ if loss.shape != (): - raise ValueError('Expected loss to be a scalar, actual shape {}'.format(loss.shape)) + raise ValueError( + "Expected loss to be a scalar, actual shape {}".format(loss.shape) + ) if x.dim() < 1 or x.shape[-1] != 2: - raise ValueError('Expected x to have rightmost size 2, actual shape {}'.format(x.shape)) + raise ValueError( + "Expected x to have rightmost size 2, actual shape {}".format(x.shape) + ) # compute derivatives g = grad(loss, [x], create_graph=True)[0] - H = torch.stack([grad(g[..., 0].sum(), [x], create_graph=True)[0], - grad(g[..., 1].sum(), [x], create_graph=True)[0]], -1) + H = torch.stack( + [ + grad(g[..., 0].sum(), [x], create_graph=True)[0], + grad(g[..., 1].sum(), [x], create_graph=True)[0], + ], + -1, + ) assert g.shape[-1:] == (2,) assert H.shape[-2:] == (2, 2) - warn_if_nan(g, 'g') - warn_if_nan(H, 'H') + warn_if_nan(g, "g") + warn_if_nan(H, "H") if trust_radius is not None: # regularize to keep update within ball of given trust_radius detH = H[..., 0, 0] * H[..., 1, 1] - H[..., 0, 1] * H[..., 1, 0] mean_eig = (H[..., 0, 0] + H[..., 1, 1]) / 2 min_eig = mean_eig - (mean_eig ** 2 - detH).clamp(min=0).sqrt() - regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) - warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(2, dtype=H.dtype, device=H.device) + regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_( + min=1e-8 + ) + warn_if_nan(regularizer, "regularizer") + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye( + 2, dtype=H.dtype, device=H.device + ) # compute newton update Hinv = rinverse(H, sym=True) - warn_if_nan(Hinv, 'Hinv') + warn_if_nan(Hinv, "Hinv") # apply update x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1) @@ -182,31 +201,44 @@ def newton_step_3d(loss, x, trust_radius=None): :rtype: tuple """ if loss.shape != (): - raise ValueError('Expected loss to be a scalar, actual shape {}'.format(loss.shape)) + raise ValueError( + "Expected loss to be a scalar, actual shape {}".format(loss.shape) + ) if x.dim() < 1 or x.shape[-1] != 3: - raise ValueError('Expected x to have rightmost size 3, actual shape {}'.format(x.shape)) + raise ValueError( + "Expected x to have rightmost size 3, actual shape {}".format(x.shape) + ) # compute derivatives g = grad(loss, [x], create_graph=True)[0] - H = torch.stack([grad(g[..., 0].sum(), [x], create_graph=True)[0], - grad(g[..., 1].sum(), [x], create_graph=True)[0], - grad(g[..., 2].sum(), [x], create_graph=True)[0]], -1) + H = torch.stack( + [ + grad(g[..., 0].sum(), [x], create_graph=True)[0], + grad(g[..., 1].sum(), [x], create_graph=True)[0], + grad(g[..., 2].sum(), [x], create_graph=True)[0], + ], + -1, + ) assert g.shape[-1:] == (3,) assert H.shape[-2:] == (3, 3) - warn_if_nan(g, 'g') - warn_if_nan(H, 'H') + warn_if_nan(g, "g") + warn_if_nan(H, "H") if trust_radius is not None: # regularize to keep update within ball of given trust_radius # calculate eigenvalues of symmetric matrix min_eig, _, _ = eig_3d(H) - regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) - warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(3, dtype=H.dtype, device=H.device) + regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_( + min=1e-8 + ) + warn_if_nan(regularizer, "regularizer") + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye( + 3, dtype=H.dtype, device=H.device + ) # compute newton update Hinv = rinverse(H, sym=True) - warn_if_nan(Hinv, 'Hinv') + warn_if_nan(Hinv, "Hinv") # apply update x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1) diff --git a/pyro/ops/packed.py b/pyro/ops/packed.py index 8dc8fc96ab..0570d66869 100644 --- a/pyro/ops/packed.py +++ b/pyro/ops/packed.py @@ -17,20 +17,31 @@ def pack(value, dim_to_symbol): :param dim_to_symbol: a map from negative integers to characters """ if isinstance(value, torch.Tensor): - assert not hasattr(value, '_pyro_dims'), 'tried to pack an already-packed tensor' + assert not hasattr( + value, "_pyro_dims" + ), "tried to pack an already-packed tensor" shape = value.shape shift = len(shape) try: with ignore_jit_warnings(): - dims = ''.join(dim_to_symbol[dim - shift] - for dim, size in enumerate(shape) - if size > 1) + dims = "".join( + dim_to_symbol[dim - shift] + for dim, size in enumerate(shape) + if size > 1 + ) except KeyError as e: - raise ValueError('\n '.join([ - 'Invalid tensor shape.', - 'Allowed dims: {}'.format(', '.join(map(str, sorted(dim_to_symbol)))), - 'Actual shape: {}'.format(tuple(value.shape)), - "Try adding shape assertions for your model's sample values and distribution parameters."])) from e + raise ValueError( + "\n ".join( + [ + "Invalid tensor shape.", + "Allowed dims: {}".format( + ", ".join(map(str, sorted(dim_to_symbol))) + ), + "Actual shape: {}".format(tuple(value.shape)), + "Try adding shape assertions for your model's sample values and distribution parameters.", + ] + ) + ) from e value = value.squeeze() value._pyro_dims = dims assert value.dim() == len(value._pyro_dims) @@ -63,10 +74,14 @@ def broadcast_all(*values, **kwargs): """ Packed broadcasting of multiple tensors. """ - dims = kwargs.get('dims') - sizes = {dim: size for value in values for dim, size in zip(value._pyro_dims, value.shape)} + dims = kwargs.get("dims") + sizes = { + dim: size + for value in values + for dim, size in zip(value._pyro_dims, value.shape) + } if dims is None: - dims = ''.join(sorted(sizes)) + dims = "".join(sorted(sizes)) else: assert set(dims) == set(sizes) shape = torch.Size(sizes[dim] for dim in dims) @@ -90,7 +105,7 @@ def gather(value, index, dim): assert dim in value._pyro_dims assert dim not in index._pyro_dims value, index = broadcast_all(value, index) - dims = value._pyro_dims.replace(dim, '') + dims = value._pyro_dims.replace(dim, "") pos = value._pyro_dims.index(dim) with ignore_jit_warnings(): zero = torch.zeros(1, dtype=torch.long, device=index.device) @@ -106,9 +121,9 @@ def mul(lhs, rhs): Packed broadcasted multiplication. """ if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): - dims = ''.join(sorted(set(lhs._pyro_dims + rhs._pyro_dims))) - equation = lhs._pyro_dims + ',' + rhs._pyro_dims + '->' + dims - result = torch.einsum(equation, lhs, rhs, backend='torch') + dims = "".join(sorted(set(lhs._pyro_dims + rhs._pyro_dims))) + equation = lhs._pyro_dims + "," + rhs._pyro_dims + "->" + dims + result = torch.einsum(equation, lhs, rhs, backend="torch") result._pyro_dims = dims return result result = lhs * rhs @@ -130,7 +145,7 @@ def scale_and_mask(tensor, scale=1.0, mask=None): :type mask: torch.BoolTensor, bool, or None """ if isinstance(scale, torch.Tensor) and scale.dim(): - raise NotImplementedError('non-scalar scale is not supported') + raise NotImplementedError("non-scalar scale is not supported") if mask is None or mask is True: if is_identically_one(scale): return tensor @@ -174,10 +189,12 @@ def rename_equation(equation, *operands): Renames symbols in an einsum/ubersum equation to match the ``.pyro_dims`` attributes of packed ``operands``. """ - inputs, outputs = equation.split('->') - inputs = inputs.split(',') + inputs, outputs = equation.split("->") + inputs = inputs.split(",") assert len(inputs) == len(operands) - rename = {old: new - for input_, operand in zip(inputs, operands) - for old, new in zip(input_, operand._pyro_dims)} - return ''.join(rename.get(s, s) for s in equation) + rename = { + old: new + for input_, operand in zip(inputs, operands) + for old, new in zip(input_, operand._pyro_dims) + } + return "".join(rename.get(s, s) for s in equation) diff --git a/pyro/ops/rings.py b/pyro/ops/rings.py index 0c4c11d68e..58c68e865d 100644 --- a/pyro/ops/rings.py +++ b/pyro/ops/rings.py @@ -27,6 +27,7 @@ class Ring(object, metaclass=ABCMeta): :param dict cache: an optional :func:`~opt_einsum.shared_intermediates` cache. """ + def __init__(self, cache=None): self._cache = {} if cache is None else cache @@ -36,7 +37,7 @@ def _hash_by_id(self, tensor): used as a key in the cache without risk of the id being recycled. """ result = id(tensor) - assert self._cache.setdefault(('tensor', result), tensor) is tensor + assert self._cache.setdefault(("tensor", result), tensor) is tensor return result @abstractmethod @@ -70,9 +71,9 @@ def broadcast(self, term, ordinal): :param frozenset ordinal: an ordinal specifying plates """ dims = term._pyro_dims - missing_dims = ''.join(sorted(set(ordinal) - set(dims))) + missing_dims = "".join(sorted(set(ordinal) - set(dims))) if missing_dims: - key = 'broadcast', self._hash_by_id(term), missing_dims + key = "broadcast", self._hash_by_id(term), missing_dims if key in self._cache: term = self._cache[key] else: @@ -107,8 +108,8 @@ def global_local(self, term, dims, ordinal): :return: a tuple ``(global_part, local_part)`` as defined above :rtype: tuple """ - assert dims, 'dims was empty, use .product() instead' - key = 'global_local', self._hash_by_id(term), frozenset(dims), ordinal + assert dims, "dims was empty, use .product() instead" + key = "global_local", self._hash_by_id(term), frozenset(dims), ordinal if key in self._cache: return self._cache[key] @@ -130,7 +131,8 @@ class LinearRing(Ring): ``._pyro_dims`` attribute, which is a string of dimension names aligned with the tensor's shape. """ - _backend = 'torch' + + _backend = "torch" def __init__(self, cache=None, dim_to_size=None): super().__init__(cache=cache) @@ -138,8 +140,8 @@ def __init__(self, cache=None, dim_to_size=None): def sumproduct(self, terms, dims): inputs = [term._pyro_dims for term in terms] - output = ''.join(sorted(set(''.join(inputs)) - set(dims))) - equation = ','.join(inputs) + '->' + output + output = "".join(sorted(set("".join(inputs)) - set(dims))) + equation = ",".join(inputs) + "->" + output term = contract(equation, *terms, backend=self._backend) term._pyro_dims = output return term @@ -149,23 +151,25 @@ def product(self, term, ordinal): for dim in sorted(ordinal, reverse=True): pos = dims.find(dim) if pos != -1: - key = 'product', self._hash_by_id(term), dim + key = "product", self._hash_by_id(term), dim if key in self._cache: term = self._cache[key] else: term = term.prod(pos) - dims = dims.replace(dim, '') + dims = dims.replace(dim, "") self._cache[key] = term term._pyro_dims = dims return term def inv(self, term): - key = 'inv', self._hash_by_id(term) + key = "inv", self._hash_by_id(term) if key in self._cache: return self._cache[key] result = term.reciprocal() - result = result.clamp(max=torch.finfo(result.dtype).max) # avoid nan due to inf / inf + result = result.clamp( + max=torch.finfo(result.dtype).max + ) # avoid nan due to inf / inf result._pyro_dims = term._pyro_dims self._cache[key] = result return result @@ -181,7 +185,8 @@ class LogRing(Ring): ``._pyro_dims`` attribute, which is a string of dimension names aligned with the tensor's shape. """ - _backend = 'pyro.ops.einsum.torch_log' + + _backend = "pyro.ops.einsum.torch_log" def __init__(self, cache=None, dim_to_size=None): super().__init__(cache=cache) @@ -189,8 +194,8 @@ def __init__(self, cache=None, dim_to_size=None): def sumproduct(self, terms, dims): inputs = [term._pyro_dims for term in terms] - output = ''.join(sorted(set(''.join(inputs)) - set(dims))) - equation = ','.join(inputs) + '->' + output + output = "".join(sorted(set("".join(inputs)) - set(dims))) + equation = ",".join(inputs) + "->" + output term = contract(equation, *terms, backend=self._backend) term._pyro_dims = output return term @@ -200,23 +205,25 @@ def product(self, term, ordinal): for dim in sorted(ordinal, reverse=True): pos = dims.find(dim) if pos != -1: - key = 'product', self._hash_by_id(term), dim + key = "product", self._hash_by_id(term), dim if key in self._cache: term = self._cache[key] else: term = term.sum(pos) - dims = dims.replace(dim, '') + dims = dims.replace(dim, "") self._cache[key] = term term._pyro_dims = dims return term def inv(self, term): - key = 'inv', self._hash_by_id(term) + key = "inv", self._hash_by_id(term) if key in self._cache: return self._cache[key] result = -term - result = result.clamp(max=torch.finfo(result.dtype).max) # avoid nan due to inf - inf + result = result.clamp( + max=torch.finfo(result.dtype).max + ) # avoid nan due to inf - inf result._pyro_dims = term._pyro_dims self._cache[key] = result return result @@ -230,6 +237,7 @@ class _SampleProductBackward(Backward): :class:`MapRing` (temperature 0 sampling) and :class:`SampleRing` (temperature 1 sampling). """ + def __init__(self, ring, term, ordinal): self.ring = ring self.term = term @@ -240,7 +248,7 @@ def process(self, message): sample_dims = message._pyro_sample_dims message = self.ring.broadcast(message, self.ordinal) if message._pyro_dims.index(SAMPLE_SYMBOL) != 0: - dims = SAMPLE_SYMBOL + message._pyro_dims.replace(SAMPLE_SYMBOL, '') + dims = SAMPLE_SYMBOL + message._pyro_dims.replace(SAMPLE_SYMBOL, "") message = message.permute(tuple(map(message._pyro_dims.find, dims))) message._pyro_dims = dims assert message.dim() == len(message._pyro_dims) @@ -253,11 +261,12 @@ class MapRing(LogRing): """ Ring of forward-maxsum backward-argmax operations. """ - _backend = 'pyro.ops.einsum.torch_map' + + _backend = "pyro.ops.einsum.torch_map" def product(self, term, ordinal): result = super().product(term, ordinal) - if hasattr(term, '_pyro_backward'): + if hasattr(term, "_pyro_backward"): result._pyro_backward = _SampleProductBackward(self, term, ordinal) return result @@ -266,11 +275,12 @@ class SampleRing(LogRing): """ Ring of forward-sumproduct backward-sample operations in log space. """ - _backend = 'pyro.ops.einsum.torch_sample' + + _backend = "pyro.ops.einsum.torch_sample" def product(self, term, ordinal): result = super().product(term, ordinal) - if hasattr(term, '_pyro_backward'): + if hasattr(term, "_pyro_backward"): result._pyro_backward = _SampleProductBackward(self, term, ordinal) return result @@ -279,6 +289,7 @@ class _MarginalProductBackward(Backward): """ Backward-marginal implementation of product, using inclusion-exclusion. """ + def __init__(self, ring, term, ordinal, result): self.ring = ring self.term = term @@ -306,19 +317,22 @@ class MarginalRing(LogRing): """ Ring of forward-sumproduct backward-marginal operations in log space. """ - _backend = 'pyro.ops.einsum.torch_marginal' + + _backend = "pyro.ops.einsum.torch_marginal" def product(self, term, ordinal): result = super().product(term, ordinal) - if hasattr(term, '_pyro_backward'): - result._pyro_backward = _MarginalProductBackward(self, term, ordinal, result) + if hasattr(term, "_pyro_backward"): + result._pyro_backward = _MarginalProductBackward( + self, term, ordinal, result + ) return result BACKEND_TO_RING = { - 'torch': LinearRing, - 'pyro.ops.einsum.torch_log': LogRing, - 'pyro.ops.einsum.torch_map': MapRing, - 'pyro.ops.einsum.torch_sample': SampleRing, - 'pyro.ops.einsum.torch_marginal': MarginalRing, + "torch": LinearRing, + "pyro.ops.einsum.torch_log": LogRing, + "pyro.ops.einsum.torch_map": MapRing, + "pyro.ops.einsum.torch_sample": SampleRing, + "pyro.ops.einsum.torch_marginal": MarginalRing, } diff --git a/pyro/ops/special.py b/pyro/ops/special.py index 1291eeadf1..5e1bf109e4 100644 --- a/pyro/ops/special.py +++ b/pyro/ops/special.py @@ -16,7 +16,7 @@ def forward(ctx, x): @staticmethod def backward(ctx, grad): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors return grad / x.clamp(min=torch.finfo(x.dtype).eps) @@ -28,7 +28,7 @@ def safe_log(x): return _SafeLog.apply(x) -def log_beta(x, y, tol=0.): +def log_beta(x, y, tol=0.0): """ Computes log Beta function. @@ -76,12 +76,17 @@ def log_beta(x, y, tol=0.): log_factor = functools.reduce(operator.mul, factors).log() - return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log() - - (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift)) + return ( + log_factor + + (x - 0.5) * x.log() + + (y - 0.5) * y.log() + - (xy - 0.5) * xy.log() + + (math.log(2 * math.pi) / 2 - shift) + ) @torch.no_grad() -def log_binomial(n, k, tol=0.): +def log_binomial(n, k, tol=0.0): """ Computes log binomial coefficient. @@ -102,7 +107,7 @@ def log_binomial(n, k, tol=0.): def log_I1(orders: int, value: torch.Tensor, terms=250): - r""" Compute first n log modified bessel function of first kind + r"""Compute first n log modified bessel function of first kind .. math :: \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk) @@ -136,7 +141,11 @@ def log_I1(orders: int, value: torch.Tensor, terms=250): indices = k[:orders].view(-1, 1) + k.view(1, -1) assert indices.shape == (orders, terms) - seqs = (2 * lvalues[None, :, :] - lfactorials[None, None, :] - lgammas.gather(1, indices)[:, None, :]).logsumexp(-1) + seqs = ( + 2 * lvalues[None, :, :] + - lfactorials[None, None, :] + - lgammas.gather(1, indices)[:, None, :] + ).logsumexp(-1) assert seqs.shape == (orders, vshape.numel()) i1s = lvalues[..., :orders].T + seqs diff --git a/pyro/ops/ssm_gp.py b/pyro/ops/ssm_gp.py index 89abcb2912..160ec30431 100644 --- a/pyro/ops/ssm_gp.py +++ b/pyro/ops/ssm_gp.py @@ -32,9 +32,14 @@ class MaternKernel(PyroModule): [2] `Stochastic Differential Equation Methods for Spatio-Temporal Gaussian Process Regression`, Arno Solin. """ - def __init__(self, nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init=None): + + def __init__( + self, nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init=None + ): if nu not in [0.5, 1.5, 2.5]: - raise NotImplementedError("The only supported values of nu are 0.5, 1.5 and 2.5") + raise NotImplementedError( + "The only supported values of nu are 0.5, 1.5 and 2.5" + ) self.nu = nu self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu] self.num_gps = num_gps @@ -49,8 +54,12 @@ def __init__(self, nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init= super().__init__() - self.length_scale = PyroParam(length_scale_init, constraint=constraints.positive) - self.kernel_scale = PyroParam(kernel_scale_init, constraint=constraints.positive) + self.length_scale = PyroParam( + length_scale_init, constraint=constraints.positive + ) + self.kernel_scale = PyroParam( + kernel_scale_init, constraint=constraints.positive + ) if self.state_dim > 1: for x in range(self.state_dim): @@ -78,10 +87,12 @@ def transition_matrix(self, dt): elif self.nu == 1.5: rho = self.length_scale.unsqueeze(-1).unsqueeze(-1) dt_rho = dt / rho - trans = (1.0 + root_three * dt_rho) * self.mask00 + \ - (-3.0 * dt_rho / rho) * self.mask01 + \ - dt * self.mask10 + \ - (1.0 - root_three * dt_rho) * self.mask11 + trans = ( + (1.0 + root_three * dt_rho) * self.mask00 + + (-3.0 * dt_rho / rho) * self.mask01 + + dt * self.mask10 + + (1.0 - root_three * dt_rho) * self.mask11 + ) return torch.exp(-root_three * dt_rho) * trans elif self.nu == 2.5: rho = self.length_scale.unsqueeze(-1).unsqueeze(-1) @@ -90,15 +101,17 @@ def transition_matrix(self, dt): dt_rho_cu = dt_rho.pow(3.0) dt_rho_qu = dt_rho.pow(4.0) dt_sq = dt ** 2.0 - trans = (1.0 + dt_rho + 0.5 * dt_rho_sq) * self.mask00 + \ - (-0.5 * dt_rho_cu / dt) * self.mask01 + \ - ((0.5 * dt_rho_qu - dt_rho_cu) / dt_sq) * self.mask02 + \ - ((dt_rho + 1.0) * dt) * self.mask10 + \ - (1.0 + dt_rho - dt_rho_sq) * self.mask11 + \ - ((dt_rho_cu - 3.0 * dt_rho_sq) / dt) * self.mask12 + \ - (0.5 * dt_sq) * self.mask20 + \ - ((1.0 - 0.5 * dt_rho) * dt) * self.mask21 + \ - (1.0 - 2.0 * dt_rho + 0.5 * dt_rho_sq) * self.mask22 + trans = ( + (1.0 + dt_rho + 0.5 * dt_rho_sq) * self.mask00 + + (-0.5 * dt_rho_cu / dt) * self.mask01 + + ((0.5 * dt_rho_qu - dt_rho_cu) / dt_sq) * self.mask02 + + ((dt_rho + 1.0) * dt) * self.mask10 + + (1.0 + dt_rho - dt_rho_sq) * self.mask11 + + ((dt_rho_cu - 3.0 * dt_rho_sq) / dt) * self.mask12 + + (0.5 * dt_sq) * self.mask20 + + ((1.0 - 0.5 * dt_rho) * dt) * self.mask21 + + (1.0 - 2.0 * dt_rho + 0.5 * dt_rho_sq) * self.mask22 + ) return torch.exp(-dt_rho) * trans @pyro_method @@ -121,9 +134,11 @@ def stationary_covariance(self): sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1) rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1) p_infinity = 0.0 - p_infinity = self.mask00 + \ - (five_thirds / rhosq) * (self.mask11 - self.mask02 - self.mask20) + \ - (25.0 / rhosq.pow(2.0)) * self.mask22 + p_infinity = ( + self.mask00 + + (five_thirds / rhosq) * (self.mask11 - self.mask02 - self.mask20) + + (25.0 / rhosq.pow(2.0)) * self.mask22 + ) return sigmasq * p_infinity @pyro_method diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 144c24a9fa..dc2aa5d808 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -78,7 +78,9 @@ def split_gelman_rubin(input, chain_dim=0, sample_dim=1): new_input = torch.stack([input[:, :N_half], input[:, -N_half:]], dim=1) new_input = new_input.reshape((-1, N_half) + input.shape[2:]) split_rhat = gelman_rubin(new_input) - return split_rhat.squeeze(max(sample_dim, chain_dim)).squeeze(min(sample_dim, chain_dim)) + return split_rhat.squeeze(max(sample_dim, chain_dim)).squeeze( + min(sample_dim, chain_dim) + ) def autocorrelation(input, dim=0): @@ -92,8 +94,9 @@ def autocorrelation(input, dim=0): :returns torch.Tensor: autocorrelation of ``input``. """ if (not input.is_cuda) and (not torch.backends.mkl.is_available()): - raise NotImplementedError("For CPU tensor, this method is only supported " - "with MKL installed.") + raise NotImplementedError( + "For CPU tensor, this method is only supported " "with MKL installed." + ) # Adapted from Stan implementation # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp @@ -116,7 +119,9 @@ def autocorrelation(input, dim=0): # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] - autocorr = autocorr / torch.tensor(range(N, 0, -1), dtype=input.dtype, device=input.device) + autocorr = autocorr / torch.tensor( + range(N, 0, -1), dtype=input.dtype, device=input.device + ) autocorr = autocorr / autocorr[..., :1] return autocorr.transpose(dim, -1) @@ -142,8 +147,11 @@ def _cummin(input): # FIXME: is there a better trick to find accumulate min of a sequence? N = input.size(0) input_tril = input.unsqueeze(0).repeat((N,) + (1,) * input.dim()) - triu_mask = (torch.ones(N, N, dtype=input.dtype, device=input.device) - .triu(diagonal=1).reshape((N, N) + (1,) * (input.dim() - 1))) + triu_mask = ( + torch.ones(N, N, dtype=input.dtype, device=input.device) + .triu(diagonal=1) + .reshape((N, N) + (1,) * (input.dim() - 1)) + ) triu_mask = triu_mask.expand((N, N) + input.shape[1:]) > 0.5 input_tril.masked_fill_(triu_mask, input.max()) return input_tril.min(dim=1)[0] @@ -279,9 +287,13 @@ def hpdi(input, prob, dim=0): mass = input.size(dim) index_length = int(prob * mass) intervals_left = sorted_input.index_select( - dim, torch.tensor(range(mass - index_length), dtype=torch.long, device=input.device)) + dim, + torch.tensor(range(mass - index_length), dtype=torch.long, device=input.device), + ) intervals_right = sorted_input.index_select( - dim, torch.tensor(range(index_length, mass), dtype=torch.long, device=input.device)) + dim, + torch.tensor(range(index_length, mass), dtype=torch.long, device=input.device), + ) intervals_length = intervals_right - intervals_left index_start = intervals_length.argmin(dim) indices = torch.stack([index_start, index_start + index_length], dim) @@ -298,8 +310,10 @@ def _weighted_mean(input, log_weights, dim=0, keepdim=False): def _weighted_variance(input, log_weights, dim=0, keepdim=False, unbiased=True): # Ref: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Frequency_weights - deviation_squared = (input - _weighted_mean(input, log_weights, dim, keepdim=True)).pow(2) - correction = log_weights.size(0) / (log_weights.size(0) - 1.) if unbiased else 1. + deviation_squared = ( + input - _weighted_mean(input, log_weights, dim, keepdim=True) + ).pow(2) + correction = log_weights.size(0) / (log_weights.size(0) - 1.0) if unbiased else 1.0 return _weighted_mean(deviation_squared, log_weights, dim, keepdim) * correction @@ -319,7 +333,9 @@ def waic(input, log_weights=None, pointwise=False, dim=0): :returns tuple: tuple of WAIC and effective number of parameters. """ if log_weights is None: - log_weights = torch.zeros(input.size(dim), dtype=input.dtype, device=input.device) + log_weights = torch.zeros( + input.size(dim), dtype=input.dtype, device=input.device + ) # computes log pointwise predictive density: formula (3) of [1] dim = input.dim() + dim if dim < 0 else dim @@ -362,7 +378,7 @@ def fit_generalized_pareto(X): # b = k / sigma bs = 1.0 - math.sqrt(M) / (torch.arange(1, M + 1, dtype=torch.double) - 0.5).sqrt() - bs /= 3.0 * X[int(N/4 - 0.5)] + bs /= 3.0 * X[int(N / 4 - 0.5)] bs += 1 / X[-1] ks = torch.log1p(-bs.unsqueeze(-1) * X).mean(-1) @@ -410,8 +426,10 @@ def crps_empirical(pred, truth): :rtype: torch.Tensor """ if pred.shape[1:] != (1,) * (pred.dim() - truth.dim() - 1) + truth.shape: - raise ValueError("Expected pred to have one extra sample dim on left. " - "Actual shapes: {} versus {}".format(pred.shape, truth.shape)) + raise ValueError( + "Expected pred to have one extra sample dim on left. " + "Actual shapes: {} versus {}".format(pred.shape, truth.shape) + ) opts = dict(device=pred.device, dtype=pred.dtype) num_samples = pred.size(0) if num_samples == 1: @@ -419,8 +437,9 @@ def crps_empirical(pred, truth): pred = pred.sort(dim=0).values diff = pred[1:] - pred[:-1] - weight = (torch.arange(1, num_samples, **opts) * - torch.arange(num_samples - 1, 0, -1, **opts)) + weight = torch.arange(1, num_samples, **opts) * torch.arange( + num_samples - 1, 0, -1, **opts + ) weight = weight.reshape(weight.shape + (1,) * (diff.dim() - 1)) - return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2 + return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples ** 2 diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index f438650009..507582162e 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -53,7 +53,7 @@ def block_diagonal(mat, block_size): mat = mat.reshape(mat.shape[:-2] + (B, M, B, N)) mat = mat.transpose(-2, -3) mat = mat.reshape(mat.shape[:-4] + (B * B, M, N)) - return mat[..., ::B + 1, :, :] + return mat[..., :: B + 1, :, :] def periodic_repeat(tensor, size, dim): @@ -117,7 +117,9 @@ def periodic_cumsum(tensor, period, dim): tensor = torch.nn.functional.pad(tensor, (0, 0) * (-1 - dim) + (0, padding)) # Accumulate. - shape = tensor.shape[:dim] + (repeats, period) + tensor.shape[tensor.dim() + dim + 1:] + shape = ( + tensor.shape[:dim] + (repeats, period) + tensor.shape[tensor.dim() + dim + 1 :] + ) result = tensor.reshape(shape).cumsum(dim=dim - 1).reshape(tensor.shape) # Truncate to original size. @@ -161,7 +163,9 @@ def periodic_features(duration, max_period=None, min_period=None, **options): t = torch.arange(float(duration), **options).unsqueeze(-1).unsqueeze(-1) phase = torch.tensor([0, math.pi / 2], **options).unsqueeze(-1) - freq = torch.arange(1, max_period / min_period, **options).mul_(2 * math.pi / max_period) + freq = torch.arange(1, max_period / min_period, **options).mul_( + 2 * math.pi / max_period + ) result = (freq * t + phase).cos_().reshape(duration, -1).contiguous() return result @@ -197,7 +201,7 @@ def next_fast_len(size): next_size += 1 -def convolve(signal, kernel, mode='full'): +def convolve(signal, kernel, mode="full"): """ Computes the 1-d convolution of signal by kernel using FFTs. The two arguments should have the same rightmost dim, but may otherwise be @@ -215,14 +219,14 @@ def convolve(signal, kernel, mode='full'): """ m = signal.size(-1) n = kernel.size(-1) - if mode == 'full': + if mode == "full": truncate = m + n - 1 - elif mode == 'valid': + elif mode == "valid": truncate = max(m, n) - min(m, n) + 1 - elif mode == 'same': + elif mode == "same": truncate = max(m, n) else: - raise ValueError('Unknown mode: {}'.format(mode)) + raise ValueError("Unknown mode: {}".format(mode)) # Compute convolution using fft. padded_size = m + n - 1 @@ -234,7 +238,7 @@ def convolve(signal, kernel, mode='full'): result = irfft(f_result, n=fast_ftt_size) start_idx = (padded_size - truncate) // 2 - return result[..., start_idx: start_idx + truncate] + return result[..., start_idx : start_idx + truncate] def repeated_matmul(M, n): @@ -247,7 +251,9 @@ def repeated_matmul(M, n): :param int n: The order of the largest product :math:`M^n` :returns torch.Tensor: A batch of square tensors of shape (n, ..., N, N) """ - assert M.size(-1) == M.size(-2), "Input tensors must satisfy M.size(-1) == M.size(-2)." + assert M.size(-1) == M.size( + -2 + ), "Input tensors must satisfy M.size(-1) == M.size(-2)." assert n > 0, "argument n to parallel_scan_repeated_matmul must be 1 or larger" doubling_rounds = 0 if n <= 2 else math.ceil(math.log(n, 2)) - 1 @@ -278,7 +284,7 @@ def dct(x, dim=-1): if dim >= 0: dim -= x.dim() if dim != -1: - y = x.reshape(x.shape[:dim + 1] + (-1,)).transpose(-1, -2) + y = x.reshape(x.shape[: dim + 1] + (-1,)).transpose(-1, -2) return dct(y).transpose(-1, -2).reshape(x.shape) # Ref: http://fourier.eng.hmc.edu/e161/lectures/dct/node2.html @@ -288,16 +294,20 @@ def dct(x, dim=-1): # Step 2 Y = rfft(y, n=N) # Step 3 - coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) + coef_real = torch.cos( + torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device) + ) M = Y.size(-1) coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1) X = as_complex(coef) * Y # NB: if we use the full-length version Y_full = fft(y, n=N), then # the real part of the later half of X will be the flip # of the negative of the imaginary part of the first half - X = torch.cat([X.real, -X.imag[..., 1:(N - M + 1)].flip(-1)], dim=-1) + X = torch.cat([X.real, -X.imag[..., 1 : (N - M + 1)].flip(-1)], dim=-1) # orthogonalize - scale = torch.cat([x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))]) + scale = torch.cat( + [x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))] + ) return X / scale @@ -315,11 +325,13 @@ def idct(x, dim=-1): if dim >= 0: dim -= x.dim() if dim != -1: - y = x.reshape(x.shape[:dim + 1] + (-1,)).transpose(-1, -2) + y = x.reshape(x.shape[: dim + 1] + (-1,)).transpose(-1, -2) return idct(y).transpose(-1, -2).reshape(x.shape) N = x.size(-1) - scale = torch.cat([x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))]) + scale = torch.cat( + [x.new_tensor([math.sqrt(N)]), x.new_full((N - 1,), math.sqrt(0.5 * N))] + ) x = x * scale # Step 1, solve X = cos(k) * Yr + sin(k) * Yi # We know that Y[1:] is conjugate to Y[:0:-1], hence @@ -329,9 +341,11 @@ def idct(x, dim=-1): # In addition, Yi[0] = 0, Yr[0] = X[0] # In other words, Y = complex_mul(e^ik, X - i[0, X[:0:-1]]) M = N // 2 + 1 # half size - xi = torch.nn.functional.pad(-x[..., N - M + 1:], (0, 1)).flip(-1) + xi = torch.nn.functional.pad(-x[..., N - M + 1 :], (0, 1)).flip(-1) X = torch.stack([x[..., :M], xi], dim=-1) - coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) + coef_real = torch.cos( + torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device) + ) coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1) Y = as_complex(coef) * as_complex(X) # Step 2 @@ -351,7 +365,7 @@ def haar_transform(x): :rtype: Tensor """ n = x.size(-1) // 2 - even, odd, end = x[..., 0:n+n:2], x[..., 1:n+n:2], x[..., n+n:] + even, odd, end = x[..., 0 : n + n : 2], x[..., 1 : n + n : 2], x[..., n + n :] hi = _ROOT_TWO_INVERSE * (even - odd) lo = _ROOT_TWO_INVERSE * (even + odd) if n >= 2: @@ -369,7 +383,7 @@ def inverse_haar_transform(x): :rtype: Tensor """ n = x.size(-1) // 2 - lo, hi, end = x[..., :n], x[..., n:n+n], x[..., n+n:] + lo, hi, end = x[..., :n], x[..., n : n + n], x[..., n + n :] if n >= 2: lo = inverse_haar_transform(lo) even = _ROOT_TWO_INVERSE * (lo + hi) @@ -412,8 +426,9 @@ def triangular_solve(x, y, upper=False, transpose=False): def precision_to_scale_tril(P): Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) - L = torch.triangular_solve(torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), - L_inv, upper=False)[0] + L = torch.triangular_solve( + torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False + )[0] return L diff --git a/pyro/ops/welford.py b/pyro/ops/welford.py index 4de3f4d56d..0affd404c8 100644 --- a/pyro/ops/welford.py +++ b/pyro/ops/welford.py @@ -14,13 +14,14 @@ class WelfordCovariance: [1] `The Art of Computer Programming`, Donald E. Knuth """ + def __init__(self, diagonal=True): self.diagonal = diagonal self.reset() def reset(self): - self._mean = 0. - self._m2 = 0. + self._mean = 0.0 + self._m2 = 0.0 self.n_samples = 0 def update(self, sample): @@ -36,16 +37,16 @@ def update(self, sample): def get_covariance(self, regularize=True): if self.n_samples < 2: - raise RuntimeError('Insufficient samples to estimate covariance') + raise RuntimeError("Insufficient samples to estimate covariance") cov = self._m2 / (self.n_samples - 1) if regularize: # Regularization from stan - scaled_cov = (self.n_samples / (self.n_samples + 5.)) * cov - shrinkage = 1e-3 * (5. / (self.n_samples + 5.0)) + scaled_cov = (self.n_samples / (self.n_samples + 5.0)) * cov + shrinkage = 1e-3 * (5.0 / (self.n_samples + 5.0)) if self.diagonal: cov = scaled_cov + shrinkage else: - scaled_cov.view(-1)[::scaled_cov.size(0) + 1] += shrinkage + scaled_cov.view(-1)[:: scaled_cov.size(0) + 1] += shrinkage cov = scaled_cov return cov @@ -54,14 +55,15 @@ class WelfordArrowheadCovariance: """ Likes :class:`WelfordCovariance` but generalized to the arrowhead structure. """ + def __init__(self, head_size=0): self.head_size = head_size self.reset() def reset(self): - self._mean = 0. - self._m2_top = 0. # upper part, shape: head_size x matrix_size - self._m2_bottom_diag = 0. # lower right part, shape: (matrix_size - head_size) + self._mean = 0.0 + self._m2_top = 0.0 # upper part, shape: head_size x matrix_size + self._m2_bottom_diag = 0.0 # lower right part, shape: (matrix_size - head_size) self.n_samples = 0 def update(self, sample): @@ -70,10 +72,15 @@ def update(self, sample): self._mean = self._mean + delta_pre / self.n_samples delta_post = sample - self._mean if self.head_size > 0: - self._m2_top = self._m2_top + torch.ger(delta_post[:self.head_size], delta_pre) + self._m2_top = self._m2_top + torch.ger( + delta_post[: self.head_size], delta_pre + ) else: self._m2_top = sample.new_empty(0, sample.size(0)) - self._m2_bottom_diag = self._m2_bottom_diag + delta_post[self.head_size:] * delta_pre[self.head_size:] + self._m2_bottom_diag = ( + self._m2_bottom_diag + + delta_post[self.head_size :] * delta_pre[self.head_size :] + ) def get_covariance(self, regularize=True): """ @@ -81,14 +88,14 @@ def get_covariance(self, regularize=True): and `bottom_diag = cov.diag()[head_size:]`. """ if self.n_samples < 2: - raise RuntimeError('Insufficient samples to estimate covariance') + raise RuntimeError("Insufficient samples to estimate covariance") top = self._m2_top / (self.n_samples - 1) bottom_diag = self._m2_bottom_diag / (self.n_samples - 1) if regularize: - top = top * (self.n_samples / (self.n_samples + 5.)) - bottom_diag = bottom_diag * (self.n_samples / (self.n_samples + 5.)) - shrinkage = 1e-3 * (5. / (self.n_samples + 5.0)) - top.view(-1)[::top.size(-1) + 1] += shrinkage + top = top * (self.n_samples / (self.n_samples + 5.0)) + bottom_diag = bottom_diag * (self.n_samples / (self.n_samples + 5.0)) + shrinkage = 1e-3 * (5.0 / (self.n_samples + 5.0)) + top.view(-1)[:: top.size(-1) + 1] += shrinkage bottom_diag = bottom_diag + shrinkage return top, bottom_diag diff --git a/pyro/optim/adagrad_rmsprop.py b/pyro/optim/adagrad_rmsprop.py index 8730352ab3..5a12a923e0 100644 --- a/pyro/optim/adagrad_rmsprop.py +++ b/pyro/optim/adagrad_rmsprop.py @@ -33,21 +33,23 @@ class AdagradRMSProp(Optimizer): :type delta: float """ - def __init__(self, params, eta: float = 1.0, delta: float = 1.0e-16, t: float = 0.1): + def __init__( + self, params, eta: float = 1.0, delta: float = 1.0e-16, t: float = 0.1 + ): defaults = dict(eta=eta, delta=delta, t=t) super().__init__(params, defaults) for group in self.param_groups: - for p in group['params']: + for p in group["params"]: state = self.state[p] - state['step'] = 0 - state['sum'] = torch.zeros_like(p.data) + state["step"] = 0 + state["sum"] = torch.zeros_like(p.data) def share_memory(self) -> None: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: state = self.state[p] - state['sum'].share_memory_() + state["sum"].share_memory_() def step(self, closure: Optional[Callable] = None) -> Optional[Any]: """ @@ -60,7 +62,7 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]: loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue @@ -70,16 +72,16 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]: raise NotImplementedError state = self.state[p] - state['step'] += 1 - if state['step'] == 1: + state["step"] += 1 + if state["step"] == 1: # if first step, initialize variance bit to grad^2 - state['sum'] = grad * grad + state["sum"] = grad * grad else: - state['sum'] *= (1.0 - group['t']) - state['sum'] += group['t'] * grad * grad + state["sum"] *= 1.0 - group["t"] + state["sum"] += group["t"] * grad * grad - lr = group['eta'] * (state['step'] ** (-0.5 + group['delta'])) - std = state['sum'].sqrt() + lr = group["eta"] * (state["step"] ** (-0.5 + group["delta"])) + std = state["sum"].sqrt() p.data.addcdiv_(grad, 1.0 + std, value=-lr) return loss diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index 32e15e77c9..14a6a06656 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -28,12 +28,25 @@ class ClippedAdam(Optimizer): `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980 """ - def __init__(self, params, lr: float = 1e-3, betas: Tuple = (0.9, 0.999), - eps: float = 1e-8, weight_decay=0, clip_norm: float = 10.0, - lrd: float = 1.0): - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - clip_norm=clip_norm, lrd=lrd) + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple = (0.9, 0.999), + eps: float = 1e-8, + weight_decay=0, + clip_norm: float = 10.0, + lrd: float = 1.0, + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + clip_norm=clip_norm, + lrd=lrd, + ) super().__init__(params, defaults) def step(self, closure: Optional[Callable] = None) -> Optional[Any]: @@ -47,40 +60,40 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]: loss = closure() for group in self.param_groups: - group['lr'] *= group['lrd'] + group["lr"] *= group["lrd"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data - grad.clamp_(-group['clip_norm'], group['clip_norm']) + grad.clamp_(-group["clip_norm"], group["clip_norm"]) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(grad) + state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(grad) + state["exp_avg_sq"] = torch.zeros_like(grad) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 - if group['weight_decay'] != 0: - grad = grad.add(p.data, alpha=group['weight_decay']) + if group["weight_decay"] != 0: + grad = grad.add(p.data, alpha=group["weight_decay"]) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = exp_avg_sq.sqrt().add_(group["eps"]) - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) diff --git a/pyro/optim/dct_adam.py b/pyro/optim/dct_adam.py index ac1ceb5590..1a0047c239 100644 --- a/pyro/optim/dct_adam.py +++ b/pyro/optim/dct_adam.py @@ -73,10 +73,25 @@ class DCTAdam(Optimizer): :param bool subsample_aware: whether to update gradient statistics only for those elements that appear in a subsample (default: False). """ - def __init__(self, params, lr: float = 1e-3, betas: Tuple = (0.9, 0.999), eps: float = 1e-8, - clip_norm: float = 10.0, lrd: float = 1.0, subsample_aware: bool = False): - defaults = dict(lr=lr, betas=betas, eps=eps, clip_norm=clip_norm, lrd=lrd, - subsample_aware=subsample_aware) + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple = (0.9, 0.999), + eps: float = 1e-8, + clip_norm: float = 10.0, + lrd: float = 1.0, + subsample_aware: bool = False, + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + clip_norm=clip_norm, + lrd=lrd, + subsample_aware=subsample_aware, + ) super().__init__(params, defaults) def step(self, closure: Optional[Callable] = None) -> Optional[float]: @@ -90,14 +105,14 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]: loss = closure() for group in self.param_groups: - group['lr'] *= group['lrd'] + group["lr"] *= group["lrd"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue subsample = getattr(p, "_pyro_subsample", {}) - if subsample and group['subsample_aware']: + if subsample and group["subsample_aware"]: self._step_param_subsample(group, p, subsample) else: self._step_param(group, p) @@ -106,7 +121,7 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]: def _step_param(self, group: Dict, p) -> None: grad = p.grad.data - grad.clamp_(-group['clip_norm'], group['clip_norm']) + grad.clamp_(-group["clip_norm"], group["clip_norm"]) # Transform selected parameters via dct. time_dim = getattr(p, "_pyro_dct_dim", None) @@ -118,26 +133,26 @@ def _step_param(self, group: Dict, p) -> None: # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(grad) + state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(grad) + state["exp_avg_sq"] = torch.zeros_like(grad) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = exp_avg_sq.sqrt().add_(group["eps"]) - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 if time_dim is None: p.data.addcdiv_(exp_avg, denom, value=-step_size) @@ -149,7 +164,7 @@ def _step_param_subsample(self, group: Dict, p, subsample) -> None: mask = _get_mask(p, subsample) grad = p.grad.data.masked_select(mask) - grad.clamp_(-group['clip_norm'], group['clip_norm']) + grad.clamp_(-group["clip_norm"], group["clip_norm"]) # Transform selected parameters via dct. time_dim = getattr(p, "_pyro_dct_dim", None) @@ -161,29 +176,36 @@ def _step_param_subsample(self, group: Dict, p, subsample) -> None: # State initialization if len(state) == 0: - state['step'] = torch.zeros_like(p) + state["step"] = torch.zeros_like(p) # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) - beta1, beta2 = group['betas'] + beta1, beta2 = group["betas"] - state_step = state['step'].masked_select(mask).add_(1) - state['step'].masked_scatter_(mask, state_step) + state_step = state["step"].masked_select(mask).add_(1) + state["step"].masked_scatter_(mask, state_step) # Decay the first and second moment running average coefficient - exp_avg = state['exp_avg'].masked_select(mask).mul_(beta1).add_(grad, alpha=1 - beta1) - state['exp_avg'].masked_scatter_(mask, exp_avg) + exp_avg = ( + state["exp_avg"].masked_select(mask).mul_(beta1).add_(grad, alpha=1 - beta1) + ) + state["exp_avg"].masked_scatter_(mask, exp_avg) - exp_avg_sq = state['exp_avg_sq'].masked_select(mask).mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - state['exp_avg_sq'].masked_scatter_(mask, exp_avg_sq) + exp_avg_sq = ( + state["exp_avg_sq"] + .masked_select(mask) + .mul_(beta2) + .addcmul_(grad, grad, value=1 - beta2) + ) + state["exp_avg_sq"].masked_scatter_(mask, exp_avg_sq) - denom = exp_avg_sq.sqrt_().add_(group['eps']) + denom = exp_avg_sq.sqrt_().add_(group["eps"]) bias_correction1 = 1 - beta1 ** state_step bias_correction2 = 1 - beta2 ** state_step - step_size = bias_correction2.sqrt_().div_(bias_correction1).mul_(group['lr']) + step_size = bias_correction2.sqrt_().div_(bias_correction1).mul_(group["lr"]) step = exp_avg.div_(denom) if time_dim is not None: diff --git a/pyro/optim/horovod.py b/pyro/optim/horovod.py index 6c7a6d6812..8a3c1f2aaa 100644 --- a/pyro/optim/horovod.py +++ b/pyro/optim/horovod.py @@ -29,11 +29,13 @@ class HorovodOptimizer(PyroOptim): :param \*\*horovod_kwargs: Extra parameters passed to :func:`horovod.torch.DistributedOptimizer`. """ + def __init__(self, pyro_optim: PyroOptim, **horovod_kwargs): param_name = pyro.get_param_store().param_name def optim_constructor(params, **pt_kwargs) -> Optimizer: import horovod.torch as hvd # type: ignore + pt_optim = pyro_optim.pt_optim_constructor(params, **pt_kwargs) # type: ignore named_parameters = [(param_name(p), p) for p in params] hvd_optim = hvd.DistributedOptimizer( @@ -43,7 +45,9 @@ def optim_constructor(params, **pt_kwargs) -> Optimizer: ) return hvd_optim # type: ignore - super().__init__(optim_constructor, pyro_optim.pt_optim_args, pyro_optim.pt_clip_args) + super().__init__( + optim_constructor, pyro_optim.pt_optim_args, pyro_optim.pt_clip_args + ) def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: # Sort by name to ensure deterministic processing order. diff --git a/pyro/optim/lr_scheduler.py b/pyro/optim/lr_scheduler.py index 538483632d..744c877892 100644 --- a/pyro/optim/lr_scheduler.py +++ b/pyro/optim/lr_scheduler.py @@ -29,21 +29,28 @@ class PyroLRScheduler(PyroOptim): svi.step(minibatch) scheduler.step() """ - def __init__(self, scheduler_constructor, optim_args: Union[Dict], - clip_args: Optional[Union[Dict]] = None): + + def __init__( + self, + scheduler_constructor, + optim_args: Union[Dict], + clip_args: Optional[Union[Dict]] = None, + ): # pytorch scheduler self.pt_scheduler_constructor = scheduler_constructor # torch optimizer - pt_optim_constructor = optim_args.pop('optimizer') + pt_optim_constructor = optim_args.pop("optimizer") # kwargs for the torch optimizer - optim_kwargs = optim_args.pop('optim_args') + optim_kwargs = optim_args.pop("optim_args") self.kwargs = optim_args super().__init__(pt_optim_constructor, optim_kwargs, clip_args) def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: super().__call__(params, *args, **kwargs) - def _get_optim(self, params: Union[Tensor, Iterable[Tensor], Iterable[Dict[Any, Any]]]): + def _get_optim( + self, params: Union[Tensor, Iterable[Tensor], Iterable[Dict[Any, Any]]] + ): optim = super()._get_optim(params) return self.pt_scheduler_constructor(optim, **self.kwargs) diff --git a/pyro/optim/multi.py b/pyro/optim/multi.py index 20c9eb659e..9f76b15423 100644 --- a/pyro/optim/multi.py +++ b/pyro/optim/multi.py @@ -31,6 +31,7 @@ class MultiOptimizer: if site['type'] == 'param'} optim.step(loss, params) """ + def step(self, loss: torch.Tensor, params: Dict) -> None: """ Performs an in-place optimization step on parameters given a @@ -72,9 +73,12 @@ class PyroMultiOptimizer(MultiOptimizer): Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objects in a :class:`MultiOptimizer` interface. """ + def __init__(self, optim: PyroOptim) -> None: if not isinstance(optim, PyroOptim): - raise TypeError('Expected a PyroOptim object but got a {}'.format(type(optim))) + raise TypeError( + "Expected a PyroOptim object but got a {}".format(type(optim)) + ) self.optim = optim def step(self, loss: torch.Tensor, params: Dict) -> None: @@ -90,6 +94,7 @@ class TorchMultiOptimizer(PyroMultiOptimizer): Facade to wrap :class:`~torch.optim.Optimizer` objects in a :class:`MultiOptimizer` interface. """ + def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict): optim = PyroOptim(optim_constructor, optim_args) super().__init__(optim) @@ -107,6 +112,7 @@ class MixedMultiOptimizer(MultiOptimizer): partition up all desired parameters to optimize. :raises ValueError: if any name is optimized by multiple optimizers. """ + def __init__(self, parts: List) -> None: optim_dict: Dict = {} self.parts = [] @@ -115,8 +121,10 @@ def __init__(self, parts: List) -> None: optim = PyroMultiOptimizer(optim) for name in names_part: if name in optim_dict: - raise ValueError("Attempted to optimize parameter '{}' by two different optimizers: " - "{} vs {}" .format(name, optim_dict[name], optim)) + raise ValueError( + "Attempted to optimize parameter '{}' by two different optimizers: " + "{} vs {}".format(name, optim_dict[name], optim) + ) optim_dict[name] = optim self.parts.append((names_part, optim)) @@ -128,7 +136,8 @@ def get_step(self, loss: torch.Tensor, params: Dict) -> Dict: updated_values = {} for names_part, optim in self.parts: updated_values.update( - optim.get_step(loss, {name: params[name] for name in names_part})) + optim.get_step(loss, {name: params[name] for name in names_part}) + ) return updated_values @@ -146,6 +155,7 @@ class Newton(MultiOptimizer): region. Missing names will use unregularized Newton update, equivalent to infinite trust radius. """ + def __init__(self, trust_radii: Dict = {}): self.trust_radii = trust_radii diff --git a/pyro/optim/optim.py b/pyro/optim/optim.py index 23e8579e80..950566c791 100644 --- a/pyro/optim/optim.py +++ b/pyro/optim/optim.py @@ -40,21 +40,27 @@ class PyroOptim: :param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries """ - def __init__(self, optim_constructor: Union[Callable, Optimizer, Type[Optimizer]], - optim_args: Union[Dict, Callable[..., Dict]], - clip_args: Optional[Union[Dict, Callable[..., Dict]]] = None): + + def __init__( + self, + optim_constructor: Union[Callable, Optimizer, Type[Optimizer]], + optim_args: Union[Dict, Callable[..., Dict]], + clip_args: Optional[Union[Dict, Callable[..., Dict]]] = None, + ): self.pt_optim_constructor = optim_constructor # must be callable or dict assert callable(optim_args) or isinstance( - optim_args, dict), "optim_args must be function that returns defaults or a defaults dictionary" + optim_args, dict + ), "optim_args must be function that returns defaults or a defaults dictionary" if clip_args is None: clip_args = {} # must be callable or dict assert callable(clip_args) or isinstance( - clip_args, dict), "clip_args must be function that returns defaults or a defaults dictionary" + clip_args, dict + ), "clip_args must be function that returns defaults or a defaults dictionary" # hold our args to be called/used self.pt_optim_args = optim_args @@ -93,8 +99,11 @@ def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: if self.grad_clip[p] is not None: self.grad_clip[p](p) - if isinstance(self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler) or \ - isinstance(self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau): + if isinstance( + self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler + ) or isinstance( + self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau + ): # if optim object was a scheduler, perform an optimizer step self.optim_objs[p].optimizer.step(*args, **kwargs) else: @@ -159,7 +168,9 @@ def _get_optim_args(self, param: Union[Iterable[Tensor], Iterable[Dict]]): opt_dict = self.pt_optim_args(module_name, stripped_param_name) # must be dictionary - assert isinstance(opt_dict, dict), "per-param optim arg must return defaults dictionary" + assert isinstance( + opt_dict, dict + ), "per-param optim arg must return defaults dictionary" return opt_dict else: return self.pt_optim_args @@ -189,14 +200,19 @@ def _get_grad_clip_args(self, param: str) -> Dict: clip_dict = self.pt_clip_args(module_name, stripped_param_name) # must be dictionary - assert isinstance(clip_dict, dict), "per-param clip arg must return defaults dictionary" + assert isinstance( + clip_dict, dict + ), "per-param clip arg must return defaults dictionary" return clip_dict else: return self.pt_clip_args @staticmethod - def _clip_grad(params: Union[Tensor, Iterable[Tensor]], clip_norm: Optional[Union[int, float]] = None, - clip_value: Optional[Union[int, float]] = None) -> None: + def _clip_grad( + params: Union[Tensor, Iterable[Tensor]], + clip_norm: Optional[Union[int, float]] = None, + clip_value: Optional[Union[int, float]] = None, + ) -> None: if clip_norm is not None: clip_grad_norm_(params, clip_norm) if clip_value is not None: diff --git a/pyro/optim/pytorch_optimizers.py b/pyro/optim/pytorch_optimizers.py index bdf709f0ce..57dd4b0282 100644 --- a/pyro/optim/pytorch_optimizers.py +++ b/pyro/optim/pytorch_optimizers.py @@ -19,9 +19,15 @@ # XXX LBFGS is not supported for SVI yet continue - _PyroOptim = (lambda _Optim: lambda optim_args, clip_args=None: PyroOptim(_Optim, optim_args, clip_args))(_Optim) + _PyroOptim = ( + lambda _Optim: lambda optim_args, clip_args=None: PyroOptim( + _Optim, optim_args, clip_args + ) + )(_Optim) _PyroOptim.__name__ = _name - _PyroOptim.__doc__ = 'Wraps :class:`torch.optim.{}` with :class:`~pyro.optim.optim.PyroOptim`.'.format(_name) + _PyroOptim.__doc__ = "Wraps :class:`torch.optim.{}` with :class:`~pyro.optim.optim.PyroOptim`.".format( + _name + ) locals()[_name] = _PyroOptim __all__.append(_name) @@ -31,17 +37,24 @@ for _name, _Optim in torch.optim.lr_scheduler.__dict__.items(): if not isinstance(_Optim, type): continue - if not issubclass(_Optim, torch.optim.lr_scheduler._LRScheduler) and _name != 'ReduceLROnPlateau': + if ( + not issubclass(_Optim, torch.optim.lr_scheduler._LRScheduler) + and _name != "ReduceLROnPlateau" + ): continue if _Optim is torch.optim.Optimizer: continue _PyroOptim = ( - lambda _Optim: lambda optim_args, clip_args=None: PyroLRScheduler(_Optim, optim_args, clip_args) + lambda _Optim: lambda optim_args, clip_args=None: PyroLRScheduler( + _Optim, optim_args, clip_args + ) )(_Optim) _PyroOptim.__name__ = _name - _PyroOptim.__doc__ = 'Wraps :class:`torch.optim.{}` with '.format(_name) +\ - ':class:`~pyro.optim.lr_scheduler.PyroLRScheduler`.' + _PyroOptim.__doc__ = ( + "Wraps :class:`torch.optim.{}` with ".format(_name) + + ":class:`~pyro.optim.lr_scheduler.PyroLRScheduler`." + ) locals()[_name] = _PyroOptim __all__.append(_name) diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 3b31046a3c..7142e34a85 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -178,17 +178,23 @@ def named_parameters(self): return self._params.items() def get_all_param_names(self): - warnings.warn("ParamStore.get_all_param_names() is deprecated; use .keys() instead.", - DeprecationWarning) + warnings.warn( + "ParamStore.get_all_param_names() is deprecated; use .keys() instead.", + DeprecationWarning, + ) return self.keys() def replace_param(self, param_name, new_param, old_param): - warnings.warn("ParamStore.replace_param() is deprecated; use .__setitem__() instead.", - DeprecationWarning) + warnings.warn( + "ParamStore.replace_param() is deprecated; use .__setitem__() instead.", + DeprecationWarning, + ) assert self._params[param_name] is old_param.unconstrained() self[param_name] = new_param - def get_param(self, name, init_tensor=None, constraint=constraints.real, event_dim=None): + def get_param( + self, name, init_tensor=None, constraint=constraints.real, event_dim=None + ): """ Get parameter from its name. If it does not yet exist in the ParamStore, it will be created and stored. @@ -234,8 +240,8 @@ def get_state(self): Get the ParamStore state. """ state = { - 'params': self._params, - 'constraints': self._constraints, + "params": self._params, + "constraints": self._constraints, } return state @@ -244,14 +250,15 @@ def set_state(self, state): Set the ParamStore state using state from a previous get_state() call """ assert isinstance(state, dict), "malformed ParamStore state" - assert set(state.keys()) == set(['params', 'constraints']), \ - "malformed ParamStore keys {}".format(state.keys()) + assert set(state.keys()) == set( + ["params", "constraints"] + ), "malformed ParamStore keys {}".format(state.keys()) - for param_name, param in state['params'].items(): + for param_name, param in state["params"].items(): self._params[param_name] = param self._param_to_name[param] = param_name - for param_name, constraint in state['constraints'].items(): + for param_name, constraint in state["constraints"].items(): if isinstance(constraint, type(constraints.real)): # Work around lack of hash & equality comparison on constraints. constraint = constraints.real diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index e8f9b77845..e15f0a885b 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -13,13 +13,14 @@ def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg): else: msg_type = msg["type"] - is_not_exposed = (msg["name"] not in expose) and \ - (msg_type not in expose_types) + is_not_exposed = (msg["name"] not in expose) and (msg_type not in expose_types) # decision rule for hiding: - if (msg["name"] in hide) or \ - (msg_type in hide_types) or \ - (is_not_exposed and hide_all): # noqa: E129 + if ( + (msg["name"] in hide) + or (msg_type in hide_types) + or (is_not_exposed and hide_all) + ): # noqa: E129 return True # otherwise expose @@ -30,8 +31,9 @@ def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg): def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types): # first, some sanity checks: # hide_all and expose_all intersect? - assert (hide_all is False and expose_all is False) or \ - (hide_all != expose_all), "cannot hide and expose a site" + assert (hide_all is False and expose_all is False) or ( + hide_all != expose_all + ), "cannot hide and expose a site" # hide and expose intersect? if hide is None: @@ -44,8 +46,7 @@ def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose else: hide_all = True - assert set(hide).isdisjoint(set(expose)), \ - "cannot hide and expose a site" + assert set(hide).isdisjoint(set(expose)), "cannot hide and expose a site" # hide_types and expose_types intersect? if hide_types is None: @@ -58,8 +59,9 @@ def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose else: hide_all = True - assert set(hide_types).isdisjoint(set(expose_types)), \ - "cannot hide and expose a site type" + assert set(hide_types).isdisjoint( + set(expose_types) + ), "cannot hide and expose a site type" return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all) @@ -113,10 +115,17 @@ class BlockMessenger(Messenger): :returns: stochastic function decorated with a :class:`~pyro.poutine.block_messenger.BlockMessenger` """ - def __init__(self, hide_fn=None, expose_fn=None, - hide_all=True, expose_all=False, - hide=None, expose=None, - hide_types=None, expose_types=None): + def __init__( + self, + hide_fn=None, + expose_fn=None, + hide_all=True, + expose_all=False, + hide=None, + expose=None, + hide_types=None, + expose_types=None, + ): super().__init__() if not (hide_fn is None or expose_fn is None): raise ValueError("Only specify one of hide_fn or expose_fn") @@ -125,9 +134,9 @@ def __init__(self, hide_fn=None, expose_fn=None, elif expose_fn is not None: self.hide_fn = lambda msg: not expose_fn(msg) else: - self.hide_fn = _make_default_hide_fn(hide_all, expose_all, - hide, expose, - hide_types, expose_types) + self.hide_fn = _make_default_hide_fn( + hide_all, expose_all, hide, expose, hide_types, expose_types + ) def _process_message(self, msg): msg["stop"] = bool(self.hide_fn(msg)) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 2ff2a57ae9..f12ec13f48 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -48,22 +48,37 @@ def _pyro_sample(msg): dist = msg["fn"] actual_batch_shape = getattr(dist, "batch_shape", None) if actual_batch_shape is not None: - target_batch_shape = [None if size == 1 else size - for size in actual_batch_shape] + target_batch_shape = [ + None if size == 1 else size for size in actual_batch_shape + ] for f in msg["cond_indep_stack"]: if f.dim is None or f.size == -1: continue assert f.dim < 0 - target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape - if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size: - raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format( - f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim])) + target_batch_shape = [None] * ( + -f.dim - len(target_batch_shape) + ) + target_batch_shape + if ( + target_batch_shape[f.dim] is not None + and target_batch_shape[f.dim] != f.size + ): + raise ValueError( + "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format( + f.name, + msg["name"], + f.dim, + f.size, + target_batch_shape[f.dim], + ) + ) target_batch_shape[f.dim] = f.size # Starting from the right, if expected size is None at an index, # set it to the actual size if it exists, else 1. for i in range(-len(target_batch_shape) + 1, 1): if target_batch_shape[i] is None: - target_batch_shape[i] = actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 + target_batch_shape[i] = ( + actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 + ) msg["fn"] = dist.expand(target_batch_shape) if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute diff --git a/pyro/poutine/collapse_messenger.py b/pyro/poutine/collapse_messenger.py index 25a6c4872a..826a7f957e 100644 --- a/pyro/poutine/collapse_messenger.py +++ b/pyro/poutine/collapse_messenger.py @@ -88,12 +88,14 @@ class CollapseMessenger(TraceMessenger): context, you should manually declare ``max_plate_nesting`` to your inference algorithm (e.g. ``Trace_ELBO(max_plate_nesting=1)``). """ + _coerce = None def __init__(self, *args, **kwargs): if CollapseMessenger._coerce is None: import funsor from funsor.distribution import CoerceDistributionToFunsor + funsor.set_backend("torch") CollapseMessenger._coerce = CoerceDistributionToFunsor("torch") self._block = False @@ -149,8 +151,9 @@ def _pyro_barrier(self, msg): msg["value"] = value def __enter__(self): - self.preserved_plates = {h.dim: h.name for h in _PYRO_STACK - if isinstance(h, pyro.plate)} + self.preserved_plates = { + h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate) + } COERCIONS.append(self._coerce) return super().__enter__() @@ -159,8 +162,7 @@ def __exit__(self, *args): assert _coerce is self._coerce super().__exit__(*args) - if any(site["type"] == "sample" - for site in self.trace.nodes.values()): + if any(site["type"] == "sample" for site in self.trace.nodes.values()): name, log_prob, _, _ = self._get_log_prob() pyro.factor(name, log_prob.data) @@ -173,8 +175,9 @@ def _get_log_prob(self): if not site["is_observed"]: reduced_vars.append(name) log_prob_terms.append(site["fn"](value=site["value"])) - plates |= frozenset(f.name for f in site["cond_indep_stack"] - if f.vectorized) + plates |= frozenset( + f.name for f in site["cond_indep_stack"] if f.vectorized + ) name = reduced_vars[0] reduced_vars = frozenset(reduced_vars) assert log_prob_terms, "nothing to collapse" diff --git a/pyro/poutine/condition_messenger.py b/pyro/poutine/condition_messenger.py index 5b08951d1f..01d4809bd0 100644 --- a/pyro/poutine/condition_messenger.py +++ b/pyro/poutine/condition_messenger.py @@ -30,6 +30,7 @@ class ConditionMessenger(Messenger): :param data: a dict or a :class:`~pyro.poutine.Trace` :returns: stochastic function decorated with a :class:`~pyro.poutine.condition_messenger.ConditionMessenger` """ + def __init__(self, data): """ :param data: a dict or a Trace diff --git a/pyro/poutine/do_messenger.py b/pyro/poutine/do_messenger.py index 1b053f0f17..b89c52698c 100644 --- a/pyro/poutine/do_messenger.py +++ b/pyro/poutine/do_messenger.py @@ -47,22 +47,26 @@ class DoMessenger(Messenger): :param data: a ``dict`` mapping sample site names to interventions :returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger` """ + def __init__(self, data): super().__init__() self.data = data self._intervener_id = str(id(self)) def _pyro_sample(self, msg): - if msg.get('_intervener_id', None) != self._intervener_id and \ - self.data.get(msg['name']) is not None: + if ( + msg.get("_intervener_id", None) != self._intervener_id + and self.data.get(msg["name"]) is not None + ): - if msg.get('_intervener_id', None) is not None: + if msg.get("_intervener_id", None) is not None: warnings.warn( "Attempting to intervene on variable {} multiple times," - "this is almost certainly incorrect behavior".format(msg['name']), - RuntimeWarning) + "this is almost certainly incorrect behavior".format(msg["name"]), + RuntimeWarning, + ) - msg['_intervener_id'] = self._intervener_id + msg["_intervener_id"] = self._intervener_id # split node, avoid reapplying self recursively to new node new_msg = msg.copy() @@ -70,15 +74,18 @@ def _pyro_sample(self, msg): apply_stack(new_msg) # apply intervention - intervention = self.data[msg['name']] - msg['name'] = msg['name'] + "__CF" # mangle old name + intervention = self.data[msg["name"]] + msg["name"] = msg["name"] + "__CF" # mangle old name if isinstance(intervention, (numbers.Number, torch.Tensor)): - msg['value'] = intervention - msg['is_observed'] = True - msg['stop'] = True + msg["value"] = intervention + msg["is_observed"] = True + msg["stop"] = True else: raise NotImplementedError( - "Interventions of type {} not implemented (yet)".format(type(intervention))) + "Interventions of type {} not implemented (yet)".format( + type(intervention) + ) + ) return None diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index cb1c14162b..234c0764b8 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -36,10 +36,16 @@ def _tmc_mixture_sample(msg): index = [Ellipsis] + [slice(None)] * (len(thin_sample.shape) - 1) squashed_dims = [] - for squashed_dim, squashed_size in zip(range(1, len(thin_sample.shape)), thin_sample.shape[1:]): - if squashed_size > 1 and (target_shape[squashed_dim] == 1 or squashed_dim == 0): + for squashed_dim, squashed_size in zip( + range(1, len(thin_sample.shape)), thin_sample.shape[1:] + ): + if squashed_size > 1 and ( + target_shape[squashed_dim] == 1 or squashed_dim == 0 + ): # uniformly sample one ancestor per upstream particle population - ancestor_dist = Categorical(logits=torch.zeros((squashed_size,), device=thin_sample.device)) + ancestor_dist = Categorical( + logits=torch.zeros((squashed_size,), device=thin_sample.device) + ) ancestor_index = ancestor_dist.sample(sample_shape=(num_samples,)) index[squashed_dim] = ancestor_index squashed_dims.append(squashed_dim) @@ -76,8 +82,12 @@ def _tmc_diagonal_sample(msg): index = [Ellipsis] + [slice(None)] * (len(thin_sample.shape) - 1) squashed_dims = [] - for squashed_dim, squashed_size in zip(range(1, len(thin_sample.shape)), thin_sample.shape[1:]): - if squashed_size > 1 and (target_shape[squashed_dim] == 1 or squashed_dim == 0): + for squashed_dim, squashed_size in zip( + range(1, len(thin_sample.shape)), thin_sample.shape[1:] + ): + if squashed_size > 1 and ( + target_shape[squashed_dim] == 1 or squashed_dim == 0 + ): # diagonal approximation: identify particle indices across populations ancestor_index = torch.arange(squashed_size, device=thin_sample.device) index[squashed_dim] = ancestor_index @@ -122,8 +132,11 @@ class EnumMessenger(Messenger): dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None. """ + def __init__(self, first_available_dim=None): - assert first_available_dim is None or first_available_dim < 0, first_available_dim + assert ( + first_available_dim is None or first_available_dim < 0 + ), first_available_dim self.first_available_dim = first_available_dim super().__init__() @@ -149,7 +162,9 @@ def _pyro_sample(self, msg): param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id if scope is not None: for name, depth in scope.items(): - if self._markov_depths[name] == depth: # hide sites whose markov context has exited + if ( + self._markov_depths[name] == depth + ): # hide sites whose markov context has exited param_dims.update(self._value_dims[name]) self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] self._param_dims[msg["name"]] = param_dims @@ -161,9 +176,11 @@ def _pyro_sample(self, msg): actual_dim = -1 - len(msg["fn"].batch_shape) # the leftmost dim of log_prob # Move actual_dim to a safe target_dim. - target_dim, id_ = _ENUM_ALLOCATOR.allocate(None if scope is None else param_dims) + target_dim, id_ = _ENUM_ALLOCATOR.allocate( + None if scope is None else param_dims + ) event_dim = msg["fn"].event_dim - categorical_support = getattr(value, '_pyro_categorical_support', None) + categorical_support = getattr(value, "_pyro_categorical_support", None) if categorical_support is not None: # Preserve categorical supports to speed up Categorical.log_prob(). # See pyro/distributions/torch.py for details. @@ -171,8 +188,9 @@ def _pyro_sample(self, msg): value = value.reshape(value.shape[:1] + (1,) * (-1 - target_dim)) value._pyro_categorical_support = categorical_support elif actual_dim < target_dim: - assert value.size(target_dim - event_dim) == 1, \ - 'pyro.markov dim conflict at dim {}'.format(actual_dim) + assert ( + value.size(target_dim - event_dim) == 1 + ), "pyro.markov dim conflict at dim {}".format(actual_dim) value = value.transpose(target_dim - event_dim, actual_dim - event_dim) while value.dim() and value.size(0) == 1: value = value.squeeze(0) @@ -181,8 +199,11 @@ def _pyro_sample(self, msg): value = value.reshape(value.shape[:1] + (1,) * diff + value.shape[1:]) # Compute dims passed downstream through the value. - value_dims = {dim: param_dims[dim] for dim in range(event_dim - value.dim(), 0) - if value.size(dim - event_dim) > 1 and dim in param_dims} + value_dims = { + dim: param_dims[dim] + for dim in range(event_dim - value.dim(), 0) + if value.size(dim - event_dim) > 1 and dim in param_dims + } value_dims[target_dim] = id_ msg["infer"]["_enumerate_dim"] = target_dim @@ -200,9 +221,12 @@ def _pyro_post_sample(self, msg): value = msg["value"] if value is None: return - shape = value.data.shape[:value.dim() - msg["fn"].event_dim] + shape = value.data.shape[: value.dim() - msg["fn"].event_dim] dim_to_id = msg["infer"].setdefault("_dim_to_id", {}) dim_to_id.update(self._param_dims.get(msg["name"], {})) with ignore_jit_warnings(): - self._value_dims[msg["name"]] = {dim: id_ for dim, id_ in dim_to_id.items() - if len(shape) >= -dim and shape[dim] > 1} + self._value_dims[msg["name"]] = { + dim: id_ + for dim, id_ in dim_to_id.items() + if len(shape) >= -dim and shape[dim] > 1 + } diff --git a/pyro/poutine/escape_messenger.py b/pyro/poutine/escape_messenger.py index 08df98db72..fd0bf6a5e6 100644 --- a/pyro/poutine/escape_messenger.py +++ b/pyro/poutine/escape_messenger.py @@ -9,6 +9,7 @@ class EscapeMessenger(Messenger): """ Messenger that does a nonlocal exit by raising a util.NonlocalExit exception """ + def __init__(self, escape_fn): """ :param escape_fn: function that takes a msg as input and returns True @@ -35,5 +36,6 @@ def _pyro_sample(self, msg): def cont(m): raise NonlocalExit(m) + msg["continuation"] = cont return None diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 3f821a78ad..01cdc60334 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -101,25 +101,42 @@ def _make_handler(msngr_cls): - _re1 = re.compile('(.)([A-Z][a-z]+)') - _re2 = re.compile('([a-z0-9])([A-Z])') + _re1 = re.compile("(.)([A-Z][a-z]+)") + _re2 = re.compile("([a-z0-9])([A-Z])") def handler(fn=None, *args, **kwargs): - if fn is not None and not (callable(fn) or isinstance(fn, collections.abc.Iterable)): + if fn is not None and not ( + callable(fn) or isinstance(fn, collections.abc.Iterable) + ): raise ValueError( - "{} is not callable, did you mean to pass it as a keyword arg?".format(fn)) + "{} is not callable, did you mean to pass it as a keyword arg?".format( + fn + ) + ) msngr = msngr_cls(*args, **kwargs) - return functools.update_wrapper(msngr(fn), fn, updated=()) if fn is not None else msngr + return ( + functools.update_wrapper(msngr(fn), fn, updated=()) + if fn is not None + else msngr + ) # handler names from messenger names: strip Messenger suffix, convert CamelCase to snake_case handler_name = _re2.sub( - r'\1_\2', _re1.sub(r'\1_\2', msngr_cls.__name__.split("Messenger")[0])).lower() - handler.__doc__ = """Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format( - handler_name + "_messenger", msngr_cls.__name__) + (msngr_cls.__doc__ if msngr_cls.__doc__ else "") + r"\1_\2", _re1.sub(r"\1_\2", msngr_cls.__name__.split("Messenger")[0]) + ).lower() + handler.__doc__ = ( + """Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format( + handler_name + "_messenger", msngr_cls.__name__ + ) + + (msngr_cls.__doc__ if msngr_cls.__doc__ else "") + ) handler.__name__ = handler_name return handler_name, handler +trace = None # flake8 +escape = None # flake8 + for _msngr_cls in _msngrs: _handler_name, _handler = _make_handler(_msngr_cls) _handler.__module__ = __name__ @@ -130,8 +147,15 @@ def handler(fn=None, *args, **kwargs): # Begin composite operations ######################################### -def queue(fn=None, queue=None, max_tries=None, - extend_fn=None, escape_fn=None, num_samples=None): + +def queue( + fn=None, + queue=None, + max_tries=None, + extend_fn=None, + escape_fn=None, + num_samples=None, +): """ Used in sequential enumeration over discrete variables. @@ -165,22 +189,28 @@ def wrapper(wrapped): def _fn(*args, **kwargs): for i in range(max_tries): - assert not queue.empty(), \ - "trying to get() from an empty queue will deadlock" + assert ( + not queue.empty() + ), "trying to get() from an empty queue will deadlock" next_trace = queue.get() try: - ftr = trace(escape(replay(wrapped, trace=next_trace), # noqa: F821 - escape_fn=functools.partial(escape_fn, - next_trace))) + ftr = trace( + escape( + replay(wrapped, trace=next_trace), # noqa: F821 + escape_fn=functools.partial(escape_fn, next_trace), + ) + ) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() - for tr in extend_fn(ftr.trace.copy(), site_container.site, - num_samples=num_samples): + for tr in extend_fn( + ftr.trace.copy(), site_container.site, num_samples=num_samples + ): queue.put(tr) raise ValueError("max tries ({}) exceeded".format(str(max_tries))) + return _fn return wrapper(fn) if fn is not None else wrapper @@ -214,6 +244,8 @@ def markov(fn=None, history=1, keep=False, dim=None, name=None): return MarkovMessenger(history=history, keep=keep, dim=dim, name=name) if not callable(fn): # Used as a generator - return MarkovMessenger(history=history, keep=keep, dim=dim, name=name).generator(iterable=fn) + return MarkovMessenger( + history=history, keep=keep, dim=dim, name=name + ).generator(iterable=fn) # Used as a decorator with bound args return MarkovMessenger(history=history, keep=keep, dim=dim, name=name)(fn) diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 0b65987932..ead1c7d613 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -12,14 +12,18 @@ from .runtime import _DIM_ALLOCATOR -class CondIndepStackFrame(namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])): +class CondIndepStackFrame( + namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"]) +): @property def vectorized(self): return self.dim is not None def _key(self): with ignore_jit_warnings(["Converting a tensor to a Python number"]): - size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size + size = ( + self.size.item() if isinstance(self.size, torch.Tensor) else self.size + ) return self.name, self.dim, size, self.counter def __eq__(self, other): @@ -54,6 +58,7 @@ class IndepMessenger(Messenger): xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320])) """ + def __init__(self, name=None, size=None, dim=None, device=None): if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") @@ -94,7 +99,8 @@ def __iter__(self): if self._vectorized is True or self.dim is not None: raise ValueError( "cannot use plate {} as both vectorized and non-vectorized" - "independence context".format(self.name)) + "independence context".format(self.name) + ) self._vectorized = False self.dim = None diff --git a/pyro/poutine/infer_config_messenger.py b/pyro/poutine/infer_config_messenger.py index bad91e303a..507ee4f40f 100644 --- a/pyro/poutine/infer_config_messenger.py +++ b/pyro/poutine/infer_config_messenger.py @@ -14,6 +14,7 @@ class InferConfigMessenger(Messenger): :param config_fn: a callable taking a site and returning an infer dict :returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger` """ + def __init__(self, config_fn): """ :param config_fn: a callable taking a site and returning an infer dict diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index dcf3222da8..3b72f66534 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -65,8 +65,10 @@ def __exit__(self, *args, **kwargs): if extra: warnings.warn( "pyro.module prior did not find params ['{}']. " - "Did you instead mean one of ['{}']?" - .format("', '".join(extra), "', '".join(self._param_misses))) + "Did you instead mean one of ['{}']?".format( + "', '".join(extra), "', '".join(self._param_misses) + ) + ) return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg): @@ -87,7 +89,7 @@ def _pyro_param(self, msg): if param_name in self.prior.keys(): msg["fn"] = self.prior[param_name] msg["args"] = msg["args"][1:] - if isinstance(msg['fn'], Distribution): + if isinstance(msg["fn"], Distribution): msg["args"] = () msg["kwargs"] = {} msg["infer"] = {} @@ -115,7 +117,7 @@ def _pyro_param(self, msg): if name in self._samples_cache: # Multiple pyro.param statements with the same # name. Block the site and fix the value. - msg['value'] = self._samples_cache[name]['value'] + msg["value"] = self._samples_cache[name]["value"] msg["is_observed"] = True msg["stop"] = True else: diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index a57de54499..1d68c9e06a 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -26,6 +26,7 @@ class MarkovMessenger(ReentrantMessenger): :func:`pyro.markov` sites between models and guides. Interface stub, behavior not yet implemented. """ + def __init__(self, history=1, keep=False, dim=None, name=None): assert history >= 0 self.history = history @@ -34,10 +35,12 @@ def __init__(self, history=1, keep=False, dim=None, name=None): self.name = name if dim is not None: raise NotImplementedError( - "vectorized markov not yet implemented, try setting dim to None") + "vectorized markov not yet implemented, try setting dim to None" + ) if name is not None: raise NotImplementedError( - "vectorized markov not yet implemented, try setting name to None") + "vectorized markov not yet implemented, try setting name to None" + ) self._iterable = None self._pos = -1 self._stack = [] @@ -74,7 +77,9 @@ def _pyro_sample(self, msg): # This accounting can be done by users of these fields, # e.g. EnumMessenger. infer = msg["infer"] - scope = infer.setdefault("_markov_scope", Counter()) # site name -> markov depth + scope = infer.setdefault( + "_markov_scope", Counter() + ) # site name -> markov depth for pos in range(max(0, self._pos - self.history), self._pos + 1): scope.update(self._stack[pos]) infer["_markov_depth"] = 1 + infer.get("_markov_depth", 0) diff --git a/pyro/poutine/mask_messenger.py b/pyro/poutine/mask_messenger.py index 4d7323946c..35d9375827 100644 --- a/pyro/poutine/mask_messenger.py +++ b/pyro/poutine/mask_messenger.py @@ -16,13 +16,18 @@ class MaskMessenger(Messenger): (1 includes a site, 0 excludes a site) :returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.MaskMessenger` """ + def __init__(self, mask): if isinstance(mask, torch.Tensor): if mask.dtype != torch.bool: - raise ValueError('Expected mask to be a BoolTensor but got {}'.format(type(mask))) + raise ValueError( + "Expected mask to be a BoolTensor but got {}".format(type(mask)) + ) elif mask not in (True, False): - raise ValueError('Expected mask to be a boolean but got {}'.format(type(mask))) + raise ValueError( + "Expected mask to be a boolean but got {}".format(type(mask)) + ) super().__init__() self.mask = mask diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 174f64060b..7b9a259c0a 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -17,6 +17,7 @@ class _bound_partial(partial): Converts a (possibly) bound method into a partial function to support class methods as arguments to handlers. """ + def __get__(self, instance, owner): if instance is None: return self @@ -44,7 +45,10 @@ def __init__(self): def __call__(self, fn): if not callable(fn): raise ValueError( - "{} is not callable, did you mean to pass it as a keyword arg?".format(fn)) + "{} is not callable, did you mean to pass it as a keyword arg?".format( + fn + ) + ) wraps = _bound_partial(partial(_context_wrap, self, fn)) return wraps diff --git a/pyro/poutine/plate_messenger.py b/pyro/poutine/plate_messenger.py index 2e5c938ecf..12c380c8a1 100644 --- a/pyro/poutine/plate_messenger.py +++ b/pyro/poutine/plate_messenger.py @@ -13,6 +13,7 @@ class PlateMessenger(SubsampleMessenger): Swiss army knife of broadcasting amazingness: combines shape inference, independence annotation, and subsampling """ + def _process_message(self, msg): super()._process_message(msg) return BroadcastMessenger._pyro_sample(msg) @@ -72,7 +73,9 @@ def predicate(messenger): with block_messengers(predicate) as matches: if strict and len(matches) != 1: - raise ValueError(f"block_plate matched {len(matches)} messengers. " - "Try either removing the block_plate or " - "setting strict=False.") + raise ValueError( + f"block_plate matched {len(matches)} messengers. " + "Try either removing the block_plate or " + "setting strict=False." + ) yield diff --git a/pyro/poutine/reparam_messenger.py b/pyro/poutine/reparam_messenger.py index c984ddc9cb..4583ac62a0 100644 --- a/pyro/poutine/reparam_messenger.py +++ b/pyro/poutine/reparam_messenger.py @@ -37,6 +37,7 @@ class ReparamMessenger(Messenger): :class:`~pyro.infer.reparam.reparam.Reparameterizer` or None. :type config: dict or callable """ + def __init__(self, config): super().__init__() assert isinstance(config, dict) or callable(config) @@ -80,12 +81,14 @@ def _pyro_sample(self, msg): # Pass args_kwargs to the reparam via a side channel. reparam.args_kwargs = self._args_kwargs try: - new_msg = reparam.apply({ - "name": msg["name"], - "fn": msg["fn"], - "value": msg["value"], - "is_observed": msg["is_observed"], - }) + new_msg = reparam.apply( + { + "name": msg["name"], + "fn": msg["fn"], + "value": msg["value"], + "is_observed": msg["is_observed"], + } + ) finally: reparam.args_kwargs = None @@ -118,6 +121,7 @@ class ReparamHandler(object): """ Reparameterization poutine. """ + def __init__(self, msngr, fn): self.msngr = msngr self.fn = fn diff --git a/pyro/poutine/replay_messenger.py b/pyro/poutine/replay_messenger.py index 248ff48f0e..548c971473 100644 --- a/pyro/poutine/replay_messenger.py +++ b/pyro/poutine/replay_messenger.py @@ -61,8 +61,7 @@ def _pyro_sample(self, msg): guide_msg = self.trace.nodes[name] if msg["is_observed"]: return None - if guide_msg["type"] != "sample" or \ - guide_msg["is_observed"]: + if guide_msg["type"] != "sample" or guide_msg["is_observed"]: raise RuntimeError("site {} must be sampled in trace".format(name)) msg["done"] = True msg["value"] = guide_msg["value"] @@ -72,8 +71,9 @@ def _pyro_sample(self, msg): def _pyro_param(self, msg): name = msg["name"] if self.params is not None and name in self.params: - assert hasattr(self.params[name], "unconstrained"), \ - "param {} must be constrained value".format(name) + assert hasattr( + self.params[name], "unconstrained" + ), "param {} must be constrained value".format(name) msg["done"] = True msg["value"] = self.params[name] return None diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 62d6672b19..e5a980c895 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -22,6 +22,7 @@ class _DimAllocator: Note that dimensions are indexed from the right, e.g. -1, -2. """ + def __init__(self): self._stack = [] # in reverse orientation of log_prob.shape @@ -39,15 +40,26 @@ def allocate(self, name, dim): while -dim <= len(self._stack) and self._stack[-1 - dim] is not None: dim -= 1 elif dim >= 0: - raise ValueError('Expected dim < 0 to index from the right, actual {}'.format(dim)) + raise ValueError( + "Expected dim < 0 to index from the right, actual {}".format(dim) + ) # Allocate the requested dimension. while dim < -len(self._stack): self._stack.append(None) if self._stack[-1 - dim] is not None: - raise ValueError('\n'.join([ - 'at plates "{}" and "{}", collide at dim={}'.format(name, self._stack[-1 - dim], dim), - '\nTry moving the dim of one plate to the left, e.g. dim={}'.format(dim - 1)])) + raise ValueError( + "\n".join( + [ + 'at plates "{}" and "{}", collide at dim={}'.format( + name, self._stack[-1 - dim], dim + ), + "\nTry moving the dim of one plate to the left, e.g. dim={}".format( + dim - 1 + ), + ] + ) + ) self._stack[-1 - dim] = name return dim @@ -74,6 +86,7 @@ class _EnumAllocator: Note that dimensions are indexed from the right, e.g. -1, -2. Note that ids are simply nonnegative integers here. """ + def set_first_available_dim(self, first_available_dim): """ Set the first available dim, which should be to the left of all @@ -105,8 +118,10 @@ def allocate(self, scope_dims=None): self.next_available_id += 1 dim = self.next_available_dim - if dim == -float('inf'): - raise ValueError("max_plate_nesting must be set to a finite value for parallel enumeration") + if dim == -float("inf"): + raise ValueError( + "max_plate_nesting must be set to a finite value for parallel enumeration" + ) if scope_dims is None: # allocate a new global dimension self.next_available_dim -= 1 @@ -129,6 +144,7 @@ class NonlocalExit(Exception): Used by poutine.EscapeMessenger to return site information. """ + def __init__(self, site, *args, **kwargs): """ :param site: message at a pyro site constructor. @@ -265,6 +281,7 @@ def _fn(*args, **kwargs): # apply the stack and return its return value apply_stack(msg) return msg["value"] + _fn._is_effectful = True return _fn diff --git a/pyro/poutine/scale_messenger.py b/pyro/poutine/scale_messenger.py index f3ad720945..d041e12936 100644 --- a/pyro/poutine/scale_messenger.py +++ b/pyro/poutine/scale_messenger.py @@ -32,11 +32,14 @@ class ScaleMessenger(Messenger): :param scale: a positive scaling factor :returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.ScaleMessenger` """ + def __init__(self, scale): if isinstance(scale, torch.Tensor): if is_validation_enabled() and not (scale > 0).all(): - raise ValueError("Expected scale > 0 but got {}. ".format(scale) + - "Consider using poutine.mask() instead of poutine.scale().") + raise ValueError( + "Expected scale > 0 but got {}. ".format(scale) + + "Consider using poutine.mask() instead of poutine.scale()." + ) elif not (scale > 0): raise ValueError("Expected scale > 0 but got {}".format(scale)) super().__init__() diff --git a/pyro/poutine/seed_messenger.py b/pyro/poutine/seed_messenger.py index 48c4c2f564..c0cac35cf1 100644 --- a/pyro/poutine/seed_messenger.py +++ b/pyro/poutine/seed_messenger.py @@ -17,6 +17,7 @@ class SeedMessenger(Messenger): :param fn: a stochastic function (callable containing Pyro primitive calls). :param int rng_seed: rng seed. """ + def __init__(self, rng_seed): assert isinstance(rng_seed, int) self.rng_seed = rng_seed diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 4ee31cfdd5..751448fcec 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -32,8 +32,11 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None): self.use_cuda = use_cuda if self.use_cuda is not None: if self.use_cuda ^ (device != "cpu"): - raise ValueError("Incompatible arg values use_cuda={}, device={}." - .format(use_cuda, device)) + raise ValueError( + "Incompatible arg values use_cuda={}, device={}.".format( + use_cuda, device + ) + ) with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): self.device = torch.Tensor().device if not device else device @@ -49,13 +52,15 @@ def sample(self, sample_shape=torch.Size()): if subsample_size is None or subsample_size >= self.size: result = torch.arange(self.size, device=self.device) else: - result = torch.randperm(self.size, device=self.device)[:subsample_size].clone() + result = torch.randperm(self.size, device=self.device)[ + :subsample_size + ].clone() return result.cuda() if self.use_cuda else result def log_prob(self, x): # This is zero so that plate can provide an unbiased estimate of # the non-subsampled log_prob. - result = torch.tensor(0., device=self.device) + result = torch.tensor(0.0, device=self.device) return result.cuda() if self.use_cuda else result @@ -64,8 +69,16 @@ class SubsampleMessenger(IndepMessenger): Extension of IndepMessenger that includes subsampling. """ - def __init__(self, name, size=None, subsample_size=None, subsample=None, dim=None, - use_cuda=None, device=None): + def __init__( + self, + name, + size=None, + subsample_size=None, + subsample=None, + dim=None, + use_cuda=None, + device=None, + ): super().__init__(name, size, dim, device) self.subsample_size = subsample_size self._indices = subsample @@ -73,11 +86,18 @@ def __init__(self, name, size=None, subsample_size=None, subsample=None, dim=Non self.device = device self.size, self.subsample_size, self._indices = self._subsample( - self.name, self.size, self.subsample_size, - self._indices, self.use_cuda, self.device) + self.name, + self.size, + self.subsample_size, + self._indices, + self.use_cuda, + self.device, + ) @staticmethod - def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=None, device=None): + def _subsample( + name, size=None, subsample_size=None, subsample=None, use_cuda=None, device=None + ): """ Helper function for plate. See its docstrings for details. """ @@ -101,19 +121,25 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No "cond_indep_stack": (), "done": False, "stop": False, - "continuation": None + "continuation": None, } apply_stack(msg) subsample = msg["value"] with ignore_jit_warnings(): if subsample_size is None: - subsample_size = subsample.size(0) if isinstance(subsample, torch.Tensor) \ + subsample_size = ( + subsample.size(0) + if isinstance(subsample, torch.Tensor) else len(subsample) + ) elif subsample is not None and subsample_size != len(subsample): - raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + raise ValueError( + "subsample_size does not match len(subsample), {} vs {}.".format( + subsample_size, len(subsample) + ) + + " Did you accidentally use different subsample_size in the model and guide?" + ) return size, subsample_size, subsample @@ -122,10 +148,14 @@ def _reset(self): super()._reset() def _process_message(self, msg): - frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size, self.counter) + frame = CondIndepStackFrame( + self.name, self.dim, self.subsample_size, self.counter + ) frame.full_size = self.size # Used for param initialization. msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] - if isinstance(self.size, torch.Tensor) or isinstance(self.subsample_size, torch.Tensor): + if isinstance(self.size, torch.Tensor) or isinstance( + self.subsample_size, torch.Tensor + ): if not isinstance(msg["scale"], torch.Tensor): with ignore_jit_warnings(): msg["scale"] = torch.tensor(msg["scale"]) @@ -141,12 +171,18 @@ def _postprocess_message(self, msg): if len(shape) >= -dim and shape[dim] != 1: if is_validation_enabled() and shape[dim] != self.size: if msg["type"] == "param": - statement = "pyro.param({}, ..., event_dim={})".format(msg["name"], event_dim) + statement = "pyro.param({}, ..., event_dim={})".format( + msg["name"], event_dim + ) else: - statement = "pyro.subsample(..., event_dim={})".format(event_dim) + statement = "pyro.subsample(..., event_dim={})".format( + event_dim + ) raise ValueError( - "Inside pyro.plate({}, {}, dim={}) invalid shape of {}: {}" - .format(self.name, self.size, self.dim, statement, shape)) + "Inside pyro.plate({}, {}, dim={}) invalid shape of {}: {}".format( + self.name, self.size, self.dim, statement, shape + ) + ) # Subsample parameters with known batch semantics. if self.subsample_size < self.size: value = msg["value"] diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 4baa90d0e9..38bf826a96 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -24,8 +24,13 @@ def identify_dense_edges(trace): if past_name == name: break past_node_independent = False - for query, target in zip(node["cond_indep_stack"], past_node["cond_indep_stack"]): - if query.name == target.name and query.counter != target.counter: + for query, target in zip( + node["cond_indep_stack"], past_node["cond_indep_stack"] + ): + if ( + query.name == target.name + and query.counter != target.counter + ): past_node_independent = True break if not past_node_independent: @@ -110,10 +115,13 @@ def get_trace(self): def _reset(self): tr = Trace(graph_type=self.graph_type) if "_INPUT" in self.trace.nodes: - tr.add_node("_INPUT", - name="_INPUT", type="input", - args=self.trace.nodes["_INPUT"]["args"], - kwargs=self.trace.nodes["_INPUT"]["kwargs"]) + tr.add_node( + "_INPUT", + name="_INPUT", + type="input", + args=self.trace.nodes["_INPUT"]["args"], + kwargs=self.trace.nodes["_INPUT"]["kwargs"], + ) self.trace = tr super()._reset() @@ -141,6 +149,7 @@ class TraceHandler: We can also use this for visualization. """ + def __init__(self, msngr, fn): self.fn = fn self.msngr = msngr @@ -158,9 +167,9 @@ def __call__(self, *args, **kwargs): and returns self.fn's return value """ with self.msngr: - self.msngr.trace.add_node("_INPUT", - name="_INPUT", type="args", - args=args, kwargs=kwargs) + self.msngr.trace.add_node( + "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs + ) try: ret = self.fn(*args, **kwargs) except (ValueError, RuntimeError) as e: @@ -169,7 +178,9 @@ def __call__(self, *args, **kwargs): exc = exc_type(u"{}\n{}".format(exc_value, shapes)) exc = exc.with_traceback(traceback) raise exc from e - self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret) + self.msngr.trace.add_node( + "_RETURN", name="_RETURN", type="return", value=ret + ) return ret @property diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 176c7a772f..f8604cf985 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -71,8 +71,9 @@ class Trace: """ def __init__(self, graph_type="flat"): - assert graph_type in ("flat", "dense"), \ - "{} not a valid graph type".format(graph_type) + assert graph_type in ("flat", "dense"), "{} not a valid graph type".format( + graph_type + ) self.graph_type = graph_type self.nodes = OrderedDict() self._succ = OrderedDict() @@ -104,12 +105,16 @@ def add_node(self, site_name, **kwargs): """ if site_name in self: site = self.nodes[site_name] - if site['type'] != kwargs['type']: + if site["type"] != kwargs["type"]: # Cannot sample or observe after a param statement. - raise RuntimeError("{} is already in the trace as a {}".format(site_name, site['type'])) - elif kwargs['type'] != "param": + raise RuntimeError( + "{} is already in the trace as a {}".format(site_name, site["type"]) + ) + elif kwargs["type"] != "param": # Cannot sample after a previous sample statement. - raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name)) + raise RuntimeError( + "Multiple {} sites named '{}'".format(kwargs["type"], site_name) + ) # XXX should copy in case site gets mutated, or dont bother? self.nodes[site_name] = kwargs @@ -188,17 +193,26 @@ def log_prob_sum(self, site_filter=lambda name, site: True): log_p = site["log_prob_sum"] else: try: - log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"]) + log_p = site["fn"].log_prob( + site["value"], *site["args"], **site["kwargs"] + ) except ValueError as e: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) - raise ValueError("Error while computing log_prob_sum at site '{}':\n{}\n{}\n" - .format(name, exc_value, shapes)).with_traceback(traceback) from e + raise ValueError( + "Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format( + name, exc_value, shapes + ) + ).with_traceback(traceback) from e log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum() site["log_prob_sum"] = log_p if is_validation_enabled(): warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name)) - warn_if_inf(log_p, "log_prob_sum at site '{}'".format(name), allow_neginf=True) + warn_if_inf( + log_p, + "log_prob_sum at site '{}'".format(name), + allow_neginf=True, + ) result = result + log_p return result @@ -213,20 +227,31 @@ def compute_log_prob(self, site_filter=lambda name, site: True): if site["type"] == "sample" and site_filter(name, site): if "log_prob" not in site: try: - log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"]) + log_p = site["fn"].log_prob( + site["value"], *site["args"], **site["kwargs"] + ) except ValueError as e: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) - raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}" - .format(name, exc_value, shapes)).with_traceback(traceback) from e + raise ValueError( + "Error while computing log_prob at site '{}':\n{}\n{}".format( + name, exc_value, shapes + ) + ).with_traceback(traceback) from e site["unscaled_log_prob"] = log_p log_p = scale_and_mask(log_p, site["scale"], site["mask"]) site["log_prob"] = log_p site["log_prob_sum"] = log_p.sum() if is_validation_enabled(): - warn_if_nan(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)) - warn_if_inf(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name), - allow_neginf=True) + warn_if_nan( + site["log_prob_sum"], + "log_prob_sum at site '{}'".format(name), + ) + warn_if_inf( + site["log_prob_sum"], + "log_prob_sum at site '{}'".format(name), + allow_neginf=True, + ) def compute_score_parts(self): """ @@ -240,20 +265,31 @@ def compute_score_parts(self): # Note that ScoreParts overloads the multiplication operator # to correctly scale each of its three parts. try: - value = site["fn"].score_parts(site["value"], *site["args"], **site["kwargs"]) + value = site["fn"].score_parts( + site["value"], *site["args"], **site["kwargs"] + ) except ValueError as e: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) - raise ValueError("Error while computing score_parts at site '{}':\n{}\n{}" - .format(name, exc_value, shapes)).with_traceback(traceback) from e + raise ValueError( + "Error while computing score_parts at site '{}':\n{}\n{}".format( + name, exc_value, shapes + ) + ).with_traceback(traceback) from e site["unscaled_log_prob"] = value.log_prob value = value.scale_and_mask(site["scale"], site["mask"]) site["score_parts"] = value site["log_prob"] = value.log_prob site["log_prob_sum"] = value.log_prob.sum() if is_validation_enabled(): - warn_if_nan(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)) - warn_if_inf(site["log_prob_sum"], "log_prob_sum at site '{}'".format(name), allow_neginf=True) + warn_if_nan( + site["log_prob_sum"], "log_prob_sum at site '{}'".format(name) + ) + warn_if_inf( + site["log_prob_sum"], + "log_prob_sum at site '{}'".format(name), + allow_neginf=True, + ) def detach_(self): """ @@ -268,26 +304,29 @@ def observation_nodes(self): """ :return: a list of names of observe sites """ - return [name for name, node in self.nodes.items() - if node["type"] == "sample" and - node["is_observed"]] + return [ + name + for name, node in self.nodes.items() + if node["type"] == "sample" and node["is_observed"] + ] @property def param_nodes(self): """ :return: a list of names of param sites """ - return [name for name, node in self.nodes.items() - if node["type"] == "param"] + return [name for name, node in self.nodes.items() if node["type"] == "param"] @property def stochastic_nodes(self): """ :return: a list of names of sample sites """ - return [name for name, node in self.nodes.items() - if node["type"] == "sample" and - not node["is_observed"]] + return [ + name + for name, node in self.nodes.items() + if node["type"] == "sample" and not node["is_observed"] + ] @property def reparameterized_nodes(self): @@ -295,10 +334,13 @@ def reparameterized_nodes(self): :return: a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions """ - return [name for name, node in self.nodes.items() - if node["type"] == "sample" and - not node["is_observed"] and - getattr(node["fn"], "has_rsample", False)] + return [ + name + for name, node in self.nodes.items() + if node["type"] == "sample" + and not node["is_observed"] + and getattr(node["fn"], "has_rsample", False) + ] @property def nonreparam_stochastic_nodes(self): @@ -369,19 +411,28 @@ def pack_tensors(self, plate_to_symbol=None): log_prob = pack(log_prob, dim_to_symbol) score_function = pack(score_function, dim_to_symbol) entropy_term = pack(entropy_term, dim_to_symbol) - packed["score_parts"] = ScoreParts(log_prob, score_function, entropy_term) + packed["score_parts"] = ScoreParts( + log_prob, score_function, entropy_term + ) packed["log_prob"] = log_prob - packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol) + packed["unscaled_log_prob"] = pack( + site["unscaled_log_prob"], dim_to_symbol + ) elif "log_prob" in site: packed["log_prob"] = pack(site["log_prob"], dim_to_symbol) - packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol) + packed["unscaled_log_prob"] = pack( + site["unscaled_log_prob"], dim_to_symbol + ) except ValueError as e: _, exc_value, traceback = sys.exc_info() shapes = self.format_shapes(last_site=site["name"]) - raise ValueError("Error while packing tensors at site '{}':\n {}\n{}" - .format(site["name"], exc_value, shapes)).with_traceback(traceback) from e + raise ValueError( + "Error while packing tensors at site '{}':\n {}\n{}".format( + site["name"], exc_value, shapes + ) + ).with_traceback(traceback) from e - def format_shapes(self, title='Trace Shapes:', last_site=None): + def format_shapes(self, title="Trace Shapes:", last_site=None): """ Returns a string showing a table of the shapes of all sites in the trace. @@ -390,34 +441,46 @@ def format_shapes(self, title='Trace Shapes:', last_site=None): return title rows = [[title]] - rows.append(['Param Sites:']) + rows.append(["Param Sites:"]) for name, site in self.nodes.items(): if site["type"] == "param": rows.append([name, None] + [str(size) for size in site["value"].shape]) if name == last_site: break - rows.append(['Sample Sites:']) + rows.append(["Sample Sites:"]) for name, site in self.nodes.items(): if site["type"] == "sample": # param shape batch_shape = getattr(site["fn"], "batch_shape", ()) event_shape = getattr(site["fn"], "event_shape", ()) - rows.append([name + " dist", None] + [str(size) for size in batch_shape] + - ["|", None] + [str(size) for size in event_shape]) + rows.append( + [name + " dist", None] + + [str(size) for size in batch_shape] + + ["|", None] + + [str(size) for size in event_shape] + ) # value shape event_dim = len(event_shape) shape = getattr(site["value"], "shape", ()) - batch_shape = shape[:len(shape) - event_dim] - event_shape = shape[len(shape) - event_dim:] - rows.append(["value", None] + [str(size) for size in batch_shape] + - ["|", None] + [str(size) for size in event_shape]) + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] + rows.append( + ["value", None] + + [str(size) for size in batch_shape] + + ["|", None] + + [str(size) for size in event_shape] + ) # log_prob shape if "log_prob" in site: batch_shape = getattr(site["log_prob"], "shape", ()) - rows.append(["log_prob", None] + [str(size) for size in batch_shape] + ["|", None]) + rows.append( + ["log_prob", None] + + [str(size) for size in batch_shape] + + ["|", None] + ) if name == last_site: break @@ -450,10 +513,12 @@ def _format_table(rows): j += 1 else: cols[j].append(cell) - cols = [[""] * (width - len(col)) + col - if direction == 'r' else - col + [""] * (width - len(col)) - for width, col, direction in zip(column_widths, cols, 'rrl')] + cols = [ + [""] * (width - len(col)) + col + if direction == "r" + else col + [""] * (width - len(col)) + for width, col, direction in zip(column_widths, cols, "rrl") + ] rows[i] = sum(cols, []) # compute cell widths @@ -463,6 +528,7 @@ def _format_table(rows): cell_widths[j] = max(cell_widths[j], len(cell)) # justify cells - return "\n".join(" ".join(cell.rjust(width) - for cell, width in zip(row, cell_widths)) - for row in rows) + return "\n".join( + " ".join(cell.rjust(width) for cell, width in zip(row, cell_widths)) + for row in rows + ) diff --git a/pyro/poutine/uncondition_messenger.py b/pyro/poutine/uncondition_messenger.py index f09eaa77ea..8db8a56ae9 100644 --- a/pyro/poutine/uncondition_messenger.py +++ b/pyro/poutine/uncondition_messenger.py @@ -9,6 +9,7 @@ class UnconditionMessenger(Messenger): Messenger to force the value of observed nodes to be sampled from their distribution, ignoring observations. """ + def __init__(self): super().__init__() diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index 4ef4c9c412..e90c2a5917 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -101,10 +101,12 @@ def discrete_escape(trace, msg): Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction. """ - return (msg["type"] == "sample") and \ - (not msg["is_observed"]) and \ - (msg["name"] not in trace) and \ - (getattr(msg["fn"], "has_enumerate_support", False)) + return ( + (msg["type"] == "sample") + and (not msg["is_observed"]) + and (msg["name"] not in trace) + and (getattr(msg["fn"], "has_enumerate_support", False)) + ) def all_escape(trace, msg): @@ -118,6 +120,8 @@ def all_escape(trace, msg): Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction. """ - return (msg["type"] == "sample") and \ - (not msg["is_observed"]) and \ - (msg["name"] not in trace) + return ( + (msg["type"] == "sample") + and (not msg["is_observed"]) + and (msg["name"] not in trace) + ) diff --git a/pyro/primitives.py b/pyro/primitives.py index aef5722f00..a96e75e44d 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -93,7 +93,7 @@ def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs): except RuntimeError as e: if "must match the size of tensor" in str(e): shape = torch.broadcast_shapes(observed.shape, unobserved.shape) - batch_shape = shape[:len(shape) - fn.event_dim] + batch_shape = shape[: len(shape) - fn.event_dim] raise ValueError( f"Invalid obs_mask shape {tuple(obs_mask.shape)}; should be " f"broadcastable to batch_shape = {tuple(batch_shape)}" @@ -135,8 +135,10 @@ def sample(name, fn, *args, **kwargs): is_observed = infer.pop("is_observed", obs is not None) if not am_i_wrapped(): if obs is not None and not infer.get("_deterministic"): - warnings.warn("trying to observe a value outside of inference at " + name, - RuntimeWarning) + warnings.warn( + "trying to observe a value outside of inference at " + name, + RuntimeWarning, + ) return obs return fn(*args, **kwargs) # if stack not empty, apply everything in the stack? @@ -156,7 +158,7 @@ def sample(name, fn, *args, **kwargs): "cond_indep_stack": (), "done": False, "stop": False, - "continuation": None + "continuation": None, } # apply the stack and return its return value apply_stack(msg) @@ -201,8 +203,12 @@ def deterministic(name, value, event_dim=None): :param int event_dim: Optional event dimension, defaults to `value.ndim`. """ event_dim = value.ndim if event_dim is None else event_dim - return sample(name, dist.Delta(value, event_dim=event_dim).mask(False), - obs=value, infer={"_deterministic": True}) + return sample( + name, + dist.Delta(value, event_dim=event_dim).mask(False), + obs=value, + infer={"_deterministic": True}, + ) @effectful(type="subsample") @@ -344,18 +350,23 @@ class plate(PlateMessenger): See `SVI Part II `_ for an extended discussion. """ + pass class iarange(plate): def __init__(self, *args, **kwargs): - warnings.warn("pyro.iarange is deprecated; use pyro.plate instead", DeprecationWarning) + warnings.warn( + "pyro.iarange is deprecated; use pyro.plate instead", DeprecationWarning + ) super().__init__(*args, **kwargs) class irange(SubsampleMessenger): def __init__(self, *args, **kwargs): - warnings.warn("pyro.irange is deprecated; use pyro.plate instead", DeprecationWarning) + warnings.warn( + "pyro.irange is deprecated; use pyro.plate instead", DeprecationWarning + ) super().__init__(*args, **kwargs) @@ -404,12 +415,15 @@ def module(name, nn_module, update_module_params=False): :returns: torch.nn.Module """ assert hasattr(nn_module, "parameters"), "module has no parameters" - assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\ - _MODULE_NAMESPACE_DIVIDER + assert _MODULE_NAMESPACE_DIVIDER not in name, ( + "improper module name, since contains %s" % _MODULE_NAMESPACE_DIVIDER + ) if isclass(nn_module): - raise NotImplementedError("pyro.module does not support class constructors for " + - "the argument nn_module") + raise NotImplementedError( + "pyro.module does not support class constructors for " + + "the argument nn_module" + ) target_state_dict = OrderedDict() @@ -423,15 +437,17 @@ def module(name, nn_module, update_module_params=False): if param_value._cdata != returned_param._cdata: target_state_dict[param_name] = returned_param elif nn_module.training: - warnings.warn(f"{param_name} was not registered in the param store " - "because requires_grad=False. You can silence this " - "warning by calling my_module.train(False)") + warnings.warn( + f"{param_name} was not registered in the param store " + "because requires_grad=False. You can silence this " + "warning by calling my_module.train(False)" + ) if target_state_dict and update_module_params: # WARNING: this is very dangerous. better method? for _name, _param in nn_module.named_parameters(): is_param = False - name_arr = _name.rsplit('.', 1) + name_arr = _name.rsplit(".", 1) if len(name_arr) > 1: mod_name, param_name = name_arr[0], name_arr[1] else: @@ -439,7 +455,9 @@ def module(name, nn_module, update_module_params=False): mod_name = _name if _name in target_state_dict.keys(): if not is_param: - deep_getattr(nn_module, mod_name)._parameters[param_name] = target_state_dict[_name] + deep_getattr(nn_module, mod_name)._parameters[ + param_name + ] = target_state_dict[_name] else: nn_module._parameters[mod_name] = target_state_dict[_name] @@ -467,9 +485,12 @@ def random_module(name, nn_module, prior, *args, **kwargs): as keys and respective distributions/stochastic functions as values. :returns: a callable which returns a sampled module """ - warnings.warn("The `random_module` primitive is deprecated, and will be removed " - "in a future release. Use `pyro.nn.Module` to create Bayesian " - "modules from `torch.nn.Module` instances.", FutureWarning) + warnings.warn( + "The `random_module` primitive is deprecated, and will be removed " + "in a future release. Use `pyro.nn.Module` to create Bayesian " + "modules from `torch.nn.Module` instances.", + FutureWarning, + ) assert hasattr(nn_module, "parameters"), "Module is not a NN module." # register params in param store @@ -479,6 +500,7 @@ def _fn(): nn_copy = copy.deepcopy(nn_module) # update_module_params must be True or the lifted module will not update local params return lifted_fn(name, nn_copy, update_module_params=True, *args, **kwargs) + return _fn diff --git a/pyro/util.py b/pyro/util.py index fd6299b8e6..1acc5ebb1e 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -30,15 +30,20 @@ def set_rng_seed(rng_seed): def get_rng_state(): - return {'torch': torch.get_rng_state(), 'random': random.getstate(), 'numpy': np.random.get_state()} + return { + "torch": torch.get_rng_state(), + "random": random.getstate(), + "numpy": np.random.get_state(), + } def set_rng_state(state): - torch.set_rng_state(state['torch']) - random.setstate(state['random']) - if 'numpy' in state: + torch.set_rng_state(state["torch"]) + random.setstate(state["random"]) + if "numpy" in state: import numpy as np - np.random.set_state(state['numpy']) + + np.random.set_state(state["numpy"]) def torch_isnan(x): @@ -75,17 +80,26 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): lineno = frame.f_lineno if torch.is_tensor(value) and value.requires_grad: - value.register_hook(lambda x: warn_if_nan(x, "backward " + msg, filename=filename, lineno=lineno)) + value.register_hook( + lambda x: warn_if_nan( + x, "backward " + msg, filename=filename, lineno=lineno + ) + ) if torch_isnan(value): - warnings.warn_explicit("Encountered NaN{}".format(': ' + msg if msg else '.'), - UserWarning, filename, lineno) + warnings.warn_explicit( + "Encountered NaN{}".format(": " + msg if msg else "."), + UserWarning, + filename, + lineno, + ) return value -def warn_if_inf(value, msg="", allow_posinf=False, allow_neginf=False, *, - filename=None, lineno=None): +def warn_if_inf( + value, msg="", allow_posinf=False, allow_neginf=False, *, filename=None, lineno=None +): """ A convenient function to warn if a Tensor or its grad contains any inf, also works with numbers. @@ -101,18 +115,39 @@ def warn_if_inf(value, msg="", allow_posinf=False, allow_neginf=False, *, lineno = frame.f_lineno if torch.is_tensor(value) and value.requires_grad: - value.register_hook(lambda x: warn_if_inf(x, "backward " + msg, - allow_posinf, allow_neginf, - filename=filename, lineno=lineno)) - - if (not allow_posinf) and (value == math.inf if isinstance(value, numbers.Number) - else (value == math.inf).any()): - warnings.warn_explicit("Encountered +inf{}".format(': ' + msg if msg else '.'), - UserWarning, filename, lineno) - if (not allow_neginf) and (value == -math.inf if isinstance(value, numbers.Number) - else (value == -math.inf).any()): - warnings.warn_explicit("Encountered -inf{}".format(': ' + msg if msg else '.'), - UserWarning, filename, lineno) + value.register_hook( + lambda x: warn_if_inf( + x, + "backward " + msg, + allow_posinf, + allow_neginf, + filename=filename, + lineno=lineno, + ) + ) + + if (not allow_posinf) and ( + value == math.inf + if isinstance(value, numbers.Number) + else (value == math.inf).any() + ): + warnings.warn_explicit( + "Encountered +inf{}".format(": " + msg if msg else "."), + UserWarning, + filename, + lineno, + ) + if (not allow_neginf) and ( + value == -math.inf + if isinstance(value, numbers.Number) + else (value == -math.inf).any() + ): + warnings.warn_explicit( + "Encountered -inf{}".format(": " + msg if msg else "."), + UserWarning, + filename, + lineno, + ) return value @@ -136,8 +171,10 @@ def save_visualization(trace, graph_output): trace = pyro.poutine.trace(model, graph_type="dense").get_trace() save_visualization(trace, 'output') """ - warnings.warn("`save_visualization` function is deprecated and will be removed in " - "a future version.") + warnings.warn( + "`save_visualization` function is deprecated and will be removed in " + "a future version." + ) import graphviz @@ -146,17 +183,17 @@ def save_visualization(trace, graph_output): for label, node in trace.nodes.items(): if site_is_subsample(node): continue - shape = 'ellipse' + shape = "ellipse" if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes: - fillcolor = 'salmon' + fillcolor = "salmon" elif label in trace.reparameterized_nodes: - fillcolor = 'lightgrey;.5:salmon' + fillcolor = "lightgrey;.5:salmon" elif label in trace.observation_nodes: - fillcolor = 'darkolivegreen3' + fillcolor = "darkolivegreen3" else: # only visualize RVs continue - g.node(label, label=label, shape=shape, style='filled', fillcolor=fillcolor) + g.node(label, label=label, shape=shape, style="filled", fillcolor=fillcolor) for label1, label2 in trace.edges: if site_is_subsample(trace.nodes[label1]): @@ -191,7 +228,11 @@ def check_traces_match(trace1, trace2): shape1 = site1["fn"].shape(*site1["args"], **site1["kwargs"]) shape2 = site2["fn"].shape(*site2["args"], **site2["kwargs"]) if shape1 != shape2: - raise ValueError("Site dims disagree at site '{}': {} vs {}".format(name, shape1, shape2)) + raise ValueError( + "Site dims disagree at site '{}': {} vs {}".format( + name, shape1, shape2 + ) + ) def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf): @@ -211,66 +252,115 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf and guide agree on sample shape. """ # Check ordinary sample sites. - guide_vars = set(name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" - if type(site["fn"]).__name__ != "_Subsample") - aux_vars = set(name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" - if site["infer"].get("is_auxiliary")) - model_vars = set(name for name, site in model_trace.nodes.items() - if site["type"] == "sample" and not site["is_observed"] - if type(site["fn"]).__name__ != "_Subsample") - enum_vars = set(name for name, site in model_trace.nodes.items() - if site["type"] == "sample" and not site["is_observed"] - if type(site["fn"]).__name__ != "_Subsample" - if site["infer"].get("_enumerate_dim") is not None - if name not in guide_vars) + guide_vars = set( + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" + if type(site["fn"]).__name__ != "_Subsample" + ) + aux_vars = set( + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" + if site["infer"].get("is_auxiliary") + ) + model_vars = set( + name + for name, site in model_trace.nodes.items() + if site["type"] == "sample" and not site["is_observed"] + if type(site["fn"]).__name__ != "_Subsample" + ) + enum_vars = set( + name + for name, site in model_trace.nodes.items() + if site["type"] == "sample" and not site["is_observed"] + if type(site["fn"]).__name__ != "_Subsample" + if site["infer"].get("_enumerate_dim") is not None + if name not in guide_vars + ) if aux_vars & model_vars: - warnings.warn("Found auxiliary vars in the model: {}".format(aux_vars & model_vars)) + warnings.warn( + "Found auxiliary vars in the model: {}".format(aux_vars & model_vars) + ) if not (guide_vars <= model_vars | aux_vars): - warnings.warn("Found non-auxiliary vars in guide but not model, " - "consider marking these infer={{'is_auxiliary': True}}:\n{}".format( - guide_vars - aux_vars - model_vars)) + warnings.warn( + "Found non-auxiliary vars in guide but not model, " + "consider marking these infer={{'is_auxiliary': True}}:\n{}".format( + guide_vars - aux_vars - model_vars + ) + ) if not (model_vars <= guide_vars | enum_vars): - warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars)) + warnings.warn( + "Found vars in model but not guide: {}".format( + model_vars - guide_vars - enum_vars + ) + ) # Check shapes agree. for name in model_vars & guide_vars: model_site = model_trace.nodes[name] guide_site = guide_trace.nodes[name] - if hasattr(model_site["fn"], "event_dim") and hasattr(guide_site["fn"], "event_dim"): + if hasattr(model_site["fn"], "event_dim") and hasattr( + guide_site["fn"], "event_dim" + ): if model_site["fn"].event_dim != guide_site["fn"].event_dim: - raise ValueError("Model and guide event_dims disagree at site '{}': {} vs {}".format( - name, model_site["fn"].event_dim, guide_site["fn"].event_dim)) + raise ValueError( + "Model and guide event_dims disagree at site '{}': {} vs {}".format( + name, model_site["fn"].event_dim, guide_site["fn"].event_dim + ) + ) if hasattr(model_site["fn"], "shape") and hasattr(guide_site["fn"], "shape"): - model_shape = model_site["fn"].shape(*model_site["args"], **model_site["kwargs"]) - guide_shape = guide_site["fn"].shape(*guide_site["args"], **guide_site["kwargs"]) + model_shape = model_site["fn"].shape( + *model_site["args"], **model_site["kwargs"] + ) + guide_shape = guide_site["fn"].shape( + *guide_site["args"], **guide_site["kwargs"] + ) if model_shape == guide_shape: continue # Allow broadcasting outside of max_plate_nesting. if len(model_shape) > max_plate_nesting: - model_shape = model_shape[len(model_shape) - max_plate_nesting - model_site["fn"].event_dim:] + model_shape = model_shape[ + len(model_shape) - max_plate_nesting - model_site["fn"].event_dim : + ] if len(guide_shape) > max_plate_nesting: - guide_shape = guide_shape[len(guide_shape) - max_plate_nesting - guide_site["fn"].event_dim:] + guide_shape = guide_shape[ + len(guide_shape) - max_plate_nesting - guide_site["fn"].event_dim : + ] if model_shape == guide_shape: continue - for model_size, guide_size in zip_longest(reversed(model_shape), reversed(guide_shape), fillvalue=1): + for model_size, guide_size in zip_longest( + reversed(model_shape), reversed(guide_shape), fillvalue=1 + ): if model_size != guide_size: - raise ValueError("Model and guide shapes disagree at site '{}': {} vs {}".format( - name, model_shape, guide_shape)) + raise ValueError( + "Model and guide shapes disagree at site '{}': {} vs {}".format( + name, model_shape, guide_shape + ) + ) # Check subsample sites introduced by plate. - model_vars = set(name for name, site in model_trace.nodes.items() - if site["type"] == "sample" and not site["is_observed"] - if type(site["fn"]).__name__ == "_Subsample") - guide_vars = set(name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" - if type(site["fn"]).__name__ == "_Subsample") + model_vars = set( + name + for name, site in model_trace.nodes.items() + if site["type"] == "sample" and not site["is_observed"] + if type(site["fn"]).__name__ == "_Subsample" + ) + guide_vars = set( + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" + if type(site["fn"]).__name__ == "_Subsample" + ) if not (guide_vars <= model_vars): - warnings.warn("Found plate statements in guide but not model: {}".format(guide_vars - model_vars)) + warnings.warn( + "Found plate statements in guide but not model: {}".format( + guide_vars - model_vars + ) + ) def check_site_shape(site, max_plate_nesting): @@ -283,42 +373,73 @@ def check_site_shape(site, max_plate_nesting): # Use the specified plate dimension, which counts from the right. assert f.dim < 0 if len(expected_shape) < -f.dim: - expected_shape = [None] * (-f.dim - len(expected_shape)) + expected_shape + expected_shape = [None] * ( + -f.dim - len(expected_shape) + ) + expected_shape if expected_shape[f.dim] is not None: - raise ValueError('\n '.join([ - 'at site "{}" within plate("{}", dim={}), dim collision'.format(site["name"], f.name, f.dim), - 'Try setting dim arg in other plates.'])) + raise ValueError( + "\n ".join( + [ + 'at site "{}" within plate("{}", dim={}), dim collision'.format( + site["name"], f.name, f.dim + ), + "Try setting dim arg in other plates.", + ] + ) + ) expected_shape[f.dim] = f.size expected_shape = [-1 if e is None else e for e in expected_shape] # Check for plate stack overflow. if len(expected_shape) > max_plate_nesting: - raise ValueError('\n '.join([ - 'at site "{}", plate stack overflow'.format(site["name"]), - 'Try increasing max_plate_nesting to at least {}'.format(len(expected_shape))])) + raise ValueError( + "\n ".join( + [ + 'at site "{}", plate stack overflow'.format(site["name"]), + "Try increasing max_plate_nesting to at least {}".format( + len(expected_shape) + ), + ] + ) + ) # Ignore dimensions left of max_plate_nesting. if max_plate_nesting < len(actual_shape): - actual_shape = actual_shape[len(actual_shape) - max_plate_nesting:] + actual_shape = actual_shape[len(actual_shape) - max_plate_nesting :] # Check for incorrect plate placement on the right of max_plate_nesting. - for actual_size, expected_size in zip_longest(reversed(actual_shape), reversed(expected_shape), fillvalue=1): + for actual_size, expected_size in zip_longest( + reversed(actual_shape), reversed(expected_shape), fillvalue=1 + ): if expected_size != -1 and expected_size != actual_size: - raise ValueError('\n '.join([ - 'at site "{}", invalid log_prob shape'.format(site["name"]), - 'Expected {}, actual {}'.format(expected_shape, actual_shape), - 'Try one of the following fixes:', - '- enclose the batched tensor in a with pyro.plate(...): context', - '- .to_event(...) the distribution being sampled', - '- .permute() data dimensions'])) + raise ValueError( + "\n ".join( + [ + 'at site "{}", invalid log_prob shape'.format(site["name"]), + "Expected {}, actual {}".format(expected_shape, actual_shape), + "Try one of the following fixes:", + "- enclose the batched tensor in a with pyro.plate(...): context", + "- .to_event(...) the distribution being sampled", + "- .permute() data dimensions", + ] + ) + ) # Check parallel dimensions on the left of max_plate_nesting. enum_dim = site["infer"].get("_enumerate_dim") if enum_dim is not None: - if len(site["fn"].batch_shape) >= -enum_dim and site["fn"].batch_shape[enum_dim] != 1: - raise ValueError('\n '.join([ - 'Enumeration dim conflict at site "{}"'.format(site["name"]), - 'Try increasing pyro.markov history size'])) + if ( + len(site["fn"].batch_shape) >= -enum_dim + and site["fn"].batch_shape[enum_dim] != 1 + ): + raise ValueError( + "\n ".join( + [ + 'Enumeration dim conflict at site "{}"'.format(site["name"]), + "Try increasing pyro.markov history size", + ] + ) + ) def _are_independent(counters1, counters2): @@ -341,30 +462,50 @@ def check_traceenum_requirements(model_trace, guide_trace): this function aims to warn only in cases where models can be easily rewitten to be obviously correct. """ - enumerated_sites = set(name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and site["infer"].get("enumerate")) - for role, trace in [('model', model_trace), ('guide', guide_trace)]: + enumerated_sites = set( + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and site["infer"].get("enumerate") + ) + for role, trace in [("model", model_trace), ("guide", guide_trace)]: plate_counters = {} # for sequential plates only enumerated_contexts = defaultdict(set) for name, site in trace.nodes.items(): if site["type"] != "sample": continue - plate_counter = {f.name: f.counter for f in site["cond_indep_stack"] if not f.vectorized} + plate_counter = { + f.name: f.counter for f in site["cond_indep_stack"] if not f.vectorized + } context = frozenset(f for f in site["cond_indep_stack"] if f.vectorized) # Check that sites outside each independence context precede enumerated sites inside that context. for enumerated_context, names in enumerated_contexts.items(): if not (context < enumerated_context): continue - names = sorted(n for n in names if not _are_independent(plate_counter, plate_counters[n])) + names = sorted( + n + for n in names + if not _are_independent(plate_counter, plate_counters[n]) + ) if not names: continue diff = sorted(f.name for f in enumerated_context - context) - warnings.warn('\n '.join([ - 'at {} site "{}", possibly invalid dependency.'.format(role, name), - 'Expected site "{}" to precede sites "{}"'.format(name, '", "'.join(sorted(names))), - 'to avoid breaking independence of plates "{}"'.format('", "'.join(diff)), - ]), RuntimeWarning) + warnings.warn( + "\n ".join( + [ + 'at {} site "{}", possibly invalid dependency.'.format( + role, name + ), + 'Expected site "{}" to precede sites "{}"'.format( + name, '", "'.join(sorted(names)) + ), + 'to avoid breaking independence of plates "{}"'.format( + '", "'.join(diff) + ), + ] + ), + RuntimeWarning, + ) plate_counters[name] = plate_counter if name in enumerated_sites: @@ -372,13 +513,21 @@ def check_traceenum_requirements(model_trace, guide_trace): def check_if_enumerated(guide_trace): - enumerated_sites = [name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and site["infer"].get("enumerate")] + enumerated_sites = [ + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and site["infer"].get("enumerate") + ] if enumerated_sites: - warnings.warn('\n'.join([ - 'Found sample sites configured for enumeration:' - ', '.join(enumerated_sites), - 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) + warnings.warn( + "\n".join( + [ + "Found sample sites configured for enumeration:" + ", ".join(enumerated_sites), + "If you want to enumerate sites, you need to use TraceEnum_ELBO instead.", + ] + ) + ) @contextmanager @@ -398,16 +547,13 @@ def ignore_jit_warnings(filter=None): with warnings.catch_warnings(): if filter is None: - warnings.filterwarnings("ignore", - category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) else: for msg in filter: category = torch.jit.TracerWarning if isinstance(msg, tuple): msg, category = msg - warnings.filterwarnings("ignore", - category=category, - message=msg) + warnings.filterwarnings("ignore", category=category, message=msg) yield @@ -427,6 +573,7 @@ class optional: """ Optionally wrap inside `context_manager` if condition is `True`. """ + def __init__(self, context_manager, condition): self.context_manager = context_manager self.condition = condition @@ -447,7 +594,7 @@ class ExperimentalWarning(UserWarning): @contextmanager def ignore_experimental_warning(): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=ExperimentalWarning) + warnings.filterwarnings("ignore", category=ExperimentalWarning) yield diff --git a/scripts/update_headers.py b/scripts/update_headers.py index 4ac35288cf..48ccf2dfb8 100644 --- a/scripts/update_headers.py +++ b/scripts/update_headers.py @@ -69,10 +69,9 @@ with open(filename, "w") as f: f.write("".join(lines)) - print("updated {}".format(filename[len(root) + 1:])) + print("updated {}".format(filename[len(root) + 1 :])) if dirty: - print("The following files need license headers:\n{}" - .format("\n".join(dirty))) + print("The following files need license headers:\n{}".format("\n".join(dirty))) print("Please run 'make license'") sys.exit(1) diff --git a/scripts/update_version.py b/scripts/update_version.py index 7cccdc8873..82361b9d11 100644 --- a/scripts/update_version.py +++ b/scripts/update_version.py @@ -17,8 +17,7 @@ filenames = [] for path in ["examples", "tutorial/source"]: for ext in ["*.py", "*.ipynb"]: - filenames.extend(glob.glob(os.path.join(root, path, "**", ext), - recursive=True)) + filenames.extend(glob.glob(os.path.join(root, path, "**", ext), recursive=True)) filenames.sort() # Update version string. diff --git a/setup.cfg b/setup.cfg index 942055b096..4e98b0f2fc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] max-line-length = 120 exclude = docs/src, build, dist, .ipynb_checkpoints -extend-ignore = E721,E741 +extend-ignore = E721,E741,E203 [isort] profile = black diff --git a/setup.py b/setup.py index df957d9086..a09ec11dc0 100644 --- a/setup.py +++ b/setup.py @@ -15,27 +15,35 @@ """ # Find pyro version. -for line in open(os.path.join(PROJECT_PATH, 'pyro', '__init__.py')): - if line.startswith('version_prefix = '): +for line in open(os.path.join(PROJECT_PATH, "pyro", "__init__.py")): + if line.startswith("version_prefix = "): version = line.strip().split()[2][1:-1] # Append current commit sha to version -commit_sha = '' +commit_sha = "" try: - current_tag = subprocess.check_output(['git', 'tag', '--points-at', 'HEAD'], - cwd=PROJECT_PATH).decode('ascii').strip() + current_tag = ( + subprocess.check_output(["git", "tag", "--points-at", "HEAD"], cwd=PROJECT_PATH) + .decode("ascii") + .strip() + ) # only add sha if HEAD does not point to the release tag if not current_tag == version: - commit_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], - cwd=PROJECT_PATH).decode('ascii').strip() + commit_sha = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], cwd=PROJECT_PATH + ) + .decode("ascii") + .strip() + ) # catch all exception to be safe except Exception: pass # probably not a git repo # Write version to _version.py if commit_sha: - version += '+{}'.format(commit_sha) -with open(os.path.join(PROJECT_PATH, 'pyro', '_version.py'), 'w') as f: + version += "+{}".format(commit_sha) +with open(os.path.join(PROJECT_PATH, "pyro", "_version.py"), "w") as f: f.write(VERSION.format(version)) @@ -45,98 +53,103 @@ # $ twine upload --repository-url https://test.pypi.org/legacy/ dist/* # test version # $ twine upload dist/* try: - long_description = open('README.md', encoding='utf-8').read() + long_description = open("README.md", encoding="utf-8").read() except Exception as e: - sys.stderr.write('Failed to read README.md: {}\n'.format(e)) + sys.stderr.write("Failed to read README.md: {}\n".format(e)) sys.stderr.flush() - long_description = '' + long_description = "" # Remove badges since they will always be obsolete. # This assumes the first 12 lines contain badge info. -long_description = '\n'.join([str(line) for line in long_description.split('\n')[12:]]) +long_description = "\n".join([str(line) for line in long_description.split("\n")[12:]]) # examples/tutorials EXTRAS_REQUIRE = [ - 'jupyter>=1.0.0', - 'graphviz>=0.8', - 'matplotlib>=1.3', - 'torchvision>=0.10.0', - 'visdom>=0.1.4', - 'pandas', - 'pillow==8.2.0', # https://github.com/pytorch/pytorch/issues/61125 - 'scikit-learn', - 'seaborn', - 'wget', - 'lap', + "jupyter>=1.0.0", + "graphviz>=0.8", + "matplotlib>=1.3", + "torchvision>=0.10.0", + "visdom>=0.1.4", + "pandas", + "pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125 + "scikit-learn", + "seaborn", + "wget", + "lap", # 'biopython>=1.54', # Requires Python 3.6 # 'scanpy>=1.4', # Requires HDF5 # 'scvi>=0.6', # Requires loopy and other fragile packages ] setup( - name='pyro-ppl', + name="pyro-ppl", version=version, - description='A Python library for probabilistic modeling and inference', + description="A Python library for probabilistic modeling and inference", long_description=long_description, - long_description_content_type='text/markdown', - packages=find_packages(include=['pyro', 'pyro.*']), + long_description_content_type="text/markdown", + packages=find_packages(include=["pyro", "pyro.*"]), package_data={"pyro.distributions": ["*.cpp"]}, author="Uber AI Labs", - url='http://pyro.ai', + url="http://pyro.ai", install_requires=[ # if you add any additional libraries, please also # add them to `docs/requirements.txt` # numpy is necessary for some functionality of PyTorch - 'numpy>=1.7', - 'opt_einsum>=2.3.2', - 'pyro-api>=0.1.1', - 'torch>=1.9.0', - 'tqdm>=4.36', + "numpy>=1.7", + "opt_einsum>=2.3.2", + "pyro-api>=0.1.1", + "torch>=1.9.0", + "tqdm>=4.36", ], extras_require={ - 'extras': EXTRAS_REQUIRE, - 'test': EXTRAS_REQUIRE + [ - 'nbval', - 'pytest>=5.0', - 'pytest-cov', - 'scipy>=1.1', + "extras": EXTRAS_REQUIRE, + "test": EXTRAS_REQUIRE + + [ + "black>=21.4b0", + "flake8", + "nbval", + "pytest>=5.0", + "pytest-cov", + "scipy>=1.1", ], - 'profile': ['prettytable', 'pytest-benchmark', 'snakeviz'], - 'dev': EXTRAS_REQUIRE + [ - 'flake8', - 'isort>=5.0', - 'mypy>=0.812', - 'nbformat', - 'nbsphinx>=0.3.2', - 'nbstripout', - 'nbval', - 'ninja', - 'pypandoc', - 'pytest>=5.0', - 'pytest-xdist', - 'scipy>=1.1', - 'sphinx', - 'sphinx_rtd_theme', - 'yapf', + "profile": ["prettytable", "pytest-benchmark", "snakeviz"], + "dev": EXTRAS_REQUIRE + + [ + "black>=21.4b0", + "flake8", + "isort>=5.0", + "mypy>=0.812", + "nbformat", + "nbsphinx>=0.3.2", + "nbstripout", + "nbval", + "ninja", + "pypandoc", + "pytest>=5.0", + "pytest-xdist", + "scipy>=1.1", + "sphinx", + "sphinx_rtd_theme", + "yapf", ], - 'horovod': ['horovod[pytorch]>=0.19'], - 'funsor': [ + "horovod": ["horovod[pytorch]>=0.19"], + "funsor": [ # This must be a released version when Pyro is released. - 'funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@383e7a6d05c9d5de9646d23698891e10c4cba927', + "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@383e7a6d05c9d5de9646d23698891e10c4cba927", ], }, - python_requires='>=3.6', - keywords='machine learning statistics probabilistic programming bayesian modeling pytorch', - license='Apache 2.0', + python_requires=">=3.6", + keywords="machine learning statistics probabilistic programming bayesian modeling pytorch", + license="Apache 2.0", classifiers=[ - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: POSIX :: Linux', - 'Operating System :: MacOS :: MacOS X', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], # yapf ) diff --git a/tests/__init__.py b/tests/__init__.py index 4056718ce7..04f37403cb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,4 +6,4 @@ # create log handler for tests level = logging.INFO if "CI" in os.environ else logging.DEBUG -logging.basicConfig(format='%(levelname).1s \t %(message)s', level=level) +logging.basicConfig(format="%(levelname).1s \t %(message)s", level=level) diff --git a/tests/common.py b/tests/common.py index 481820c9db..9403e23bfd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -25,8 +25,8 @@ """ TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) -RESOURCE_DIR = os.path.join(TESTS_DIR, 'resources') -EXAMPLES_DIR = os.path.join(os.path.dirname(TESTS_DIR), 'examples') +RESOURCE_DIR = os.path.join(TESTS_DIR, "resources") +EXAMPLES_DIR = os.path.join(os.path.dirname(TESTS_DIR), "examples") TEST_FAILURE_RATE = 2e-5 # For all goodness-of-fit tests. @@ -69,31 +69,32 @@ def TemporaryDirectory(): shutil.rmtree(path) -requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), - reason="cuda is not available") +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="cuda is not available" +) try: import horovod except ImportError: horovod = None -requires_horovod = pytest.mark.skipif(horovod is None, - reason="horovod is not available") +requires_horovod = pytest.mark.skipif( + horovod is None, reason="horovod is not available" +) try: import funsor except ImportError: funsor = None -requires_funsor = pytest.mark.skipif(funsor is None, - reason="funsor is not available") +requires_funsor = pytest.mark.skipif(funsor is None, reason="funsor is not available") def get_cpu_type(t): - assert t.__module__ == 'torch.cuda' + assert t.__module__ == "torch.cuda" return getattr(torch, t.__class__.__name__) def get_gpu_type(t): - assert t.__module__ == 'torch' + assert t.__module__ == "torch" return getattr(torch.cuda, t.__name__) @@ -104,14 +105,14 @@ def tensors_default_to(host): :param str host: Either "cuda" or "cpu". """ - assert host in ('cpu', 'cuda'), host - old_module, name = torch.Tensor().type().rsplit('.', 1) - new_module = 'torch.cuda' if host == 'cuda' else 'torch' - torch.set_default_tensor_type('{}.{}'.format(new_module, name)) + assert host in ("cpu", "cuda"), host + old_module, name = torch.Tensor().type().rsplit(".", 1) + new_module = "torch.cuda" if host == "cuda" else "torch" + torch.set_default_tensor_type("{}.{}".format(new_module, name)) try: yield finally: - torch.set_default_tensor_type('{}.{}'.format(old_module, name)) + torch.set_default_tensor_type("{}.{}".format(old_module, name)) @contextlib.contextmanager @@ -149,7 +150,7 @@ def is_iterable(obj): return False -def assert_tensors_equal(a, b, prec=0., msg=''): +def assert_tensors_equal(a, b, prec=0.0, msg=""): assert a.size() == b.size(), msg if isinstance(prec, numbers.Number) and prec == 0: assert (a == b).all(), msg @@ -169,7 +170,7 @@ def assert_tensors_equal(a, b, prec=0., msg=''): assert (diff <= prec).all(), msg else: max_err = diff.max().item() - assert (max_err <= prec), msg + assert max_err <= prec, msg def _safe_coalesce(t): @@ -197,16 +198,17 @@ def _safe_coalesce(t): return tg -def assert_close(actual, expected, atol=1e-7, rtol=0, msg=''): +def assert_close(actual, expected, atol=1e-7, rtol=0, msg=""): if not msg: - msg = '{} vs {}'.format(actual, expected) + msg = "{} vs {}".format(actual, expected) if isinstance(actual, numbers.Number) and isinstance(expected, numbers.Number): assert actual == approx(expected, abs=atol, rel=rtol), msg # Placing this as a second check allows for coercing of numeric types above; # this can be moved up to harden type checks. elif type(actual) != type(expected): - raise AssertionError("cannot compare {} and {}".format(type(actual), - type(expected))) + raise AssertionError( + "cannot compare {} and {}".format(type(actual), type(expected)) + ) elif torch.is_tensor(actual) and torch.is_tensor(expected): prec = atol + rtol * abs(expected) if rtol > 0 else atol assert actual.is_sparse == expected.is_sparse, msg @@ -218,37 +220,45 @@ def assert_close(actual, expected, atol=1e-7, rtol=0, msg=''): else: assert_tensors_equal(actual, expected, prec, msg) elif type(actual) == np.ndarray and type(expected) == np.ndarray: - assert_allclose(actual, expected, atol=atol, rtol=rtol, equal_nan=True, err_msg=msg) + assert_allclose( + actual, expected, atol=atol, rtol=rtol, equal_nan=True, err_msg=msg + ) elif isinstance(actual, numbers.Number) and isinstance(y, numbers.Number): assert actual == approx(expected, abs=atol, rel=rtol), msg elif isinstance(actual, dict): assert set(actual.keys()) == set(expected.keys()) for key, x_val in actual.items(): - assert_close(x_val, expected[key], atol=atol, rtol=rtol, - msg='At key {}: {} vs {}'.format(repr(key), x_val, expected[key])) + assert_close( + x_val, + expected[key], + atol=atol, + rtol=rtol, + msg="At key {}: {} vs {}".format(repr(key), x_val, expected[key]), + ) elif isinstance(actual, str): assert actual == expected, msg elif is_iterable(actual) and is_iterable(expected): assert len(actual) == len(expected), msg for xi, yi in zip(actual, expected): - assert_close(xi, yi, atol=atol, rtol=rtol, msg='{} vs {}'.format(xi, yi)) + assert_close(xi, yi, atol=atol, rtol=rtol, msg="{} vs {}".format(xi, yi)) else: assert actual == expected, msg # TODO: Remove `prec` arg, and move usages to assert_close -def assert_equal(actual, expected, prec=1e-5, msg=''): - if prec > 0.: +def assert_equal(actual, expected, prec=1e-5, msg=""): + if prec > 0.0: return assert_close(actual, expected, atol=prec, msg=msg) if not msg: - msg = '{} vs {}'.format(actual, expected) + msg = "{} vs {}".format(actual, expected) if isinstance(actual, numbers.Number) and isinstance(expected, numbers.Number): assert actual == expected, msg # Placing this as a second check allows for coercing of numeric types above; # this can be moved up to harden type checks. elif type(actual) != type(expected): - raise AssertionError("cannot compare {} and {}".format(type(actual), - type(expected))) + raise AssertionError( + "cannot compare {} and {}".format(type(actual), type(expected)) + ) elif torch.is_tensor(actual) and torch.is_tensor(expected): assert actual.is_sparse == expected.is_sparse, msg if actual.is_sparse: @@ -263,21 +273,27 @@ def assert_equal(actual, expected, prec=1e-5, msg=''): elif isinstance(actual, dict): assert set(actual.keys()) == set(expected.keys()) for key, x_val in actual.items(): - assert_equal(x_val, expected[key], prec=0., - msg='At key{}: {} vs {}'.format(key, x_val, expected[key])) + assert_equal( + x_val, + expected[key], + prec=0.0, + msg="At key{}: {} vs {}".format(key, x_val, expected[key]), + ) elif isinstance(actual, str): assert actual == expected, msg elif is_iterable(actual) and is_iterable(expected): assert len(actual) == len(expected), msg for xi, yi in zip(actual, expected): - assert_equal(xi, yi, prec=0., msg='{} vs {}'.format(xi, yi)) + assert_equal(xi, yi, prec=0.0, msg="{} vs {}".format(xi, yi)) else: assert actual == expected, msg -def assert_not_equal(x, y, prec=1e-5, msg=''): +def assert_not_equal(x, y, prec=1e-5, msg=""): try: assert_equal(x, y, prec) except AssertionError: return - raise AssertionError("{} \nValues are equal: x={}, y={}, prec={}".format(msg, x, y, prec)) + raise AssertionError( + "{} \nValues are equal: x={}, y={}, prec={}".format(msg, x, y, prec) + ) diff --git a/tests/conftest.py b/tests/conftest.py index 699cca55c2..4103356fe3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,16 +9,19 @@ import pyro -torch.set_default_tensor_type(os.environ.get('PYRO_TENSOR_TYPE', 'torch.DoubleTensor')) +torch.set_default_tensor_type(os.environ.get("PYRO_TENSOR_TYPE", "torch.DoubleTensor")) def pytest_configure(config): - config.addinivalue_line("markers", - "init(rng_seed): initialize the RNG using the seed provided.") - config.addinivalue_line("markers", - "stage(NAME): mark test to run when testing stage matches NAME.") - config.addinivalue_line("markers", - "disable_validation: disable all validation on this test.") + config.addinivalue_line( + "markers", "init(rng_seed): initialize the RNG using the seed provided." + ) + config.addinivalue_line( + "markers", "stage(NAME): mark test to run when testing stage matches NAME." + ) + config.addinivalue_line( + "markers", "disable_validation: disable all validation on this test." + ) def pytest_runtest_setup(item): @@ -34,16 +37,20 @@ def pytest_runtest_setup(item): def pytest_addoption(parser): - parser.addoption("--stage", - action="append", - metavar="NAME", - default=[], - help="Only run tests matching the stage NAME.") - - parser.addoption("--lax", - action="store_true", - default=False, - help="Ignore AssertionError when running tests.") + parser.addoption( + "--stage", + action="append", + metavar="NAME", + default=[], + help="Only run tests matching the stage NAME.", + ) + + parser.addoption( + "--lax", + action="store_true", + default=False, + help="Ignore AssertionError when running tests.", + ) def _get_highest_specificity_marker(stage_marker): @@ -91,7 +98,11 @@ def pytest_collection_modifyitems(config, items): stage_marker = item.get_closest_marker("stage") if not stage_marker: selected_items.append(item) - warnings.warn("No stage associated with the test {}. Will run on each stage invocation.".format(item.name)) + warnings.warn( + "No stage associated with the test {}. Will run on each stage invocation.".format( + item.name + ) + ) continue item_stage_markers = _get_highest_specificity_marker(stage_marker) if test_stages.isdisjoint(item_stage_markers): diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 3059349a5d..1a722f6cf6 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -40,9 +40,13 @@ def compute_target(self, N): self.target_auto_diag_cov[-1] = 1.0 / self.lambda_posts[-1].item() for n in range(N - 1, 0, -1): self.target_auto_mus[n] += self.target_mus[n].item() - self.target_auto_mus[n] += self.target_kappas[n].item() * self.target_auto_mus[n + 1] + self.target_auto_mus[n] += ( + self.target_kappas[n].item() * self.target_auto_mus[n + 1] + ) self.target_auto_diag_cov[n] += 1.0 / self.lambda_posts[n].item() - self.target_auto_diag_cov[n] += (self.target_kappas[n].item() ** 2) * self.target_auto_diag_cov[n + 1] + self.target_auto_diag_cov[n] += ( + self.target_kappas[n].item() ** 2 + ) * self.target_auto_diag_cov[n + 1] def test_multivariatate_normal_auto(self): self.do_test_auto(3, reparameterized=True, n_steps=10001) @@ -54,13 +58,19 @@ def do_test_auto(self, N, reparameterized, n_steps): self.setup_chain(N) self.compute_target(N) self.guide = AutoMultivariateNormal(self.model) - logger.debug("target auto_loc: {}" - .format(self.target_auto_mus[1:].detach().cpu().numpy())) - logger.debug("target auto_diag_cov: {}" - .format(self.target_auto_diag_cov[1:].detach().cpu().numpy())) + logger.debug( + "target auto_loc: {}".format( + self.target_auto_mus[1:].detach().cpu().numpy() + ) + ) + logger.debug( + "target auto_diag_cov: {}".format( + self.target_auto_diag_cov[1:].detach().cpu().numpy() + ) + ) # TODO speed up with parallel num_particles > 1 - adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) + adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO()) for k in range(n_steps): @@ -68,22 +78,43 @@ def do_test_auto(self, N, reparameterized, n_steps): assert np.isfinite(loss), loss if k % 1000 == 0 and k > 0 or k == n_steps - 1: - logger.debug("[step {}] guide mean parameter: {}" - .format(k, self.guide.loc.detach().cpu().numpy())) + logger.debug( + "[step {}] guide mean parameter: {}".format( + k, self.guide.loc.detach().cpu().numpy() + ) + ) L = self.guide.scale_tril diag_cov = torch.mm(L, L.t()).diag() - logger.debug("[step {}] auto_diag_cov: {}" - .format(k, diag_cov.detach().cpu().numpy())) - - assert_equal(self.guide.loc.detach(), self.target_auto_mus[1:], prec=0.05, - msg="guide mean off") - assert_equal(diag_cov, self.target_auto_diag_cov[1:], prec=0.07, - msg="guide covariance off") - - -@pytest.mark.parametrize('auto_class', [AutoDiagonalNormal, AutoMultivariateNormal, - AutoLowRankMultivariateNormal, AutoLaplaceApproximation]) -@pytest.mark.parametrize('Elbo', [Trace_ELBO, TraceMeanField_ELBO]) + logger.debug( + "[step {}] auto_diag_cov: {}".format( + k, diag_cov.detach().cpu().numpy() + ) + ) + + assert_equal( + self.guide.loc.detach(), + self.target_auto_mus[1:], + prec=0.05, + msg="guide mean off", + ) + assert_equal( + diag_cov, + self.target_auto_diag_cov[1:], + prec=0.07, + msg="guide covariance off", + ) + + +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) +@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceMeanField_ELBO]) def test_auto_diagonal_gaussians(auto_class, Elbo): n_steps = 3001 @@ -95,8 +126,9 @@ def model(): guide = auto_class(model, rank=1) else: guide = auto_class(model) - adam = optim.ClippedAdam({"lr": .01, "betas": (0.95, 0.999), - "lrd": 0.1 ** (1 / n_steps)}) + adam = optim.ClippedAdam( + {"lr": 0.01, "betas": (0.95, 0.999), "lrd": 0.1 ** (1 / n_steps)} + ) svi = SVI(model, guide, adam, loss=Elbo()) for k in range(n_steps): @@ -109,21 +141,44 @@ def model(): loc, scale = guide._loc_scale() expected_loc = torch.tensor([-0.2, 0.2]) - assert_equal(loc.detach(), expected_loc, prec=0.05, - msg="\n".join(["Incorrect guide loc. Expected:", - str(expected_loc.cpu().numpy()), - "Actual:", - str(loc.detach().cpu().numpy())])) + assert_equal( + loc.detach(), + expected_loc, + prec=0.05, + msg="\n".join( + [ + "Incorrect guide loc. Expected:", + str(expected_loc.cpu().numpy()), + "Actual:", + str(loc.detach().cpu().numpy()), + ] + ), + ) expected_scale = torch.tensor([1.2, 0.7]) - assert_equal(scale.detach(), expected_scale, prec=0.08, - msg="\n".join(["Incorrect guide scale. Expected:", - str(expected_scale.cpu().numpy()), - "Actual:", - str(scale.detach().cpu().numpy())])) - - -@pytest.mark.parametrize('auto_class', [AutoDiagonalNormal, AutoMultivariateNormal, - AutoLowRankMultivariateNormal, AutoLaplaceApproximation]) + assert_equal( + scale.detach(), + expected_scale, + prec=0.08, + msg="\n".join( + [ + "Incorrect guide scale. Expected:", + str(expected_scale.cpu().numpy()), + "Actual:", + str(scale.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) def test_auto_transform(auto_class): n_steps = 3500 @@ -134,7 +189,7 @@ def model(): guide = auto_class(model, rank=1) else: guide = auto_class(model) - adam = optim.Adam({"lr": .001, "betas": (0.90, 0.999)}) + adam = optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) for k in range(n_steps): @@ -145,21 +200,24 @@ def model(): guide = guide.laplace_approximation() loc, scale = guide._loc_scale() - assert_equal(loc.detach(), torch.tensor([0.2]), prec=0.04, - msg="guide mean off") - assert_equal(scale.detach(), torch.tensor([0.7]), prec=0.04, - msg="guide covariance off") - - -@pytest.mark.parametrize('auto_class', [ - AutoDiagonalNormal, - AutoIAFNormal, - AutoMultivariateNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - lambda m: AutoNormalizingFlow(m, partial(iterated, 2, block_autoregressive)), -]) -@pytest.mark.parametrize('Elbo', [Trace_ELBO, TraceMeanField_ELBO]) + assert_equal(loc.detach(), torch.tensor([0.2]), prec=0.04, msg="guide mean off") + assert_equal( + scale.detach(), torch.tensor([0.7]), prec=0.04, msg="guide covariance off" + ) + + +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoIAFNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + lambda m: AutoNormalizingFlow(m, partial(iterated, 2, block_autoregressive)), + ], +) +@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceMeanField_ELBO]) def test_auto_dirichlet(auto_class, Elbo): num_steps = 2000 prior = torch.tensor([0.5, 1.0, 1.5, 3.0]) @@ -172,7 +230,7 @@ def model(data): pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data) guide = auto_class(model) - svi = SVI(model, guide, optim.Adam({"lr": .003}), loss=Elbo()) + svi = SVI(model, guide, optim.Adam({"lr": 0.003}), loss=Elbo()) for _ in range(num_steps): loss = svi.step(data) @@ -184,6 +242,14 @@ def model(data): else: loc = guide.loc actual_mean = biject_to(constraints.simplex)(loc) - assert_equal(actual_mean, expected_mean, prec=0.2, msg=''.join([ - '\nexpected {}'.format(expected_mean.detach().cpu().numpy()), - '\n actual {}'.format(actual_mean.detach().cpu().numpy())])) + assert_equal( + actual_mean, + expected_mean, + prec=0.2, + msg="".join( + [ + "\nexpected {}".format(expected_mean.detach().cpu().numpy()), + "\n actual {}".format(actual_mean.detach().cpu().numpy()), + ] + ), + ) diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/contrib/autoguide/test_mean_field_entropy.py index 2f5cd163db..0558ef4604 100644 --- a/tests/contrib/autoguide/test_mean_field_entropy.py +++ b/tests/contrib/autoguide/test_mean_field_entropy.py @@ -13,8 +13,8 @@ def mean_field_guide(batch_tensor, design): # A batched variable - w_p = pyro.param("w_p", 0.2*torch.ones(batch_tensor.shape)) - u_p = pyro.param("u_p", 0.5*torch.ones(batch_tensor.shape)) + w_p = pyro.param("w_p", 0.2 * torch.ones(batch_tensor.shape)) + u_p = pyro.param("u_p", 0.5 * torch.ones(batch_tensor.shape)) pyro.sample("w", dist.Bernoulli(w_p)) pyro.sample("u", dist.Bernoulli(u_p)) @@ -23,9 +23,16 @@ def h(p): return -(sc.xlogy(p, p) + sc.xlog1py(1 - p, -p)) -@pytest.mark.parametrize("guide,args,expected_entropy", [ - (mean_field_guide, (torch.Tensor([0.]), None), torch.Tensor([h(0.2) + h(0.5)])), - (mean_field_guide, (torch.eye(2), None), (h(0.2) + h(0.5))*torch.ones(2, 2)) -]) +@pytest.mark.parametrize( + "guide,args,expected_entropy", + [ + ( + mean_field_guide, + (torch.Tensor([0.0]), None), + torch.Tensor([h(0.2) + h(0.5)]), + ), + (mean_field_guide, (torch.eye(2), None), (h(0.2) + h(0.5)) * torch.ones(2, 2)), + ], +) def test_guide_entropy(guide, args, expected_entropy): assert_equal(mean_field_entropy(guide, args), expected_entropy) diff --git a/tests/contrib/autoname/test_named.py b/tests/contrib/autoname/test_named.py index 8d02f7c07d..7ef9b22ba3 100644 --- a/tests/contrib/autoname/test_named.py +++ b/tests/contrib/autoname/test_named.py @@ -10,13 +10,23 @@ def get_sample_names(tr): - return set([name for name, site in tr.nodes.items() - if site["type"] == "sample" and not site["is_observed"]]) + return set( + [ + name + for name, site in tr.nodes.items() + if site["type"] == "sample" and not site["is_observed"] + ] + ) def get_observe_names(tr): - return set([name for name, site in tr.nodes.items() - if site["type"] == "sample" and site["is_observed"]]) + return set( + [ + name + for name, site in tr.nodes.items() + if site["type"] == "sample" and site["is_observed"] + ] + ) def get_param_names(tr): diff --git a/tests/contrib/autoname/test_scoping.py b/tests/contrib/autoname/test_scoping.py index aa7e44bae6..b0d445927d 100644 --- a/tests/contrib/autoname/test_scoping.py +++ b/tests/contrib/autoname/test_scoping.py @@ -14,7 +14,6 @@ def test_multi_nested(): - @scope def model1(r=True): model2() @@ -29,25 +28,25 @@ def model1(r=True): def model2(): return pyro.sample("y", dist.Normal(0.0, 1.0)) - true_samples = ["model1/model2/y", - "model1/model2__1/y", - "model1/inter/model2/y", - "model1/inter/model1/model2/y", - "model1/inter/model1/model2__1/y", - "model1/inter/model1/inter/model2/y", - "model1/inter/model1/model2__2/y", - "model1/model2__2/y"] + true_samples = [ + "model1/model2/y", + "model1/model2__1/y", + "model1/inter/model2/y", + "model1/inter/model1/model2/y", + "model1/inter/model1/model2__1/y", + "model1/inter/model1/inter/model2/y", + "model1/inter/model1/model2__2/y", + "model1/model2__2/y", + ] tr = poutine.trace(name_count(model1)).get_trace(r=True) - samples = [name for name, node in tr.nodes.items() - if node["type"] == "sample"] + samples = [name for name, node in tr.nodes.items() if node["type"] == "sample"] logger.debug(samples) assert true_samples == samples def test_recur_multi(): - @scope(inner=True) def model1(r=True): model2() @@ -61,23 +60,23 @@ def model1(r=True): def model2(): return pyro.sample("y", dist.Normal(0.0, 1.0)) - true_samples = ["model1/model2/y", - "model1/inter/model2/y", - "model1/inter/model1/model2/y", - "model1/inter/model1/inter/model2/y", - "model1/inter/model1/model2/y__1", - "model1/model2/y__1"] + true_samples = [ + "model1/model2/y", + "model1/inter/model2/y", + "model1/inter/model1/model2/y", + "model1/inter/model1/inter/model2/y", + "model1/inter/model1/model2/y__1", + "model1/model2/y__1", + ] tr = poutine.trace(name_count(model1)).get_trace() - samples = [name for name, node in tr.nodes.items() - if node["type"] == "sample"] + samples = [name for name, node in tr.nodes.items() if node["type"] == "sample"] logger.debug(samples) assert true_samples == samples def test_only_withs(): - def model1(): with scope(prefix="a"): with scope(prefix="b"): @@ -91,14 +90,13 @@ def model1(): def test_mutual_recur(): - @scope def model1(n): pyro.sample("a", dist.Bernoulli(0.5)) if n <= 0: return else: - return model2(n-1) + return model2(n - 1) @scope def model2(n): @@ -108,14 +106,14 @@ def model2(n): else: model1(n) - names = set(["_INPUT", "_RETURN", - "model2/b", "model2/model1/a", "model2/model1/model2/b"]) + names = set( + ["_INPUT", "_RETURN", "model2/b", "model2/model1/a", "model2/model1/model2/b"] + ) tr_names = set([name for name in poutine.trace(name_count(model2)).get_trace(1)]) assert names == tr_names def test_simple_recur(): - @scope def geometric(p): x = pyro.sample("x", dist.Bernoulli(p)) @@ -134,7 +132,6 @@ def geometric(p): def test_basic_scope(): - @scope def f1(): return pyro.sample("x", dist.Bernoulli(0.5)) @@ -153,7 +150,6 @@ def f2(): def test_nested_traces(): - @scope def f1(): return pyro.sample("x", dist.Bernoulli(0.5)) @@ -167,8 +163,9 @@ def f2(): expected_names = ["f2/f1/x", "f2/f1__1/x", "f2/f1__2/x", "f2/y"] tr2 = poutine.trace(name_count(name_count(f2))).get_trace() - actual_names = [name for name, node in tr2.nodes.items() - if node['type'] == "sample"] + actual_names = [ + name for name, node in tr2.nodes.items() if node["type"] == "sample" + ] assert expected_names == actual_names @@ -183,7 +180,8 @@ def model(): expected_names = ["a", "model/b"] tr = poutine.trace(model).get_trace() - actual_names = [name for name, node in tr.nodes.items() - if node['type'] in ('param', 'sample')] + actual_names = [ + name for name, node in tr.nodes.items() if node["type"] in ("param", "sample") + ] assert expected_names == actual_names diff --git a/tests/contrib/bnn/test_hidden_layer.py b/tests/contrib/bnn/test_hidden_layer.py index c688572d0f..9f72e54f86 100644 --- a/tests/contrib/bnn/test_hidden_layer.py +++ b/tests/contrib/bnn/test_hidden_layer.py @@ -12,16 +12,30 @@ @pytest.mark.parametrize("non_linearity", [F.relu]) @pytest.mark.parametrize("include_hidden_bias", [False, True]) -def test_hidden_layer_rsample(non_linearity, include_hidden_bias, B=2, D=3, H=4, N=900000): +def test_hidden_layer_rsample( + non_linearity, include_hidden_bias, B=2, D=3, H=4, N=900000 +): X = torch.randn(B, D) A_mean = torch.rand(D, H) A_scale = 0.3 * torch.exp(0.3 * torch.rand(D, H)) # test naive weight space sampling against sampling in pre-activation space - dist1 = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, non_linearity=non_linearity, - include_hidden_bias=include_hidden_bias, weight_space_sampling=True) - dist2 = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, non_linearity=non_linearity, - include_hidden_bias=include_hidden_bias, weight_space_sampling=False) + dist1 = HiddenLayer( + X=X, + A_mean=A_mean, + A_scale=A_scale, + non_linearity=non_linearity, + include_hidden_bias=include_hidden_bias, + weight_space_sampling=True, + ) + dist2 = HiddenLayer( + X=X, + A_mean=A_mean, + A_scale=A_scale, + non_linearity=non_linearity, + include_hidden_bias=include_hidden_bias, + weight_space_sampling=False, + ) out1 = dist1.rsample(sample_shape=(N,)) out1_mean, out1_var = out1.mean(0), out1.var(0) @@ -39,8 +53,13 @@ def test_hidden_layer_log_prob(non_linearity, include_hidden_bias, B=2, D=3, H=2 X = torch.randn(B, D) A_mean = torch.rand(D, H) A_scale = 0.3 * torch.exp(0.3 * torch.rand(D, H)) - dist = HiddenLayer(X=X, A_mean=A_mean, A_scale=A_scale, - non_linearity=non_linearity, include_hidden_bias=include_hidden_bias) + dist = HiddenLayer( + X=X, + A_mean=A_mean, + A_scale=A_scale, + non_linearity=non_linearity, + include_hidden_bias=include_hidden_bias, + ) A_dist = Normal(A_mean, A_scale) A_prior = Normal(torch.zeros(D, H), torch.ones(D, H)) diff --git a/tests/contrib/cevae/test_cevae.py b/tests/contrib/cevae/test_cevae.py index 70924cf869..bcbb192ff7 100644 --- a/tests/contrib/cevae/test_cevae.py +++ b/tests/contrib/cevae/test_cevae.py @@ -12,8 +12,7 @@ from pyro.contrib.cevae import CEVAE, DistributionNet from tests.common import assert_close -DIST_NETS = [cls.__name__.lower()[:-3] - for cls in DistributionNet.__subclasses__()] +DIST_NETS = [cls.__name__.lower()[:-3] for cls in DistributionNet.__subclasses__()] def generate_data(num_data, feature_dim): @@ -44,7 +43,9 @@ def test_serialization(jit, feature_dim, outcome_dist): x, t, y = generate_data(num_data=32, feature_dim=feature_dim) if outcome_dist == "exponential": y.clamp_(min=1e-20) - cevae = CEVAE(feature_dim, outcome_dist=outcome_dist, num_samples=1000, hidden_dim=32) + cevae = CEVAE( + feature_dim, outcome_dist=outcome_dist, num_samples=1000, hidden_dim=32 + ) cevae.fit(x, t, y, num_epochs=4, batch_size=8) pyro.set_rng_seed(0) expected_ite = cevae.ite(x) diff --git a/tests/contrib/easyguide/test_easyguide.py b/tests/contrib/easyguide/test_easyguide.py index ffb29031a9..4166cfc5a1 100644 --- a/tests/contrib/easyguide/test_easyguide.py +++ b/tests/contrib/easyguide/test_easyguide.py @@ -24,12 +24,12 @@ def model(batch, subsample, full_size): result = [None] * num_time_steps drift = pyro.sample("drift", dist.LogNormal(-1, 0.5)) with pyro.plate("data", full_size, subsample=subsample): - z = 0. + z = 0.0 for t in range(num_time_steps): - z = pyro.sample("state_{}".format(t), - dist.Normal(z, drift)) - result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), - obs=batch[t]) + z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) + result[t] = pyro.sample( + "obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t] + ) return torch.stack(result) @@ -56,7 +56,6 @@ def check_guide(guide): @pytest.mark.parametrize("init_fn", [None, init_to_mean, init_to_median]) def test_delta_smoke(init_fn): - @easy_guide(model) def guide(self, batch, subsample, full_size): self.map_estimate("drift") @@ -106,16 +105,23 @@ def test_subsample_smoke(init_fn): def guide(self, batch, subsample, full_size): self.map_estimate("drift") group = self.group(match="state_[0-9]*") - cov_diag = pyro.param("state_cov_diag", - lambda: torch.full(group.event_shape, 0.01), - constraint=constraints.positive) - cov_factor = pyro.param("state_cov_factor", - lambda: torch.randn(group.event_shape + (rank,)) * 0.01) + cov_diag = pyro.param( + "state_cov_diag", + lambda: torch.full(group.event_shape, 0.01), + constraint=constraints.positive, + ) + cov_factor = pyro.param( + "state_cov_factor", lambda: torch.randn(group.event_shape + (rank,)) * 0.01 + ) with self.plate("data", full_size, subsample=subsample): - loc = pyro.param("state_loc", - lambda: torch.full((full_size,) + group.event_shape, 0.5), - event_dim=1) - group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)) + loc = pyro.param( + "state_loc", + lambda: torch.full((full_size,) + group.event_shape, 0.5), + event_dim=1, + ) + group.sample( + "states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) + ) if init_fn is not None: guide.init = init_fn @@ -133,20 +139,27 @@ def guide(self, batch, subsample, full_size): self.map_estimate("drift") group = self.group(match="state_[0-9]*") - cov_diag = pyro.param("state_cov_diag", - lambda: torch.full(group.event_shape, 0.01), - constraint=constraints.positive) - cov_factor = pyro.param("state_cov_factor", - lambda: torch.randn(group.event_shape + (rank,)) * 0.01) + cov_diag = pyro.param( + "state_cov_diag", + lambda: torch.full(group.event_shape, 0.01), + constraint=constraints.positive, + ) + cov_factor = pyro.param( + "state_cov_factor", lambda: torch.randn(group.event_shape + (rank,)) * 0.01 + ) if not hasattr(self, "nn"): - self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel()) + self.nn = torch.nn.Linear( + group.event_shape.numel(), group.event_shape.numel() + ) self.nn.weight.data.fill_(1.0 / num_time_steps) self.nn.bias.data.fill_(-0.5) pyro.module("state_nn", self.nn) with self.plate("data", full_size, subsample=subsample): loc = self.nn(batch.t()) - group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)) + group.sample( + "states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) + ) if init_fn is not None: guide.init = init_fn @@ -155,7 +168,6 @@ def guide(self, batch, subsample, full_size): def test_overlapping_plates_ok(): - def model(batch, subsample, full_size): # This is ok because the shared plate is left of the nonshared plate. with pyro.plate("shared", full_size, subsample=subsample, dim=-2): @@ -163,20 +175,23 @@ def model(batch, subsample, full_size): with pyro.plate("nonshared", 2, dim=-1): y = pyro.sample("y", dist.Normal(0, 1)) xy = x + y.sum(-1, keepdim=True) - return pyro.sample("z", dist.Normal(xy, 1), - obs=batch) + return pyro.sample("z", dist.Normal(xy, 1), obs=batch) @easy_guide(model) def guide(self, batch, subsample, full_size): with self.plate("shared", full_size, subsample=subsample, dim=-2): group = self.group(match="x|y") - loc = pyro.param("guide_loc", - torch.zeros((full_size, 1) + group.event_shape), - event_dim=1) - scale = pyro.param("guide_scale", - torch.ones((full_size, 1) + group.event_shape), - constraint=constraints.positive, - event_dim=1) + loc = pyro.param( + "guide_loc", + torch.zeros((full_size, 1) + group.event_shape), + event_dim=1, + ) + scale = pyro.param( + "guide_scale", + torch.ones((full_size, 1) + group.event_shape), + constraint=constraints.positive, + event_dim=1, + ) group.sample("xy", dist.Normal(loc, scale).to_event(1)) # Generate data. @@ -198,7 +213,6 @@ def guide(self, batch, subsample, full_size): def test_overlapping_plates_error(): - def model(batch, subsample, full_size): # This is an error because the shared plate is right of the nonshared plate. with pyro.plate("shared", full_size, subsample=subsample, dim=-1): @@ -206,20 +220,21 @@ def model(batch, subsample, full_size): with pyro.plate("nonshared", 2, dim=-2): y = pyro.sample("y", dist.Normal(0, 1)) xy = x + y.sum(-2) - return pyro.sample("z", dist.Normal(xy, 1), - obs=batch) + return pyro.sample("z", dist.Normal(xy, 1), obs=batch) @easy_guide(model) def guide(self, batch, subsample, full_size): with self.plate("shared", full_size, subsample=subsample, dim=-1): group = self.group(match="x|y") - loc = pyro.param("guide_loc", - torch.zeros((full_size,) + group.event_shape), - event_dim=1) - scale = pyro.param("guide_scale", - torch.ones((full_size,) + group.event_shape), - constraint=constraints.positive, - event_dim=1) + loc = pyro.param( + "guide_loc", torch.zeros((full_size,) + group.event_shape), event_dim=1 + ) + scale = pyro.param( + "guide_scale", + torch.ones((full_size,) + group.event_shape), + constraint=constraints.positive, + event_dim=1, + ) group.sample("xy", dist.Normal(loc, scale).to_event(1)) # Generate data. diff --git a/tests/contrib/epidemiology/test_distributions.py b/tests/contrib/epidemiology/test_distributions.py index 5acab8cead..b65180fc8c 100644 --- a/tests/contrib/epidemiology/test_distributions.py +++ b/tests/contrib/epidemiology/test_distributions.py @@ -32,86 +32,108 @@ def assert_dist_close(d1, d2): assert (p1 - p2).abs().max() / max_prob < 0.05 -@pytest.mark.parametrize("R0,I", [ - (1., 1), - (1., 10), - (10., 1), - (5., 5), -]) +@pytest.mark.parametrize( + "R0,I", + [ + (1.0, 1), + (1.0, 10), + (10.0, 1), + (5.0, 5), + ], +) def test_binomial_vs_poisson(R0, I): R0 = torch.tensor(R0) I = torch.tensor(I) d1 = infection_dist(individual_rate=R0, num_infectious=I) - d2 = infection_dist(individual_rate=R0, num_infectious=I, - num_susceptible=1000., population=1000.) + d2 = infection_dist( + individual_rate=R0, num_infectious=I, num_susceptible=1000.0, population=1000.0 + ) assert isinstance(d1, dist.Poisson) assert isinstance(d2, dist.Binomial) assert_dist_close(d1, d2) -@pytest.mark.parametrize("R0,I,k", [ - (1., 1., 0.5), - (1., 1., 1.), - (1., 1., 2.), - (1., 10., 0.5), - (1., 10., 1.), - (1., 10., 2.), - (10., 1., 0.5), - (10., 1., 1.), - (10., 1., 2.), - (5., 5, 0.5), - (5., 5, 1.), - (5., 5, 2.), -]) +@pytest.mark.parametrize( + "R0,I,k", + [ + (1.0, 1.0, 0.5), + (1.0, 1.0, 1.0), + (1.0, 1.0, 2.0), + (1.0, 10.0, 0.5), + (1.0, 10.0, 1.0), + (1.0, 10.0, 2.0), + (10.0, 1.0, 0.5), + (10.0, 1.0, 1.0), + (10.0, 1.0, 2.0), + (5.0, 5, 0.5), + (5.0, 5, 1.0), + (5.0, 5, 2.0), + ], +) def test_beta_binomial_vs_negative_binomial(R0, I, k): R0 = torch.tensor(R0) I = torch.tensor(I) d1 = infection_dist(individual_rate=R0, num_infectious=I, concentration=k) - d2 = infection_dist(individual_rate=R0, num_infectious=I, concentration=k, - num_susceptible=1000., population=1000.) + d2 = infection_dist( + individual_rate=R0, + num_infectious=I, + concentration=k, + num_susceptible=1000.0, + population=1000.0, + ) assert isinstance(d1, dist.NegativeBinomial) assert isinstance(d2, dist.BetaBinomial) assert_dist_close(d1, d2) -@pytest.mark.parametrize("R0,I", [ - (1., 1.), - (1., 10.), - (10., 1.), - (5., 5.), -]) +@pytest.mark.parametrize( + "R0,I", + [ + (1.0, 1.0), + (1.0, 10.0), + (10.0, 1.0), + (5.0, 5.0), + ], +) def test_beta_binomial_vs_binomial(R0, I): R0 = torch.tensor(R0) I = torch.tensor(I) - d1 = infection_dist(individual_rate=R0, num_infectious=I, - num_susceptible=20., population=30.) - d2 = infection_dist(individual_rate=R0, num_infectious=I, - num_susceptible=20., population=30., - concentration=200.) + d1 = infection_dist( + individual_rate=R0, num_infectious=I, num_susceptible=20.0, population=30.0 + ) + d2 = infection_dist( + individual_rate=R0, + num_infectious=I, + num_susceptible=20.0, + population=30.0, + concentration=200.0, + ) assert isinstance(d1, dist.Binomial) assert isinstance(d2, dist.BetaBinomial) assert_dist_close(d1, d2) -@pytest.mark.parametrize("R0,I", [ - (1., 1.), - (1., 10.), - (10., 1.), - (5., 5.), -]) +@pytest.mark.parametrize( + "R0,I", + [ + (1.0, 1.0), + (1.0, 10.0), + (10.0, 1.0), + (5.0, 5.0), + ], +) def test_negative_binomial_vs_poisson(R0, I): R0 = torch.tensor(R0) I = torch.tensor(I) d1 = infection_dist(individual_rate=R0, num_infectious=I) - d2 = infection_dist(individual_rate=R0, num_infectious=I, - concentration=200.) + d2 = infection_dist(individual_rate=R0, num_infectious=I, concentration=200.0) assert isinstance(d1, dist.Poisson) assert isinstance(d2, dist.NegativeBinomial) @@ -140,12 +162,12 @@ def test_overdispersed_asymptote(probs, overdispersion): # Check binomial_dist converges in distribution to LogitNormal. d1 = binomial_dist(total_count, probs) d2 = dist.TransformedDistribution( - dist.Normal(math.log(probs / (1 - probs)), overdispersion), - SigmoidTransform()) + dist.Normal(math.log(probs / (1 - probs)), overdispersion), SigmoidTransform() + ) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion - k = torch.arange(0., total_count + 1.) + k = torch.arange(0.0, total_count + 1.0) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.cdf(k / total_count) crps = (cdf1 - cdf2).pow(2).mean() @@ -153,17 +175,18 @@ def test_overdispersed_asymptote(probs, overdispersion): @pytest.mark.parametrize("total_count", [1, 2, 5, 10, 20, 50]) -@pytest.mark.parametrize("concentration1", [0.2, 1.0, 5.]) -@pytest.mark.parametrize("concentration0", [0.2, 1.0, 5.]) +@pytest.mark.parametrize("concentration1", [0.2, 1.0, 5.0]) +@pytest.mark.parametrize("concentration0", [0.2, 1.0, 5.0]) def test_beta_binomial(concentration1, concentration0, total_count): # For small overdispersion, beta_binomial_dist is close to BetaBinomial. d1 = dist.BetaBinomial(concentration1, concentration0, total_count) - d2 = beta_binomial_dist(concentration1, concentration0, total_count, - overdispersion=0.01) + d2 = beta_binomial_dist( + concentration1, concentration0, total_count, overdispersion=0.01 + ) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion - k = torch.arange(0., total_count + 1.) + k = torch.arange(0.0, total_count + 1.0) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.log_prob(k).exp().cumsum(-1) crps = (cdf1 - cdf2).pow(2).mean() @@ -175,16 +198,17 @@ def test_beta_binomial(concentration1, concentration0, total_count): @pytest.mark.parametrize("probs", [0.1, 0.2, 0.5, 0.8, 0.9]) def test_overdispersed_beta_binomial(probs, total_count, overdispersion): # For high concentraion, beta_binomial_dist is close to binomial_dist. - concentration = 100. # very little uncertainty + concentration = 100.0 # very little uncertainty concentration1 = concentration * probs concentration0 = concentration * (1 - probs) d1 = binomial_dist(total_count, probs, overdispersion=overdispersion) - d2 = beta_binomial_dist(concentration1, concentration0, total_count, - overdispersion=overdispersion) + d2 = beta_binomial_dist( + concentration1, concentration0, total_count, overdispersion=overdispersion + ) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion - k = torch.arange(0., total_count + 1.) + k = torch.arange(0.0, total_count + 1.0) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.log_prob(k).exp().cumsum(-1) crps = (cdf1 - cdf2).pow(2).mean() @@ -241,13 +265,15 @@ def test_relaxed_overdispersed_beta_binomial(overdispersion): concentration1 = torch.logspace(-1, 2, 8).unsqueeze(-1) concentration0 = concentration1.unsqueeze(-1) - d1 = beta_binomial_dist(concentration1, concentration0, total_count, - overdispersion=overdispersion) + d1 = beta_binomial_dist( + concentration1, concentration0, total_count, overdispersion=overdispersion + ) assert isinstance(d1, dist.ExtendedBetaBinomial) with set_relaxed_distributions(): - d2 = beta_binomial_dist(concentration1, concentration0, total_count, - overdispersion=overdispersion) + d2 = beta_binomial_dist( + concentration1, concentration0, total_count, overdispersion=overdispersion + ) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE)) diff --git a/tests/contrib/epidemiology/test_models.py b/tests/contrib/epidemiology/test_models.py index 068f775ac1..987461ef3e 100644 --- a/tests/contrib/epidemiology/test_models.py +++ b/tests/contrib/epidemiology/test_models.py @@ -31,30 +31,34 @@ @pytest.mark.filterwarnings("ignore:num_chains") @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("svi", {"guide_rank": None}), - ("svi", {"guide_rank": 2}), - ("svi", {"guide_rank": "full"}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"haar_full_mass": 0}), - ("mcmc", {"haar_full_mass": 2}), - ("mcmc", {"num_quant_bins": 2}), - ("mcmc", {"num_quant_bins": 4}), - ("mcmc", {"num_quant_bins": 8}), - ("mcmc", {"num_quant_bins": 12}), - ("mcmc", {"num_quant_bins": 16}), - ("mcmc", {"num_quant_bins": 2, "haar": False}), - ("mcmc", {"arrowhead_mass": True}), - ("mcmc", {"jit_compile": True}), - ("mcmc", {"jit_compile": True, "haar_full_mass": 0}), - ("mcmc", {"jit_compile": True, "num_quant_bins": 2}), - ("mcmc", {"num_chains": 2, "mp_context": "spawn"}), - ("mcmc", {"num_chains": 2, "mp_context": "spawn", "num_quant_bins": 2}), - ("mcmc", {"num_chains": 2, "mp_context": "spawn", "jit_compile": True}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("svi", {"guide_rank": None}), + ("svi", {"guide_rank": 2}), + ("svi", {"guide_rank": "full"}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"haar_full_mass": 0}), + ("mcmc", {"haar_full_mass": 2}), + ("mcmc", {"num_quant_bins": 2}), + ("mcmc", {"num_quant_bins": 4}), + ("mcmc", {"num_quant_bins": 8}), + ("mcmc", {"num_quant_bins": 12}), + ("mcmc", {"num_quant_bins": 16}), + ("mcmc", {"num_quant_bins": 2, "haar": False}), + ("mcmc", {"arrowhead_mass": True}), + ("mcmc", {"jit_compile": True}), + ("mcmc", {"jit_compile": True, "haar_full_mass": 0}), + ("mcmc", {"jit_compile": True, "num_quant_bins": 2}), + ("mcmc", {"num_chains": 2, "mp_context": "spawn"}), + ("mcmc", {"num_chains": 2, "mp_context": "spawn", "num_quant_bins": 2}), + ("mcmc", {"num_chains": 2, "mp_context": "spawn", "jit_compile": True}), + ], + ids=str, +) def test_simple_sir_smoke(duration, forecast, options, algo): population = 100 recovery_time = 7.0 @@ -72,7 +76,9 @@ def test_simple_sir_smoke(duration, forecast, options, algo): model = SimpleSIRModel(population, recovery_time, data) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -85,22 +91,27 @@ def test_simple_sir_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"haar_full_mass": 0}), - ("mcmc", {"num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"haar_full_mass": 0}), + ("mcmc", {"num_quant_bins": 2}), + ], + ids=str, +) def test_simple_seir_smoke(duration, forecast, options, algo): population = 100 incubation_time = 2.0 recovery_time = 7.0 # Generate data. - model = SimpleSEIRModel(population, incubation_time, recovery_time, - [None] * duration) + model = SimpleSEIRModel( + population, incubation_time, recovery_time, [None] * duration + ) assert model.full_mass == [("R0", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] @@ -112,7 +123,9 @@ def test_simple_seir_smoke(duration, forecast, options, algo): model = SimpleSEIRModel(population, incubation_time, recovery_time, data) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -125,11 +138,15 @@ def test_simple_seir_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("mcmc", {}), - ("mcmc", {"haar_full_mass": 0}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("mcmc", {}), + ("mcmc", {"haar_full_mass": 0}), + ], + ids=str, +) def test_simple_seird_smoke(duration, forecast, options, algo): population = 100 incubation_time = 2.0 @@ -137,8 +154,9 @@ def test_simple_seird_smoke(duration, forecast, options, algo): mortality_rate = 0.1 # Generate data. - model = SimpleSEIRDModel(population, incubation_time, recovery_time, - mortality_rate, [None] * duration) + model = SimpleSEIRDModel( + population, incubation_time, recovery_time, mortality_rate, [None] * duration + ) assert model.full_mass == [("R0", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] @@ -147,11 +165,14 @@ def test_simple_seird_smoke(duration, forecast, options, algo): assert data.sum() > 0, "failed to generate positive data" # Infer. - model = SimpleSEIRDModel(population, incubation_time, recovery_time, - mortality_rate, data) + model = SimpleSEIRDModel( + population, incubation_time, recovery_time, mortality_rate, data + ) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -165,11 +186,15 @@ def test_simple_seird_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [3]) @pytest.mark.parametrize("forecast", [7]) -@pytest.mark.parametrize("options", [ - {}, - {"haar": False}, - {"num_quant_bins": 2}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + {}, + {"haar": False}, + {"num_quant_bins": 2}, + ], + ids=str, +) def test_overdispersed_sir_smoke(duration, forecast, options): population = 100 recovery_time = 7.0 @@ -196,19 +221,24 @@ def test_overdispersed_sir_smoke(duration, forecast, options): @pytest.mark.parametrize("duration", [3]) @pytest.mark.parametrize("forecast", [7]) -@pytest.mark.parametrize("options", [ - {}, - {"haar": False}, - {"num_quant_bins": 2}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + {}, + {"haar": False}, + {"num_quant_bins": 2}, + ], + ids=str, +) def test_overdispersed_seir_smoke(duration, forecast, options): population = 100 incubation_time = 2.0 recovery_time = 7.0 # Generate data. - model = OverdispersedSEIRModel(population, incubation_time, recovery_time, - [None] * duration) + model = OverdispersedSEIRModel( + population, incubation_time, recovery_time, [None] * duration + ) assert model.full_mass == [("R0", "rho", "od")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] @@ -219,8 +249,7 @@ def test_overdispersed_seir_smoke(duration, forecast, options): # Infer. model = OverdispersedSEIRModel(population, incubation_time, recovery_time, data) num_samples = 5 - model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, - **options) + model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options) # Predict and forecast. samples = model.predict(forecast=forecast) @@ -231,12 +260,16 @@ def test_overdispersed_seir_smoke(duration, forecast, options): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("options", [ - {}, - {"haar": False}, - {"haar_full_mass": 0}, - {"num_quant_bins": 2}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + {}, + {"haar": False}, + {"haar_full_mass": 0}, + {"num_quant_bins": 2}, + ], + ids=str, +) def test_superspreading_sir_smoke(duration, forecast, options): population = 100 recovery_time = 7.0 @@ -263,12 +296,16 @@ def test_superspreading_sir_smoke(duration, forecast, options): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("options", [ - {}, - {"haar": False}, - {"haar_full_mass": 0}, - {"num_quant_bins": 2}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + {}, + {"haar": False}, + {"haar_full_mass": 0}, + {"num_quant_bins": 2}, + ], + ids=str, +) def test_superspreading_seir_smoke(duration, forecast, options): population = 100 incubation_time = 2.0 @@ -276,7 +313,8 @@ def test_superspreading_seir_smoke(duration, forecast, options): # Generate data. model = SuperspreadingSEIRModel( - population, incubation_time, recovery_time, [None] * duration) + population, incubation_time, recovery_time, [None] * duration + ) assert model.full_mass == [("R0", "k", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"] @@ -285,11 +323,9 @@ def test_superspreading_seir_smoke(duration, forecast, options): assert data.sum() > 0, "failed to generate positive data" # Infer. - model = SuperspreadingSEIRModel( - population, incubation_time, recovery_time, data) + model = SuperspreadingSEIRModel(population, incubation_time, recovery_time, data) num_samples = 5 - model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, - **options) + model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options) # Predict and forecast. samples = model.predict(forecast=forecast) @@ -300,12 +336,16 @@ def test_superspreading_seir_smoke(duration, forecast, options): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"num_quant_bins": 2}), + ], + ids=str, +) def test_coalescent_likelihood_smoke(duration, forecast, options, algo): population = 100 incubation_time = 2.0 @@ -313,7 +353,8 @@ def test_coalescent_likelihood_smoke(duration, forecast, options, algo): # Generate data. model = SuperspreadingSEIRModel( - population, incubation_time, recovery_time, [None] * duration) + population, incubation_time, recovery_time, [None] * duration + ) for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"] if data.sum(): @@ -325,12 +366,18 @@ def test_coalescent_likelihood_smoke(duration, forecast, options, algo): # Infer. model = SuperspreadingSEIRModel( - population, incubation_time, recovery_time, data, - leaf_times=leaf_times, coal_times=coal_times) + population, + incubation_time, + recovery_time, + data, + leaf_times=leaf_times, + coal_times=coal_times, + ) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, - **options) + model.fit_mcmc( + warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -343,13 +390,17 @@ def test_coalescent_likelihood_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"num_quant_bins": 2}), + ], + ids=str, +) def test_heterogeneous_sir_smoke(duration, forecast, options, algo): population = 100 recovery_time = 7.0 @@ -377,13 +428,17 @@ def test_heterogeneous_sir_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [4, 12]) @pytest.mark.parametrize("forecast", [7]) -@pytest.mark.parametrize("options", [ - xfail_param({}, reason="Delta is incompatible with relaxed inference"), - {"num_quant_bins": 2}, - {"num_quant_bins": 2, "haar": False}, - {"num_quant_bins": 2, "haar_full_mass": 0}, - {"num_quant_bins": 4}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + xfail_param({}, reason="Delta is incompatible with relaxed inference"), + {"num_quant_bins": 2}, + {"num_quant_bins": 2, "haar": False}, + {"num_quant_bins": 2, "haar_full_mass": 0}, + {"num_quant_bins": 4}, + ], + ids=str, +) def test_sparse_smoke(duration, forecast, options): population = 100 recovery_time = 7.0 @@ -421,12 +476,16 @@ def test_sparse_smoke(duration, forecast, options): @pytest.mark.parametrize("pre_obs_window", [6]) @pytest.mark.parametrize("duration", [8]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("options", [ - {}, - {"haar": False}, - {"haar_full_mass": 0}, - {"num_quant_bins": 2}, -], ids=str) +@pytest.mark.parametrize( + "options", + [ + {}, + {"haar": False}, + {"haar_full_mass": 0}, + {"num_quant_bins": 2}, + ], + ids=str, +) def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): population = 100 recovery_time = 7.0 @@ -467,23 +526,28 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"haar_full_mass": 0}), - ("mcmc", {"num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"haar_full_mass": 0}), + ("mcmc", {"num_quant_bins": 2}), + ], + ids=str, +) def test_regional_smoke(duration, forecast, options, algo): num_regions = 6 coupling = torch.eye(num_regions).clamp(min=0.1) - population = torch.tensor([2., 3., 4., 10., 100., 1000.]) + population = torch.tensor([2.0, 3.0, 4.0, 10.0, 100.0, 1000.0]) recovery_time = 7.0 # Generate data. - model = RegionalSIRModel(population, coupling, recovery_time, - data=[None] * duration) + model = RegionalSIRModel( + population, coupling, recovery_time, data=[None] * duration + ) assert model.full_mass == [("R0", "rho_c1", "rho_c0", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] @@ -496,7 +560,9 @@ def test_regional_smoke(duration, forecast, options, algo): model = RegionalSIRModel(population, coupling, recovery_time, data) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -514,29 +580,33 @@ def finalize(self, params, prev, curr): I = curr["I"] I_mean = I.mean(dim=[-1, -2], keepdim=True).expand_as(I) with self.region_plate, self.time_plate: - pyro.sample("likelihood", dist.Normal(I_mean, 1.), - obs=I) + pyro.sample("likelihood", dist.Normal(I_mean, 1.0), obs=I) @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"haar_full_mass": 0}), - ("mcmc", {"num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"haar_full_mass": 0}), + ("mcmc", {"num_quant_bins": 2}), + ], + ids=str, +) def test_regional_finalize_smoke(duration, forecast, options, algo): num_regions = 6 coupling = torch.eye(num_regions).clamp(min=0.1) - population = torch.tensor([2., 3., 4., 10., 100., 1000.]) + population = torch.tensor([2.0, 3.0, 4.0, 10.0, 100.0, 1000.0]) recovery_time = 7.0 # Generate data. - model = RegionalSIRModelWithFinalize(population, coupling, recovery_time, - data=[None] * duration) + model = RegionalSIRModelWithFinalize( + population, coupling, recovery_time, data=[None] * duration + ) assert model.full_mass == [("R0", "rho_c1", "rho_c0", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] @@ -549,7 +619,9 @@ def test_regional_finalize_smoke(duration, forecast, options, algo): model = RegionalSIRModelWithFinalize(population, coupling, recovery_time, data) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) @@ -561,26 +633,31 @@ def test_regional_finalize_smoke(duration, forecast, options, algo): @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -@pytest.mark.parametrize("algo,options", [ - ("svi", {}), - ("svi", {"haar": False}), - ("mcmc", {}), - ("mcmc", {"haar": False}), - ("mcmc", {"haar_full_mass": 0}), - ("mcmc", {"num_quant_bins": 2}), - ("mcmc", {"jit_compile": True}), - ("mcmc", {"jit_compile": True, "haar": False}), - ("mcmc", {"jit_compile": True, "num_quant_bins": 2}), -], ids=str) +@pytest.mark.parametrize( + "algo,options", + [ + ("svi", {}), + ("svi", {"haar": False}), + ("mcmc", {}), + ("mcmc", {"haar": False}), + ("mcmc", {"haar_full_mass": 0}), + ("mcmc", {"num_quant_bins": 2}), + ("mcmc", {"jit_compile": True}), + ("mcmc", {"jit_compile": True, "haar": False}), + ("mcmc", {"jit_compile": True, "num_quant_bins": 2}), + ], + ids=str, +) def test_hetero_regional_smoke(duration, forecast, options, algo): num_regions = 6 coupling = torch.eye(num_regions).clamp(min=0.1) - population = torch.tensor([2., 3., 4., 10., 100., 1000.]) + population = torch.tensor([2.0, 3.0, 4.0, 10.0, 100.0, 1000.0]) recovery_time = 7.0 # Generate data. - model = HeterogeneousRegionalSIRModel(population, coupling, recovery_time, - data=[None] * duration) + model = HeterogeneousRegionalSIRModel( + population, coupling, recovery_time, data=[None] * duration + ) assert model.full_mass == [("R0", "R_drift", "rho0", "rho_drift")] for attempt in range(100): data = model.generate({"R0": 1.5})["obs"] @@ -593,7 +670,9 @@ def test_hetero_regional_smoke(duration, forecast, options, algo): model = HeterogeneousRegionalSIRModel(population, coupling, recovery_time, data) num_samples = 5 if algo == "mcmc": - model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + model.fit_mcmc( + warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options + ) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) diff --git a/tests/contrib/epidemiology/test_util.py b/tests/contrib/epidemiology/test_util.py index bc602620c6..df1b90b4a7 100644 --- a/tests/contrib/epidemiology/test_util.py +++ b/tests/contrib/epidemiology/test_util.py @@ -8,8 +8,8 @@ from tests.common import assert_equal -@pytest.mark.parametrize("min", [None, 0., (), (2,)], ids=str) -@pytest.mark.parametrize("max", [None, 1., (), (2,)], ids=str) +@pytest.mark.parametrize("min", [None, 0.0, (), (2,)], ids=str) +@pytest.mark.parametrize("max", [None, 1.0, (), (2,)], ids=str) @pytest.mark.parametrize("shape", [(2,), (3, 2)], ids=str) def test_clamp(shape, min, max): tensor = torch.randn(shape) diff --git a/tests/contrib/forecast/test_evaluate.py b/tests/contrib/forecast/test_evaluate.py index 9394aa3015..c5c982f4c8 100644 --- a/tests/contrib/forecast/test_evaluate.py +++ b/tests/contrib/forecast/test_evaluate.py @@ -36,24 +36,32 @@ def model(self, zero_data, covariates): ] -@pytest.mark.parametrize("train_window,min_train_window,test_window,min_test_window,stride", WINDOWS) +@pytest.mark.parametrize( + "train_window,min_train_window,test_window,min_test_window,stride", WINDOWS +) @pytest.mark.parametrize("warm_start", [False, True], ids=["cold", "warm"]) -def test_simple(train_window, min_train_window, test_window, min_test_window, stride, warm_start): +def test_simple( + train_window, min_train_window, test_window, min_test_window, stride, warm_start +): duration = 30 obs_dim = 2 covariates = torch.zeros(duration, 0) data = torch.randn(duration, obs_dim) + 4 forecaster_options = {"num_steps": 2, "warm_start": warm_start} - expect_error = (warm_start and train_window is not None) + expect_error = warm_start and train_window is not None with optional(pytest.raises(ValueError), expect_error): - windows = backtest(data, covariates, Model, - train_window=train_window, - min_train_window=min_train_window, - test_window=test_window, - min_test_window=min_test_window, - stride=stride, - forecaster_options=forecaster_options) + windows = backtest( + data, + covariates, + Model, + train_window=train_window, + min_train_window=min_train_window, + test_window=test_window, + min_test_window=min_test_window, + stride=stride, + forecaster_options=forecaster_options, + ) if not expect_error: assert any(window["t0"] == 0 for window in windows) if stride == 1: @@ -66,9 +74,13 @@ def test_simple(train_window, min_train_window, test_window, min_test_window, st assert 0 < window[name] < math.inf -@pytest.mark.parametrize("train_window,min_train_window,test_window,min_test_window,stride", WINDOWS) +@pytest.mark.parametrize( + "train_window,min_train_window,test_window,min_test_window,stride", WINDOWS +) @pytest.mark.parametrize("engine", ["svi", "hmc"]) -def test_poisson(train_window, min_train_window, test_window, min_test_window, stride, engine): +def test_poisson( + train_window, min_train_window, test_window, min_test_window, stride, engine +): duration = 30 obs_dim = 2 covariates = torch.zeros(duration, 0) @@ -90,15 +102,19 @@ def transform(pred, truth): forecaster_fn = HMCForecaster forecaster_options = {"num_warmup": 1, "num_samples": 1} - windows = backtest(data, covariates, Model, - forecaster_fn=forecaster_fn, - transform=transform, - train_window=train_window, - min_train_window=min_train_window, - test_window=test_window, - min_test_window=min_test_window, - stride=stride, - forecaster_options=forecaster_options) + windows = backtest( + data, + covariates, + Model, + forecaster_fn=forecaster_fn, + transform=transform, + train_window=train_window, + min_train_window=min_train_window, + test_window=test_window, + min_test_window=min_test_window, + stride=stride, + forecaster_options=forecaster_options, + ) assert any(window["t0"] == 0 for window in windows) if stride == 1: @@ -123,7 +139,11 @@ def forecaster_options(t0, t1, t2): else: return {"num_steps": 0, "warm_start": True} - backtest(data, covariates, Model, - min_train_window=min_train_window, - test_window=10, - forecaster_options=forecaster_options) + backtest( + data, + covariates, + Model, + min_train_window=min_train_window, + test_window=10, + forecaster_options=forecaster_options, + ) diff --git a/tests/contrib/forecast/test_forecaster.py b/tests/contrib/forecast/test_forecaster.py index 091f7dbf04..261dd3f682 100644 --- a/tests/contrib/forecast/test_forecaster.py +++ b/tests/contrib/forecast/test_forecaster.py @@ -107,35 +107,66 @@ def model(self, zero_data, covariates): @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) @pytest.mark.parametrize("cov_dim", [0, 1, 6]) @pytest.mark.parametrize("obs_dim", [1, 2]) -@pytest.mark.parametrize("time_reparam, dct_gradients", [ - (None, False), - ("haar", False), - ("dct", False), - (None, True), -]) +@pytest.mark.parametrize( + "time_reparam, dct_gradients", + [ + (None, False), + ("haar", False), + ("dct", False), + (None, True), + ], +) @pytest.mark.parametrize("Model", [Model0, Model1, Model2, Model3, Model4]) @pytest.mark.parametrize("engine", ["svi", "hmc"]) -def test_smoke(Model, batch_shape, t_obs, t_forecast, obs_dim, cov_dim, time_reparam, dct_gradients, engine): +def test_smoke( + Model, + batch_shape, + t_obs, + t_forecast, + obs_dim, + cov_dim, + time_reparam, + dct_gradients, + engine, +): model = Model() data = torch.randn(batch_shape + (t_obs, obs_dim)) covariates = torch.randn(batch_shape + (t_obs + t_forecast, cov_dim)) if engine == "svi": - forecaster = Forecaster(model, data, covariates[..., :t_obs, :], - num_steps=2, log_every=1, time_reparam=time_reparam, - dct_gradients=dct_gradients) + forecaster = Forecaster( + model, + data, + covariates[..., :t_obs, :], + num_steps=2, + log_every=1, + time_reparam=time_reparam, + dct_gradients=dct_gradients, + ) else: if dct_gradients is True: pytest.skip("Duplicated test.") - forecaster = HMCForecaster(model, data, covariates[..., :t_obs, :], max_tree_depth=1, - num_warmup=1, num_samples=1, - jit_compile=False) + forecaster = HMCForecaster( + model, + data, + covariates[..., :t_obs, :], + max_tree_depth=1, + num_warmup=1, + num_samples=1, + jit_compile=False, + ) num_samples = 5 samples = forecaster(data, covariates, num_samples) - assert samples.shape == (num_samples,) + batch_shape + (t_forecast, obs_dim,) + assert samples.shape == (num_samples,) + batch_shape + ( + t_forecast, + obs_dim, + ) samples = forecaster(data, covariates, num_samples, batch_size=2) - assert samples.shape == (num_samples,) + batch_shape + (t_forecast, obs_dim,) + assert samples.shape == (num_samples,) + batch_shape + ( + t_forecast, + obs_dim, + ) @pytest.mark.parametrize("t_obs", [1, 7]) @@ -148,8 +179,15 @@ def test_trace_smoke(Model, batch_shape, t_obs, obs_dim, cov_dim): data = torch.randn(batch_shape + (t_obs, obs_dim)) covariates = torch.randn(batch_shape + (t_obs, cov_dim)) forecaster = Forecaster(model, data, covariates, num_steps=2, log_every=1) - hmc_forecaster = HMCForecaster(model, data, covariates, max_tree_depth=1, - num_warmup=1, num_samples=1, jit_compile=False) + hmc_forecaster = HMCForecaster( + model, + data, + covariates, + max_tree_depth=1, + num_warmup=1, + num_samples=1, + jit_compile=False, + ) # This is the desired syntax for recording posterior latent samples. num_samples = 5 @@ -191,9 +229,16 @@ def test_svi_custom_smoke(subsample_aware): guide = AutoDelta(model) optim = Adam({}) - Forecaster(model, data, covariates[..., :t_obs, :], - guide=guide, optim=optim, subsample_aware=subsample_aware, - num_steps=2, log_every=1) + Forecaster( + model, + data, + covariates[..., :t_obs, :], + guide=guide, + optim=optim, + subsample_aware=subsample_aware, + num_steps=2, + log_every=1, + ) class SubsampleModel3(ForecastingModel): @@ -263,11 +308,23 @@ def create_plates(zero_data, covariates): size = len(zero_data) return pyro.plate("batch", size, subsample_size=2, dim=-2) - forecaster = Forecaster(model, data, covariates[..., :t_obs, :], - num_steps=2, log_every=1, create_plates=create_plates) + forecaster = Forecaster( + model, + data, + covariates[..., :t_obs, :], + num_steps=2, + log_every=1, + create_plates=create_plates, + ) num_samples = 5 samples = forecaster(data, covariates, num_samples) - assert samples.shape == (num_samples,) + batch_shape + (t_forecast, obs_dim,) + assert samples.shape == (num_samples,) + batch_shape + ( + t_forecast, + obs_dim, + ) samples = forecaster(data, covariates, num_samples, batch_size=2) - assert samples.shape == (num_samples,) + batch_shape + (t_forecast, obs_dim,) + assert samples.shape == (num_samples,) + batch_shape + ( + t_forecast, + obs_dim, + ) diff --git a/tests/contrib/forecast/test_util.py b/tests/contrib/forecast/test_util.py index af629eb2a8..f6f7f6725d 100644 --- a/tests/contrib/forecast/test_util.py +++ b/tests/contrib/forecast/test_util.py @@ -55,10 +55,12 @@ def random_dist(Dist, shape, transform=None): base_dist = random_dist(dist.Normal, shape) transforms = [ dist.transforms.ExpTransform(), - dist.transforms.ComposeTransform([ - dist.transforms.AffineTransform(1, 1), - dist.transforms.ExpTransform().inv, - ]), + dist.transforms.ComposeTransform( + [ + dist.transforms.AffineTransform(1, 1), + dist.transforms.ExpTransform().inv, + ] + ), ] return dist.TransformedDistribution(base_dist, transforms) elif Dist in (dist.GaussianHMM, dist.LinearHMM): @@ -66,13 +68,18 @@ def random_dist(Dist, shape, transform=None): hidden_dim = obs_dim + 1 init_dist = random_dist(dist.Normal, batch_shape + (hidden_dim,)).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) - trans_dist = random_dist(dist.Normal, batch_shape + (duration, hidden_dim)).to_event(1) + trans_dist = random_dist( + dist.Normal, batch_shape + (duration, hidden_dim) + ).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) - obs_dist = random_dist(dist.Normal, batch_shape + (duration, obs_dim)).to_event(1) + obs_dist = random_dist(dist.Normal, batch_shape + (duration, obs_dim)).to_event( + 1 + ) if Dist is dist.LinearHMM and transform is not None: obs_dist = dist.TransformedDistribution(obs_dist, transform) - return Dist(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, - duration=duration) + return Dist( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) elif Dist is dist.IndependentHMM: batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] base_shape = batch_shape + (obs_dim, duration, 1) @@ -87,7 +94,8 @@ def random_dist(Dist, shape, transform=None): else: params = { name: transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5) - for name in UNIVARIATE_DISTS[Dist]} + for name in UNIVARIATE_DISTS[Dist] + } return Dist(**params) @@ -129,8 +137,10 @@ def test_reshape_batch(Dist, batch_shape, duration, dim): @pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)]) @pytest.mark.parametrize("transform", list(UNIVARIATE_TRANSFORMS.keys())) def test_reshape_transform_batch(transform, batch_shape, duration, dim): - params = {p: torch.rand(batch_shape + (duration, dim)) - for p in UNIVARIATE_TRANSFORMS[transform]} + params = { + p: torch.rand(batch_shape + (duration, dim)) + for p in UNIVARIATE_TRANSFORMS[transform] + } t = transform(**params) d = random_dist(dist.LinearHMM, batch_shape + (duration, dim), transform=t) d = d.to_event(2 - d.event_dim) diff --git a/tests/contrib/funsor/test_enum_funsor.py b/tests/contrib/funsor/test_enum_funsor.py index 55de235462..e8dbc60c69 100644 --- a/tests/contrib/funsor/test_enum_funsor.py +++ b/tests/contrib/funsor/test_enum_funsor.py @@ -19,6 +19,7 @@ import funsor import pyro.contrib.funsor + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro @@ -31,13 +32,20 @@ def _check_loss_and_grads(expected_loss, actual_loss): - assert_equal(actual_loss, expected_loss, - msg='Expected:\n{}\nActual:\n{}'.format(expected_loss.detach().cpu().numpy(), - actual_loss.detach().cpu().numpy())) + assert_equal( + actual_loss, + expected_loss, + msg="Expected:\n{}\nActual:\n{}".format( + expected_loss.detach().cpu().numpy(), actual_loss.detach().cpu().numpy() + ), + ) if "TEST_ENUM_PYRO_BACKEND" in os.environ: # only log if we manually set a backend - logging.debug('Expected:\n{}\nActual:\n{}'.format(expected_loss.detach().cpu().numpy(), - actual_loss.detach().cpu().numpy())) + logging.debug( + "Expected:\n{}\nActual:\n{}".format( + expected_loss.detach().cpu().numpy(), actual_loss.detach().cpu().numpy() + ) + ) names = pyro.get_param_store().keys() params = [pyro.param(name).unconstrained() for name in names] @@ -48,10 +56,15 @@ def _check_loss_and_grads(expected_loss, actual_loss): continue assert not torch_isnan(actual_grad) assert not torch_isnan(expected_grad) - assert_equal(actual_grad, expected_grad, - msg='{}\nExpected:\n{}\nActual:\n{}'.format(name, - expected_grad.detach().cpu().numpy(), - actual_grad.detach().cpu().numpy())) + assert_equal( + actual_grad, + expected_grad, + msg="{}\nExpected:\n{}\nActual:\n{}".format( + name, + expected_grad.detach().cpu().numpy(), + actual_grad.detach().cpu().numpy(), + ), + ) @pytest.mark.parametrize("inner_dim", [2]) @@ -61,7 +74,7 @@ def test_elbo_plate_plate(outer_dim, inner_dim): pyro.get_param_store().clear() q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 - p = torch.tensor([p, 1-p]) + p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) @@ -88,7 +101,8 @@ def guide(): pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = torch.distributions.kl.kl_divergence( - torch.distributions.Categorical(q), torch.distributions.Categorical(p)) + torch.distributions.Categorical(q), torch.distributions.Categorical(p) + ) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [q.unconstrained()])[0] @@ -101,21 +115,23 @@ def guide(): assert_equal(actual_grad, expected_grad, prec=1e-5) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_1(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([0.3, 0.7]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", torch.tensor([0.3, 0.7]), constraint=constraints.simplex + ) @handlers.scale(scale=scale) def auto_model(): @@ -123,8 +139,7 @@ def auto_model(): probs_y = pyro.param("model_probs_y") probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) - pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + pyro.sample("y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"}) pyro.sample("z", dist.Categorical(probs_z), obs=torch.tensor(0)) @handlers.scale(scale=scale) @@ -146,21 +161,25 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_2(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) @handlers.scale(scale=scale) def auto_model(): @@ -168,8 +187,9 @@ def auto_model(): probs_y = pyro.param("model_probs_y") probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=torch.tensor(0)) @handlers.scale(scale=scale) @@ -193,21 +213,25 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_3(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(): probs_x = pyro.param("model_probs_x") @@ -215,8 +239,9 @@ def auto_model(): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with handlers.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=torch.tensor(0)) def hand_model(): @@ -239,28 +264,32 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(2, 2), (3, 2)], - ids=["batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_1(num_samples, num_masked, scale): # +---------+ # x ----> y ----> z | # | N | # +---------+ - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(data): probs_x = pyro.param("model_probs_x") @@ -268,8 +297,9 @@ def auto_model(data): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with handlers.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) if num_masked == num_samples: with pyro.plate("data", len(data)): pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) @@ -284,8 +314,9 @@ def hand_model(data): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with handlers.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) for i in pyro.plate("data", num_masked): pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @@ -302,28 +333,32 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(2, 2), (3, 2)], - ids=["batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_2(num_samples, num_masked, scale): # +-----------------+ # x ----> y ----> z | # | N | # +-----------------+ - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(data): probs_x = pyro.param("model_probs_x") @@ -333,13 +368,19 @@ def auto_model(data): with handlers.scale(scale=scale): with pyro.plate("data", len(data)): if num_masked == num_samples: - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) else: with handlers.mask(mask=torch.arange(num_samples) < num_masked): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) def hand_model(data): @@ -349,8 +390,11 @@ def hand_model(data): x = pyro.sample("x", dist.Categorical(probs_x)) with handlers.scale(scale=scale): for i in pyro.plate("data", num_masked): - y = pyro.sample("y_{}".format(i), dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y_{}".format(i), + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @infer.config_enumerate @@ -365,10 +409,10 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(2, 2), (3, 2)], - ids=["batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): # +-----------------------+ @@ -376,18 +420,22 @@ def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): # | N | # +-----------------------+ # This plate should remain unreduced since all enumeration is in a single plate. - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) @handlers.scale(scale=scale) def auto_model(data): @@ -397,14 +445,18 @@ def auto_model(data): with pyro.plate("data", len(data)): if num_masked == num_samples: x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) else: with handlers.mask(mask=torch.arange(num_samples) < num_masked): x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) @handlers.scale(scale=scale) @@ -425,8 +477,11 @@ def hand_model(data): probs_z = pyro.param("model_probs_z") for i in pyro.plate("data", num_masked): x = pyro.sample("x_{}".format(i), dist.Categorical(probs_x)) - y = pyro.sample("y_{}".format(i), dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y_{}".format(i), + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @handlers.scale(scale=scale) @@ -443,9 +498,10 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('outer_obs,inner_obs', - [(False, True), (True, False), (True, True)]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "outer_obs,inner_obs", [(False, True), (True, False), (True, True)] +) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): # a ---> outer_obs @@ -457,8 +513,8 @@ def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): # This tests two different observations, one outside and one inside an plate. pyro.param("probs_a", torch.tensor([0.4, 0.6]), constraint=constraints.simplex) pyro.param("probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex) - pyro.param("locs", torch.tensor([-1., 1.])) - pyro.param("scales", torch.tensor([1., 2.]), constraint=constraints.positive) + pyro.param("locs", torch.tensor([-1.0, 1.0])) + pyro.param("scales", torch.tensor([1.0, 2.0]), constraint=constraints.positive) outer_data = torch.tensor(2.0) inner_data = torch.tensor([0.5, 1.5]) @@ -468,17 +524,17 @@ def auto_model(): probs_b = pyro.param("probs_b") locs = pyro.param("locs") scales = pyro.param("scales") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) if outer_obs: - pyro.sample("outer_obs", dist.Normal(0., scales[a]), - obs=outer_data) + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) with pyro.plate("inner", 2): - b = pyro.sample("b", dist.Categorical(probs_b), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b), infer={"enumerate": "parallel"} + ) if inner_obs: - pyro.sample("inner_obs", dist.Normal(locs[b], scales[a]), - obs=inner_data) + pyro.sample( + "inner_obs", dist.Normal(locs[b], scales[a]), obs=inner_data + ) @handlers.scale(scale=scale) def hand_model(): @@ -486,17 +542,21 @@ def hand_model(): probs_b = pyro.param("probs_b") locs = pyro.param("locs") scales = pyro.param("scales") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) if outer_obs: - pyro.sample("outer_obs", dist.Normal(0., scales[a]), - obs=outer_data) + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) for i in pyro.plate("inner", 2): - b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b_{}".format(i), + dist.Categorical(probs_b), + infer={"enumerate": "parallel"}, + ) if inner_obs: - pyro.sample("inner_obs_{}".format(i), dist.Normal(locs[b], scales[a]), - obs=inner_data[i]) + pyro.sample( + "inner_obs_{}".format(i), + dist.Normal(locs[b], scales[a]), + obs=inner_data[i], + ) def guide(): pass @@ -517,19 +577,22 @@ def test_elbo_enumerate_plate_5(): # | M=2 V | # | b ----> c | # +------------------+ - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([1, 2]) @infer.config_enumerate @@ -540,8 +603,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("b_axis", 2): b = pyro.sample("b", dist.Categorical(probs_b)) - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @infer.config_enumerate def guide_plate(): @@ -557,9 +619,9 @@ def model_iplate(): a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("b_axis", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i]) + pyro.sample( + "c_{}".format(i), dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i] + ) @infer.config_enumerate def guide_iplate(): @@ -570,13 +632,15 @@ def guide_iplate(): elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) expected_loss = elbo.differentiable_loss(model_iplate, guide_iplate) elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) - with pytest.raises(ValueError, match="Expected model enumeration to be no more global than guide"): + with pytest.raises( + ValueError, match="Expected model enumeration to be no more global than guide" + ): actual_loss = elbo.differentiable_loss(model_plate, guide_plate) # This never gets run because we don't support this yet. _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.parametrize('enumerate1', ['parallel', 'sequential']) +@pytest.mark.parametrize("enumerate1", ["parallel", "sequential"]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_6(enumerate1): # Guide Model @@ -586,19 +650,22 @@ def test_elbo_enumerate_plate_6(enumerate1): # +-------+ # This tests that sequential enumeration over b works, even though # model-side enumeration moves c into b's plate via contraction. - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([1, 2]) @infer.config_enumerate @@ -609,8 +676,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b)) with pyro.plate("b_axis", 2): - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @infer.config_enumerate def model_iplate(): @@ -620,9 +686,9 @@ def model_iplate(): a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b)) for i in pyro.plate("b_axis", 2): - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i]) + pyro.sample( + "c_{}".format(i), dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i] + ) @infer.config_enumerate(default=enumerate1) def guide(): @@ -636,7 +702,7 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plate_7(scale): # Guide Model @@ -647,27 +713,37 @@ def test_elbo_enumerate_plate_7(scale): # | c -----> d -----> e N=2 | # +---------------------------+ # This tests a mixture of model and guide enumeration. - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_d", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), - constraint=constraints.simplex) - pyro.param("model_probs_e", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.35, 0.64]), - constraint=constraints.simplex) - pyro.param("guide_probs_c", - torch.tensor([[0., 1.], [1., 0.]]), # deterministic - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_e", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.35, 0.64]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_c", + torch.tensor([[0.0, 1.0], [1.0, 0.0]]), # deterministic + constraint=constraints.simplex, + ) @handlers.scale(scale=scale) def auto_model(data): @@ -677,20 +753,23 @@ def auto_model(data): probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample("b", dist.Categorical(probs_b[a]), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) with pyro.plate("data", 2): c = pyro.sample("c", dist.Categorical(probs_c[a])) - d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}) + d = pyro.sample( + "d", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) @handlers.scale(scale=scale) def auto_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) with pyro.plate("data", 2): pyro.sample("c", dist.Categorical(probs_c[a])) @@ -702,21 +781,23 @@ def hand_model(data): probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample("b", dist.Categorical(probs_b[a]), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) for i in pyro.plate("data", 2): c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) - d = pyro.sample("d_{}".format(i), - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}) + d = pyro.sample( + "d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i]) @handlers.scale(scale=scale) def hand_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) for i in pyro.plate("data", 2): pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) @@ -728,7 +809,7 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_1(scale): # +-----------------+ @@ -739,18 +820,18 @@ def test_elbo_enumerate_plates_1(scale): # +-----------------+ # This tests two unrelated plates. # Each should remain uncontracted. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([0.75, 0.25]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param("probs_c", torch.tensor([0.75, 0.25]), constraint=constraints.simplex) + pyro.param( + "probs_d", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) b_data = torch.tensor([0, 1]) d_data = torch.tensor([0, 0, 1]) @@ -792,7 +873,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_2(scale): # +---------+ +---------+ @@ -800,15 +881,17 @@ def test_elbo_enumerate_plates_2(scale): # | M=2 | | N=3 | # +---------+ +---------+ # This tests two different plates with recycled dimension. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) b_data = torch.tensor([0, 1]) c_data = torch.tensor([0, 0, 1]) @@ -820,11 +903,9 @@ def auto_model(): probs_c = pyro.param("probs_c") a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("b_axis", 2): - pyro.sample("b", dist.Categorical(probs_b[a]), - obs=b_data) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) with pyro.plate("c_axis", 3): - pyro.sample("c", dist.Categorical(probs_c[a]), - obs=c_data) + pyro.sample("c", dist.Categorical(probs_c[a]), obs=c_data) @infer.config_enumerate @handlers.scale(scale=scale) @@ -834,11 +915,9 @@ def hand_model(): probs_c = pyro.param("probs_c") a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("b_axis", 2): - pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), - obs=b_data[i]) + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), obs=b_data[i]) for j in pyro.plate("c_axis", 3): - pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a]), - obs=c_data[j]) + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a]), obs=c_data[j]) def guide(): pass @@ -850,7 +929,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_3(scale): # +--------------------+ @@ -861,12 +940,12 @@ def test_elbo_enumerate_plates_3(scale): # +--------------------+ # This is tests the case of multiple plate contractions in # a single step. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1], [0, 0]]) @infer.config_enumerate @@ -877,8 +956,7 @@ def auto_model(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("outer", 2): with pyro.plate("inner", 2): - pyro.sample("b", dist.Categorical(probs_b[a]), - obs=data) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=data) @infer.config_enumerate @handlers.scale(scale=scale) @@ -889,8 +967,9 @@ def hand_model(): a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("outer", 2): for j in inner: - pyro.sample("b_{}_{}".format(i, j), dist.Categorical(probs_b[a]), - obs=data[i, j]) + pyro.sample( + "b_{}_{}".format(i, j), dist.Categorical(probs_b[a]), obs=data[i, j] + ) def guide(): pass @@ -902,7 +981,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_4(scale): # +--------------------+ @@ -911,15 +990,17 @@ def test_elbo_enumerate_plates_4(scale): # | | N=2 | | # | M=2 +----------+ | # +--------------------+ - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -944,8 +1025,9 @@ def hand_model(data): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for j in inner: - pyro.sample("c_{}_{}".format(i, j), dist.Categorical(probs_c[b]), - obs=data[i, j]) + pyro.sample( + "c_{}_{}".format(i, j), dist.Categorical(probs_c[b]), obs=data[i, j] + ) def guide(data): pass @@ -958,7 +1040,7 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_5(scale): # a @@ -969,16 +1051,17 @@ def test_elbo_enumerate_plates_5(scale): # | | N=2 | | # | M=2 +----------+ | # +-------------------+ - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], - [[0.2, 0.8], [0.1, 0.9]]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.2, 0.8], [0.1, 0.9]]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1], [0, 0]]) @infer.config_enumerate @@ -991,8 +1074,7 @@ def auto_model(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b[a])) with pyro.plate("inner", 2): - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1005,9 +1087,11 @@ def hand_model(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for j in inner: - pyro.sample("c_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i, j]) + pyro.sample( + "c_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[i, j], + ) def guide(): pass @@ -1019,7 +1103,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_6(scale): # +----------+ @@ -1034,18 +1118,22 @@ def test_elbo_enumerate_plates_6(scale): # +-------------+ # This tests different ways of mixing two independence contexts, # where each can be either sequential or vectorized plate. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + constraint=constraints.simplex, + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1058,13 +1146,19 @@ def model_iplate_iplate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] for i in b_axis: for j in c_axis: - pyro.sample("d_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_d)[b[i], c[j]]), - obs=data[i, j]) + pyro.sample( + "d_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_d)[b[i], c[j]]), + obs=data[i, j], + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1081,9 +1175,11 @@ def model_iplate_plate(data): for i in b_axis: b_i = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) with c_axis: - pyro.sample("d_{}".format(i), - dist.Categorical(Vindex(probs_d)[b_i, c]), - obs=data[i]) + pyro.sample( + "d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b_i, c]), + obs=data[i], + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1098,12 +1194,16 @@ def model_plate_iplate(data): a = pyro.sample("a", dist.Categorical(probs_a)) with b_axis: b = pyro.sample("b", dist.Categorical(probs_b[a])) - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] with b_axis: for j in c_axis: - pyro.sample("d_{}".format(j), - dist.Categorical(Vindex(probs_d)[b, c[j]]), - obs=data[:, j]) + pyro.sample( + "d_{}".format(j), + dist.Categorical(Vindex(probs_d)[b, c[j]]), + obs=data[:, j], + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1120,9 +1220,7 @@ def model_plate_plate(data): with c_axis: c = pyro.sample("c", dist.Categorical(probs_c[a])) with b_axis, c_axis: - pyro.sample("d", - dist.Categorical(Vindex(probs_d)[b, c]), - obs=data) + pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), obs=data) def guide(data): pass @@ -1143,7 +1241,7 @@ def guide(data): elbo.differentiable_loss(model_plate_plate, guide, data) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_elbo_enumerate_plates_7(scale): # +-------------+ @@ -1159,21 +1257,27 @@ def test_elbo_enumerate_plates_7(scale): # +----------------+ # This tests tree-structured dependencies among variables but # non-tree dependencies among plate nestings. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) - pyro.param("probs_e", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_d", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_e", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1187,14 +1291,24 @@ def model_iplate_iplate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] for i in b_axis: for j in c_axis: - pyro.sample("d_{}_{}".format(i, j), dist.Categorical(probs_d[b[i]]), - obs=data[i, j]) - pyro.sample("e_{}_{}".format(i, j), dist.Categorical(probs_e[c[j]]), - obs=data[i, j]) + pyro.sample( + "d_{}_{}".format(i, j), + dist.Categorical(probs_d[b[i]]), + obs=data[i, j], + ) + pyro.sample( + "e_{}_{}".format(i, j), + dist.Categorical(probs_e[c[j]]), + obs=data[i, j], + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1212,10 +1326,10 @@ def model_iplate_plate(data): for i in b_axis: b_i = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) with c_axis: - pyro.sample("d_{}".format(i), dist.Categorical(probs_d[b_i]), - obs=data[i]) - pyro.sample("e_{}".format(i), dist.Categorical(probs_e[c]), - obs=data[i]) + pyro.sample( + "d_{}".format(i), dist.Categorical(probs_d[b_i]), obs=data[i] + ) + pyro.sample("e_{}".format(i), dist.Categorical(probs_e[c]), obs=data[i]) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1231,13 +1345,17 @@ def model_plate_iplate(data): a = pyro.sample("a", dist.Categorical(probs_a)) with b_axis: b = pyro.sample("b", dist.Categorical(probs_b[a])) - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] with b_axis: for j in c_axis: - pyro.sample("d_{}".format(j), dist.Categorical(probs_d[b]), - obs=data[:, j]) - pyro.sample("e_{}".format(j), dist.Categorical(probs_e[c[j]]), - obs=data[:, j]) + pyro.sample( + "d_{}".format(j), dist.Categorical(probs_d[b]), obs=data[:, j] + ) + pyro.sample( + "e_{}".format(j), dist.Categorical(probs_e[c[j]]), obs=data[:, j] + ) @infer.config_enumerate @handlers.scale(scale=scale) @@ -1275,12 +1393,17 @@ def guide(data): _check_loss_and_grads(loss_iplate_iplate, loss_plate_plate) -@pytest.mark.parametrize('guide_scale', [1]) -@pytest.mark.parametrize('model_scale', [1]) -@pytest.mark.parametrize('outer_vectorized', [False, xfail_param(True, reason="validation not yet implemented")]) -@pytest.mark.parametrize('inner_vectorized', [False, True]) +@pytest.mark.parametrize("guide_scale", [1]) +@pytest.mark.parametrize("model_scale", [1]) +@pytest.mark.parametrize( + "outer_vectorized", + [False, xfail_param(True, reason="validation not yet implemented")], +) +@pytest.mark.parametrize("inner_vectorized", [False, True]) @pyroapi.pyro_backend(_PYRO_BACKEND) -def test_elbo_enumerate_plates_8(model_scale, guide_scale, inner_vectorized, outer_vectorized): +def test_elbo_enumerate_plates_8( + model_scale, guide_scale, inner_vectorized, outer_vectorized +): # Guide Model # a # +-----------|--------+ @@ -1289,19 +1412,22 @@ def test_elbo_enumerate_plates_8(model_scale, guide_scale, inner_vectorized, out # | b ----> c | | # | +----------+ | # +--------------------+ - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([[0, 1], [0, 2]]) @infer.config_enumerate @@ -1314,9 +1440,7 @@ def model_plate_plate(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b)) with pyro.plate("inner", 2): - pyro.sample("c", - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @infer.config_enumerate @handlers.scale(scale=model_scale) @@ -1329,9 +1453,11 @@ def model_iplate_plate(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) with inner: - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[:, i]) + pyro.sample( + "c_{}".format(i), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[:, i], + ) @infer.config_enumerate @handlers.scale(scale=model_scale) @@ -1343,9 +1469,11 @@ def model_plate_iplate(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b)) for j in pyro.plate("inner", 2): - pyro.sample("c_{}".format(j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[j]) + pyro.sample( + "c_{}".format(j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j], + ) @infer.config_enumerate @handlers.scale(scale=model_scale) @@ -1358,9 +1486,11 @@ def model_iplate_iplate(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) for j in inner: - pyro.sample("c_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[j, i]) + pyro.sample( + "c_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j, i], + ) @infer.config_enumerate @handlers.scale(scale=guide_scale) @@ -1403,21 +1533,27 @@ def test_elbo_enumerate_plate_9(): # | M=2 V | # | b -> c | # +---------------+ - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.3, 0.7], [0.6, 0.4]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([[0.3, 0.7], [0.8, 0.2]]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.3, 0.7], [0.6, 0.4]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_b", + torch.tensor([[0.3, 0.7], [0.8, 0.2]]), + constraint=constraints.simplex, + ) data = torch.tensor([1, 2]) @infer.config_enumerate @@ -1428,8 +1564,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("b_axis", 2): b = pyro.sample("b", dist.Categorical(probs_b[a])) - pyro.sample("c", dist.Categorical(probs_c[b]), - obs=data) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) @infer.config_enumerate def guide_plate(): @@ -1447,8 +1582,7 @@ def model_iplate(): a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("b_axis", 2): b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) - pyro.sample(f"c_{i}", dist.Categorical(probs_c[b]), - obs=data[i]) + pyro.sample(f"c_{i}", dist.Categorical(probs_c[b]), obs=data[i]) @infer.config_enumerate def guide_iplate(): @@ -1471,21 +1605,27 @@ def test_elbo_enumerate_plate_10(): # a -> [ [ bij -> cij ] ] # Guide # a -> [ [ bij ] ] - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.3, 0.7], [0.6, 0.4]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([[0.3, 0.7], [0.8, 0.2]]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.3, 0.7], [0.6, 0.4]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_b", + torch.tensor([[0.3, 0.7], [0.8, 0.2]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1, 2], [1, 2, 2]]) @infer.config_enumerate @@ -1497,8 +1637,7 @@ def model_plate(): with pyro.plate("i", 2, dim=-2): with pyro.plate("j", 3, dim=-1): b = pyro.sample("b", dist.Categorical(probs_b[a])) - pyro.sample("c", dist.Categorical(probs_c[b]), - obs=data) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) @infer.config_enumerate def guide_plate(): @@ -1518,8 +1657,7 @@ def model_iplate(): for i in pyro.plate("i", 2): for j in pyro.plate("j", 3): b = pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) - pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), - obs=data[i, j]) + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), obs=data[i, j]) @infer.config_enumerate def guide_iplate(): @@ -1543,21 +1681,27 @@ def test_elbo_enumerate_plate_11(): # [ ai -> [ bij -> cij ] ] # Guide # [ ai -> [ bij ] ] - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.3, 0.7], [0.6, 0.4]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([[0.3, 0.7], [0.8, 0.2]]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.3, 0.7], [0.6, 0.4]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_b", + torch.tensor([[0.3, 0.7], [0.8, 0.2]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1, 2], [1, 2, 2]]) @infer.config_enumerate @@ -1569,8 +1713,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("j", 3, dim=-1): b = pyro.sample("b", dist.Categorical(probs_b[a])) - pyro.sample("c", dist.Categorical(probs_c[b]), - obs=data) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) @infer.config_enumerate def guide_plate(): @@ -1590,8 +1733,7 @@ def model_iplate(): a = pyro.sample(f"a_{i}", dist.Categorical(probs_a)) for j in pyro.plate("j", 3): b = pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) - pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), - obs=data[i, j]) + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), obs=data[i, j]) @infer.config_enumerate def guide_iplate(): @@ -1615,27 +1757,37 @@ def test_elbo_enumerate_plate_12(): # a -> [ bi -> [ cij -> dij ] ] # Guide # a -> [ bi -> [ cij ] ] - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.3, 0.7], [0.6, 0.4]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("model_probs_d", - torch.tensor([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([[0.3, 0.7], [0.8, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_c", - torch.tensor([[0.3, 0.3, 0.4], [0.2, 0.4, 0.4]]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.3, 0.7], [0.6, 0.4]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_d", + torch.tensor([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_b", + torch.tensor([[0.3, 0.7], [0.8, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_c", + torch.tensor([[0.3, 0.3, 0.4], [0.2, 0.4, 0.4]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1, 2], [1, 2, 2]]) @infer.config_enumerate @@ -1649,8 +1801,7 @@ def model_plate(): b = pyro.sample("b", dist.Categorical(probs_b[a])) with pyro.plate("j", 3, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c[b])) - pyro.sample("d", dist.Categorical(probs_d[c]), - obs=data) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=data) @infer.config_enumerate def guide_plate(): @@ -1674,8 +1825,7 @@ def model_iplate(): b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) for j in pyro.plate("j", 3): c = pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b])) - pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), - obs=data[i, j]) + pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), obs=data[i, j]) @infer.config_enumerate def guide_iplate(): @@ -1707,27 +1857,37 @@ def test_elbo_enumerate_plate_13(): # | # v # [ bi ] - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.3, 0.7], [0.6, 0.4]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.3, 0.7], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("model_probs_d", - torch.tensor([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([[0.3, 0.7], [0.8, 0.2]]), - constraint=constraints.simplex) - pyro.param("guide_probs_c", - torch.tensor([[0.2, 0.8], [0.9, 0.1]]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.3, 0.7], [0.6, 0.4]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.3, 0.7], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_d", + torch.tensor([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_b", + torch.tensor([[0.3, 0.7], [0.8, 0.2]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_c", + torch.tensor([[0.2, 0.8], [0.9, 0.1]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1, 2], [1, 2, 2]]) @infer.config_enumerate @@ -1741,8 +1901,7 @@ def model_plate(): pyro.sample("b", dist.Categorical(probs_b[a])) with pyro.plate("j", 3, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c[a])) - pyro.sample("d", dist.Categorical(probs_d[c]), - obs=data) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=data) @infer.config_enumerate def guide_plate(): @@ -1766,8 +1925,7 @@ def model_iplate(): pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) for j in pyro.plate("j", 3): c = pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[a])) - pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), - obs=data[i, j]) + pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), obs=data[i, j]) @infer.config_enumerate def guide_iplate(): diff --git a/tests/contrib/funsor/test_infer_discrete.py b/tests/contrib/funsor/test_infer_discrete.py index a3e90cbcaf..bdfcf71f7c 100644 --- a/tests/contrib/funsor/test_infer_discrete.py +++ b/tests/contrib/funsor/test_infer_discrete.py @@ -16,6 +16,7 @@ import funsor import pyro.contrib.funsor + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro @@ -27,8 +28,8 @@ _PYRO_BACKEND = os.environ.get("TEST_ENUM_PYRO_BACKEND", "contrib.funsor") -@pytest.mark.parametrize('length', [1, 2, 10, 100]) -@pytest.mark.parametrize('temperature', [0, 1]) +@pytest.mark.parametrize("length", [1, 2, 10, 100]) +@pytest.mark.parametrize("temperature", [0, 1]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_hmm_smoke(length, temperature): @@ -38,11 +39,14 @@ def hmm(data, hidden_dim=10): means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): - states.append(pyro.sample("states_{}".format(t), - dist.Categorical(transition[states[-1]]))) - data[t] = pyro.sample("obs_{}".format(t), - dist.Normal(means[states[-1]], 1.), - obs=data[t]) + states.append( + pyro.sample( + "states_{}".format(t), dist.Categorical(transition[states[-1]]) + ) + ) + data[t] = pyro.sample( + "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] + ) return states, data true_states, data = hmm([None] * length) @@ -64,36 +68,45 @@ def test_distribution_1(temperature): # z --|--> x | # +-------+ num_particles = 10000 - data = torch.tensor([1., 2., 3.]) + data = torch.tensor([1.0, 2.0, 3.0]) @infer.config_enumerate def model(z=None): p = pyro.param("p", torch.tensor([0.75, 0.25])) iz = pyro.sample("z", dist.Categorical(p), obs=z) - z = torch.tensor([0., 1.])[iz] + z = torch.tensor([0.0, 1.0])[iz] logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3): - pyro.sample("x", dist.Normal(z, 1.), obs=data) + pyro.sample("x", dist.Normal(z, 1.0), obs=data) first_available_dim = -3 - vectorized_model = model if temperature == 0 else \ - pyro.plate("particles", size=num_particles, dim=-2)(model) + vectorized_model = ( + model + if temperature == 0 + else pyro.plate("particles", size=num_particles, dim=-2)(model) + ) sampled_model = infer.infer_discrete( - vectorized_model, - first_available_dim, - temperature + vectorized_model, first_available_dim, temperature ) sampled_trace = handlers.trace(sampled_model).get_trace() - conditioned_traces = {z: handlers.trace(model).get_trace(z=torch.tensor(z).long()) for z in [0., 1.]} + conditioned_traces = { + z: handlers.trace(model).get_trace(z=torch.tensor(z).long()) for z in [0.0, 1.0] + } # Check posterior over z. actual_z_mean = sampled_trace.nodes["z"]["value"].float().mean() if temperature: - expected_z_mean = 1 / (1 + (conditioned_traces[0].log_prob_sum() - - conditioned_traces[1].log_prob_sum()).exp()) + expected_z_mean = 1 / ( + 1 + + ( + conditioned_traces[0].log_prob_sum() + - conditioned_traces[1].log_prob_sum() + ).exp() + ) else: - expected_z_mean = (conditioned_traces[1].log_prob_sum() > - conditioned_traces[0].log_prob_sum()).float() + expected_z_mean = ( + conditioned_traces[1].log_prob_sum() > conditioned_traces[0].log_prob_sum() + ).float() expected_max = max(t.log_prob_sum() for t in conditioned_traces.values()) actual_max = sampled_trace.log_prob_sum() assert_equal(expected_max, actual_max, prec=1e-5) @@ -110,40 +123,51 @@ def test_distribution_2(temperature): # z2 --|--> x2 | # +--------+ num_particles = 10000 - data = torch.tensor([[-1., -1., 0.], [-1., 1., 1.]]) + data = torch.tensor([[-1.0, -1.0, 0.0], [-1.0, 1.0, 1.0]]) @infer.config_enumerate def model(z1=None, z2=None): p = pyro.param("p", torch.tensor([[0.25, 0.75], [0.1, 0.9]])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) z1 = pyro.sample("z1", dist.Categorical(p[0]), obs=z1) z2 = pyro.sample("z2", dist.Categorical(p[z1]), obs=z2) logger.info("z1.shape = {}".format(z1.shape)) logger.info("z2.shape = {}".format(z2.shape)) with pyro.plate("data", 3): - pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0]) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1]) + pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1]) first_available_dim = -3 - vectorized_model = model if temperature == 0 else \ - pyro.plate("particles", size=num_particles, dim=-2)(model) + vectorized_model = ( + model + if temperature == 0 + else pyro.plate("particles", size=num_particles, dim=-2)(model) + ) sampled_model = infer.infer_discrete( - vectorized_model, - first_available_dim, - temperature + vectorized_model, first_available_dim, temperature ) sampled_trace = handlers.trace(sampled_model).get_trace() - conditioned_traces = {(z1, z2): handlers.trace(model).get_trace(z1=torch.tensor(z1), - z2=torch.tensor(z2)) - for z1 in [0, 1] for z2 in [0, 1]} + conditioned_traces = { + (z1, z2): handlers.trace(model).get_trace( + z1=torch.tensor(z1), z2=torch.tensor(z2) + ) + for z1 in [0, 1] + for z2 in [0, 1] + } # Check joint posterior over (z1, z2). actual_probs = torch.empty(2, 2) expected_probs = torch.empty(2, 2) for (z1, z2), tr in conditioned_traces.items(): expected_probs[z1, z2] = tr.log_prob_sum().exp() - actual_probs[z1, z2] = ((sampled_trace.nodes["z1"]["value"] == z1) & - (sampled_trace.nodes["z2"]["value"] == z2)).float().mean() + actual_probs[z1, z2] = ( + ( + (sampled_trace.nodes["z1"]["value"] == z1) + & (sampled_trace.nodes["z2"]["value"] == z2) + ) + .float() + .mean() + ) if temperature: expected_probs = expected_probs / expected_probs.sum() @@ -164,35 +188,45 @@ def test_distribution_3_simple(temperature): # | 2 | # +---------------+ num_particles = 10000 - data = torch.tensor([-1., 1.]) + data = torch.tensor([-1.0, 1.0]) @infer.config_enumerate def model(z2=None): p = pyro.param("p", torch.tensor([0.25, 0.75])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) with pyro.plate("data", 2): z2 = pyro.sample("z2", dist.Categorical(p), obs=z2) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data) first_available_dim = -3 - vectorized_model = model if temperature == 0 else \ - pyro.plate("particles", size=num_particles, dim=-2)(model) + vectorized_model = ( + model + if temperature == 0 + else pyro.plate("particles", size=num_particles, dim=-2)(model) + ) sampled_model = infer.infer_discrete( - vectorized_model, - first_available_dim, - temperature + vectorized_model, first_available_dim, temperature ) sampled_trace = handlers.trace(sampled_model).get_trace() - conditioned_traces = {(z20, z21): handlers.trace(model).get_trace(z2=torch.tensor([z20, z21])) - for z20 in [0, 1] for z21 in [0, 1]} + conditioned_traces = { + (z20, z21): handlers.trace(model).get_trace(z2=torch.tensor([z20, z21])) + for z20 in [0, 1] + for z21 in [0, 1] + } # Check joint posterior over (z2[0], z2[1]). actual_probs = torch.empty(2, 2) expected_probs = torch.empty(2, 2) for (z20, z21), tr in conditioned_traces.items(): expected_probs[z20, z21] = tr.log_prob_sum().exp() - actual_probs[z20, z21] = ((sampled_trace.nodes["z2"]["value"][..., :1] == z20) & - (sampled_trace.nodes["z2"]["value"][..., 1:] == z21)).float().mean() + actual_probs[z20, z21] = ( + ( + (sampled_trace.nodes["z2"]["value"][..., :1] == z20) + & (sampled_trace.nodes["z2"]["value"][..., 1:] == z21) + ) + .float() + .mean() + ) if temperature: expected_probs = expected_probs / expected_probs.sum() else: @@ -212,40 +246,52 @@ def test_distribution_3(temperature): # | 3 | | 2 | # +---------+ +---------------+ num_particles = 10000 - data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])] + data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])] @infer.config_enumerate def model(z1=None, z2=None): p = pyro.param("p", torch.tensor([0.25, 0.75])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) z1 = pyro.sample("z1", dist.Categorical(p), obs=z1) with pyro.plate("data[0]", 3): - pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0]) + pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) with pyro.plate("data[1]", 2): z2 = pyro.sample("z2", dist.Categorical(p), obs=z2) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1]) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1]) first_available_dim = -3 - vectorized_model = model if temperature == 0 else \ - pyro.plate("particles", size=num_particles, dim=-2)(model) + vectorized_model = ( + model + if temperature == 0 + else pyro.plate("particles", size=num_particles, dim=-2)(model) + ) sampled_model = infer.infer_discrete( - vectorized_model, - first_available_dim, - temperature + vectorized_model, first_available_dim, temperature ) sampled_trace = handlers.trace(sampled_model).get_trace() - conditioned_traces = {(z1, z20, z21): handlers.trace(model).get_trace(z1=torch.tensor(z1), - z2=torch.tensor([z20, z21])) - for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1]} + conditioned_traces = { + (z1, z20, z21): handlers.trace(model).get_trace( + z1=torch.tensor(z1), z2=torch.tensor([z20, z21]) + ) + for z1 in [0, 1] + for z20 in [0, 1] + for z21 in [0, 1] + } # Check joint posterior over (z1, z2[0], z2[1]). actual_probs = torch.empty(2, 2, 2) expected_probs = torch.empty(2, 2, 2) for (z1, z20, z21), tr in conditioned_traces.items(): expected_probs[z1, z20, z21] = tr.log_prob_sum().exp() - actual_probs[z1, z20, z21] = ((sampled_trace.nodes["z1"]["value"] == z1) & - (sampled_trace.nodes["z2"]["value"][..., :1] == z20) & - (sampled_trace.nodes["z2"]["value"][..., 1:] == z21)).float().mean() + actual_probs[z1, z20, z21] = ( + ( + (sampled_trace.nodes["z1"]["value"] == z1) + & (sampled_trace.nodes["z2"]["value"][..., :1] == z20) + & (sampled_trace.nodes["z2"]["value"][..., 1:] == z21) + ) + .float() + .mean() + ) if temperature: expected_probs = expected_probs / expected_probs.sum() else: @@ -264,7 +310,7 @@ def model_zzxx(): # z1 --|--> x1 | | z2 ---> x2 | # | 3 | | 2 | # +---------+ +---------------+ - data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])] + data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])] p = pyro.param("p", torch.tensor([0.25, 0.75])) loc = pyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1)) # FIXME results in infinite loop in transformeddist_to_funsor. @@ -280,13 +326,13 @@ def model_zzxx(): def model2(): - data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])] + data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])] p = pyro.param("p", torch.tensor([0.25, 0.75])) loc = pyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1)) # FIXME results in infinite loop in transformeddist_to_funsor. # scale = pyro.sample("scale", dist.LogNormal(0, 1)) z1 = pyro.sample("z1", dist.Categorical(p)) - scale = pyro.sample("scale", dist.Normal(torch.tensor([0., 1.])[z1], 1)).exp() + scale = pyro.sample("scale", dist.Normal(torch.tensor([0.0, 1.0])[z1], 1)).exp() with pyro.plate("data[0]", 3): pyro.sample("x1", dist.Normal(loc[z1], scale), obs=data[0]) with pyro.plate("data[1]", 2): @@ -300,12 +346,17 @@ def model2(): def test_svi_model_side_enumeration(model, temperature): # Perform fake inference. # This has the wrong distribution but the right type for tests. - guide = AutoNormal(handlers.enum(handlers.block(infer.config_enumerate(model), expose=["loc", "scale"]))) + guide = AutoNormal( + handlers.enum( + handlers.block(infer.config_enumerate(model), expose=["loc", "scale"]) + ) + ) guide() # Initialize but don't bother to train. guide_trace = handlers.trace(guide).get_trace() guide_data = { name: site["value"] - for name, site in guide_trace.nodes.items() if site["type"] == "sample" + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. @@ -314,7 +365,7 @@ def test_svi_model_side_enumeration(model, temperature): # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), guide_trace) handlers.condition(infer.config_enumerate(model), guide_data), - temperature=temperature + temperature=temperature, ) ).get_trace() @@ -333,13 +384,13 @@ def test_mcmc_model_side_enumeration(model, temperature): # This has the wrong distribution but the right type for tests. mcmc_trace = handlers.trace( handlers.block( - handlers.enum(infer.config_enumerate(model)), - expose=["loc", "scale"] + handlers.enum(infer.config_enumerate(model)), expose=["loc", "scale"] ) ).get_trace() mcmc_data = { name: site["value"] - for name, site in mcmc_trace.nodes.items() if site["type"] == "sample" + for name, site in mcmc_trace.nodes.items() + if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. @@ -348,7 +399,7 @@ def test_mcmc_model_side_enumeration(model, temperature): # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), mcmc_trace), handlers.condition(infer.config_enumerate(model), mcmc_data), - temperature=temperature + temperature=temperature, ), ).get_trace() @@ -358,14 +409,14 @@ def test_mcmc_model_side_enumeration(model, temperature): assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs -@pytest.mark.parametrize('temperature', [0, 1]) +@pytest.mark.parametrize("temperature", [0, 1]) @pyroapi.pyro_backend(_PYRO_BACKEND) def test_distribution_masked(temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 - data = torch.tensor([1., 2., 3.]) + data = torch.tensor([1.0, 2.0, 3.0]) mask = torch.tensor([True, False, False]) @infer.config_enumerate @@ -374,25 +425,34 @@ def model(z=None): z = pyro.sample("z", dist.Categorical(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3), handlers.mask(mask=mask): - pyro.sample("x", dist.Normal(z.type_as(data), 1.), obs=data) + pyro.sample("x", dist.Normal(z.type_as(data), 1.0), obs=data) first_available_dim = -3 - vectorized_model = model if temperature == 0 else \ - pyro.plate("particles", size=num_particles, dim=-2)(model) + vectorized_model = ( + model + if temperature == 0 + else pyro.plate("particles", size=num_particles, dim=-2)(model) + ) sampled_model = infer.infer_discrete( - vectorized_model, - first_available_dim, - temperature + vectorized_model, first_available_dim, temperature ) sampled_trace = handlers.trace(sampled_model).get_trace() - conditioned_traces = {z: handlers.trace(model).get_trace(z=torch.tensor(z)) for z in [0., 1.]} + conditioned_traces = { + z: handlers.trace(model).get_trace(z=torch.tensor(z)) for z in [0.0, 1.0] + } # Check posterior over z. actual_z_mean = sampled_trace.nodes["z"]["value"].type_as(data).mean() if temperature: - expected_z_mean = 1 / (1 + (conditioned_traces[0].log_prob_sum() - - conditioned_traces[1].log_prob_sum()).exp()) + expected_z_mean = 1 / ( + 1 + + ( + conditioned_traces[0].log_prob_sum() + - conditioned_traces[1].log_prob_sum() + ).exp() + ) else: - expected_z_mean = (conditioned_traces[1].log_prob_sum() > - conditioned_traces[0].log_prob_sum()).float() + expected_z_mean = ( + conditioned_traces[1].log_prob_sum() > conditioned_traces[0].log_prob_sum() + ).float() assert_equal(actual_z_mean, expected_z_mean, prec=1e-2) diff --git a/tests/contrib/funsor/test_named_handlers.py b/tests/contrib/funsor/test_named_handlers.py index c4c57b7bd5..ea447b59d2 100644 --- a/tests/contrib/funsor/test_named_handlers.py +++ b/tests/contrib/funsor/test_named_handlers.py @@ -14,6 +14,7 @@ import pyro.contrib.funsor from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger + funsor.set_backend("torch") from pyroapi import pyro, pyro_backend except ImportError: @@ -24,11 +25,14 @@ def test_iteration(): - def testing(): for i in pyro.markov(range(5)): - v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), 'real')) - v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) + v1 = pyro.to_data( + Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), "real") + ) + v2 = pyro.to_data( + Tensor(torch.zeros(2), OrderedDict([("a", funsor.Bint[2])]), "real") + ) fv1 = pyro.to_funsor(v1, funsor.Real) fv2 = pyro.to_funsor(v2, funsor.Real) print(i, v1.shape) # shapes should alternate @@ -38,34 +42,51 @@ def testing(): assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) - print('a', v2.shape) # shapes should stay the same - print('a', fv2.inputs) + print("a", v2.shape) # shapes should stay the same + print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing() def test_nesting(): - def testing(): with pyro.markov(): - v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]), 'real')) + v1 = pyro.to_data( + Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]), "real") + ) print(1, v1.shape) # shapes should alternate assert v1.shape == (2,) with pyro.markov(): - v2 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(2), funsor.Bint[2])]), 'real')) + v2 = pyro.to_data( + Tensor( + torch.ones(2), OrderedDict([(str(2), funsor.Bint[2])]), "real" + ) + ) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with pyro.markov(): - v3 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(3), funsor.Bint[2])]), 'real')) + v3 = pyro.to_data( + Tensor( + torch.ones(2), + OrderedDict([(str(3), funsor.Bint[2])]), + "real", + ) + ) print(3, v3.shape) # shapes should alternate assert v3.shape == (2,) with pyro.markov(): - v4 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(4), funsor.Bint[2])]), 'real')) + v4 = pyro.to_data( + Tensor( + torch.ones(2), + OrderedDict([(str(4), funsor.Bint[2])]), + "real", + ) + ) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1) @@ -75,26 +96,28 @@ def testing(): def test_staggered(): - def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: - v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) + v2 = pyro.to_data( + Tensor(torch.zeros(2), OrderedDict([("a", funsor.Bint[2])]), "real") + ) fv2 = pyro.to_funsor(v2, funsor.Real) assert v2.shape == (2,) - print('a', v2.shape) - print('a', fv2.inputs) + print("a", v2.shape) + print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing() def test_fresh_inputs_to_funsor(): - def testing(): - x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"}) + x = pyro.to_funsor(torch.tensor([0.0, 1.0]), funsor.Real, dim_to_name={-1: "x"}) assert set(x.inputs) == {"x"} - px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"}) + px = pyro.to_funsor( + torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"} + ) assert px.inputs["x"].dtype == 2 and px.inputs["y"].dtype == 3 with pyro_backend("contrib.funsor"), NamedMessenger(): @@ -102,7 +125,6 @@ def testing(): def test_iteration_fresh(): - def testing(): for i in pyro.markov(range(5)): fv1 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: str(i)}) @@ -116,23 +138,22 @@ def testing(): assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) - print('a', v2.shape) # shapes should stay the same - print('a', fv2.inputs) + print("a", v2.shape) # shapes should stay the same + print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing() def test_staggered_fresh(): - def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: - fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: 'a'}) + fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: "a"}) v2 = pyro.to_data(fv2) assert v2.shape == (2,) - print('a', v2.shape) - print('a', fv2.inputs) + print("a", v2.shape) + print("a", fv2.inputs) with pyro_backend("contrib.funsor"), NamedMessenger(first_available_dim=-1): testing() diff --git a/tests/contrib/funsor/test_pyroapi_funsor.py b/tests/contrib/funsor/test_pyroapi_funsor.py index 407b9fb89a..21a57f5ece 100644 --- a/tests/contrib/funsor/test_pyroapi_funsor.py +++ b/tests/contrib/funsor/test_pyroapi_funsor.py @@ -8,6 +8,7 @@ import funsor import pyro.contrib.funsor # noqa: F401 + funsor.set_backend("torch") except ImportError: pytestmark = pytest.mark.skip() diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index fec050cb48..26c45c7b1c 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -16,6 +16,7 @@ import funsor import pyro.contrib.funsor + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import infer, pyro, pyro_backend @@ -30,26 +31,28 @@ @pytest.mark.parametrize("max_plate_nesting", [2, 3]) @pytest.mark.parametrize("tmc_strategy", ["diagonal", "mixture"]) def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy): - def model(): x = pyro.sample("x0", dist.Categorical(pyro.param("q0"))) with pyro.plate("local", 3): for i in range(1, depth): - x = pyro.sample("x{}".format(i), - dist.Categorical(pyro.param("q{}".format(i))[..., x, :])) + x = pyro.sample( + "x{}".format(i), + dist.Categorical(pyro.param("q{}".format(i))[..., x, :]), + ) with pyro.plate("data", 4): - pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), - obs=data) + pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), obs=data) with pyro_backend("pyro"): # initialize qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))] for i in range(1, depth): - qs.append(pyro.param( - "q{}".format(i), - torch.randn(2, 2).abs().detach().requires_grad_(), - constraint=constraints.simplex - )) + qs.append( + pyro.param( + "q{}".format(i), + torch.randn(2, 2).abs().detach().requires_grad_(), + constraint=constraints.simplex, + ) + ) qs.append(pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True))) qs = [q.unconstrained() for q in qs] data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype, device=qs[-1].device) @@ -57,28 +60,52 @@ def model(): with pyro_backend("pyro"): elbo = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) enum_model = infer.config_enumerate( - model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) + model, + default="parallel", + expand=False, + num_samples=num_samples, + tmc=tmc_strategy, + ) expected_loss = (-elbo.differentiable_loss(enum_model, lambda: None)).exp() expected_grads = grad(expected_loss, qs) with pyro_backend("contrib.funsor"): tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = infer.config_enumerate( - model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) + model, + default="parallel", + expand=False, + num_samples=num_samples, + tmc=tmc_strategy, + ) actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp() actual_grads = grad(actual_loss, qs) prec = 0.05 - assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) + assert_equal( + actual_loss, + expected_loss, + prec=prec, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=prec, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("depth", [1, 2, 3, 4]) @@ -87,26 +114,45 @@ def model(): @pytest.mark.parametrize("guide_type", ["prior", "factorized", "nonfactorized"]) @pytest.mark.parametrize("reparameterized", [False, True], ids=["dice", "pathwise"]) @pytest.mark.parametrize("tmc_strategy", ["diagonal", "mixture"]) -def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting, expand, - guide_type, reparameterized, tmc_strategy): +def test_tmc_normals_chain_gradient( + depth, + num_samples, + max_plate_nesting, + expand, + guide_type, + reparameterized, + tmc_strategy, +): def model(reparameterized): - Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + Normal = ( + dist.Normal + if reparameterized + else dist.testing.fakes.NonreparameterizedNormal + ) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) - pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) + pyro.sample("y", Normal(x, 1.0), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): - Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal - pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + Normal = ( + dist.Normal + if reparameterized + else dist.testing.fakes.NonreparameterizedNormal + ) + pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i+1) / depth))) + pyro.sample("x{}".format(i), Normal(0.0, math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): - Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + Normal = ( + dist.Normal + if reparameterized + else dist.testing.fakes.NonreparameterizedNormal + ) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) with pyro_backend("contrib.funsor"): # compare reparameterized and nonreparameterized gradient estimates @@ -115,27 +161,50 @@ def nonfactorized_guide(reparameterized): tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = infer.config_enumerate( - model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) - guide = factorized_guide if guide_type == "factorized" else \ - nonfactorized_guide if guide_type == "nonfactorized" else \ - lambda *args: None + model, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) + guide = ( + factorized_guide + if guide_type == "factorized" + else nonfactorized_guide + if guide_type == "nonfactorized" + else lambda *args: None + ) tmc_guide = infer.config_enumerate( - guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) + guide, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) # convert to linear space for unbiasedness - actual_loss = (-tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() + actual_loss = ( + -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized) + ).exp() actual_grads = grad(actual_loss, qs) # gold values from Funsor - expected_grads = (torch.tensor( - {1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771}[depth] - ),) + expected_grads = ( + torch.tensor({1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771}[depth]), + ) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): print(actual_loss) - assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=grad_prec, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) diff --git a/tests/contrib/funsor/test_valid_models_enum.py b/tests/contrib/funsor/test_valid_models_enum.py index 7df1b23a90..8ad27d6bdd 100644 --- a/tests/contrib/funsor/test_valid_models_enum.py +++ b/tests/contrib/funsor/test_valid_models_enum.py @@ -22,6 +22,7 @@ import pyro.contrib.funsor from pyro.contrib.funsor.handlers.runtime import _DIM_STACK + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro, pyro_backend @@ -51,21 +52,38 @@ def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): - guide_tr_pyro = handlers.trace(handlers.queue( - guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend - )).get_trace(**kwargs) - tr_pyro = handlers.trace(handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) + guide_tr_pyro = handlers.trace( + handlers.queue( + guide, + q_pyro, + escape_fn=iter_discrete_escape, + extend_fn=iter_discrete_extend, + ) + ).get_trace(**kwargs) + tr_pyro = handlers.trace( + handlers.replay(model, trace=guide_tr_pyro) + ).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): - guide_tr_funsor = handlers.trace(handlers.queue( - guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend - )).get_trace(**kwargs) - tr_funsor = handlers.trace(handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) + guide_tr_funsor = handlers.trace( + handlers.queue( + guide, + q_funsor, + escape_fn=iter_discrete_escape, + extend_fn=iter_discrete_extend, + ) + ).get_trace(**kwargs) + tr_funsor = handlers.trace( + handlers.replay(model, trace=guide_tr_funsor) + ).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame - assert not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name + assert ( + not _DIM_STACK.global_frame.name_to_dim + and not _DIM_STACK.global_frame.dim_to_name + ) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) @@ -81,13 +99,15 @@ def _check_traces(tr_pyro, tr_funsor): tr_pyro.pack_tensors() symbol_to_name = { - node['infer']['_enumerate_symbol']: name + node["infer"]["_enumerate_symbol"]: name for name, node in tr_pyro.nodes.items() - if node['type'] == 'sample' and not node['is_observed'] - and node['infer'].get('enumerate') == 'parallel' + if node["type"] == "sample" + and not node["is_observed"] + and node["infer"].get("enumerate") == "parallel" } - symbol_to_name.update({ - symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()}) + symbol_to_name.update( + {symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()} + ) if _NAMED_TEST_STRENGTH >= 1: # coarser check: enumeration requirements satisfied @@ -96,78 +116,106 @@ def _check_traces(tr_pyro, tr_funsor): try: # coarser check: number of elements and squeezed shapes for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - assert pyro_node['packed']['log_prob'].numel() == funsor_node['log_prob'].numel() - assert pyro_node['packed']['log_prob'].shape == funsor_node['log_prob'].squeeze().shape - assert frozenset(f for f in pyro_node['cond_indep_stack'] if f.vectorized) == \ - frozenset(f for f in funsor_node['cond_indep_stack'] if f.vectorized) + assert ( + pyro_node["packed"]["log_prob"].numel() + == funsor_node["log_prob"].numel() + ) + assert ( + pyro_node["packed"]["log_prob"].shape + == funsor_node["log_prob"].squeeze().shape + ) + assert frozenset( + f for f in pyro_node["cond_indep_stack"] if f.vectorized + ) == frozenset( + f for f in funsor_node["cond_indep_stack"] if f.vectorized + ) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - pyro_packed_shape = pyro_node['packed']['log_prob'].shape - funsor_packed_shape = funsor_node['log_prob'].squeeze().shape + pyro_packed_shape = pyro_node["packed"]["log_prob"].shape + funsor_packed_shape = funsor_node["log_prob"].squeeze().shape if pyro_packed_shape != funsor_packed_shape: err_str = "==> (dep mismatch) {}".format(name) else: err_str = name - print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_packed_shape, funsor_packed_shape)) + print( + err_str, + "Pyro: {} vs Funsor: {}".format( + pyro_packed_shape, funsor_packed_shape + ), + ) raise if _NAMED_TEST_STRENGTH >= 2: try: # medium check: unordered packed shapes match for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) - funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) - assert pyro_names == frozenset(name.replace('__PARTICLES', '') for name in funsor_names) + pyro_names = frozenset( + symbol_to_name[d] + for d in pyro_node["packed"]["log_prob"]._pyro_dims + ) + funsor_names = frozenset(funsor_node["funsor"]["log_prob"].inputs) + assert pyro_names == frozenset( + name.replace("__PARTICLES", "") for name in funsor_names + ) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) - funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) + pyro_names = frozenset( + symbol_to_name[d] + for d in pyro_node["packed"]["log_prob"]._pyro_dims + ) + funsor_names = frozenset(funsor_node["funsor"]["log_prob"].inputs) if pyro_names != funsor_names: err_str = "==> (packed mismatch) {}".format(name) else: err_str = name - print(err_str, "Pyro: {} vs Funsor: {}".format(sorted(tuple(pyro_names)), sorted(tuple(funsor_names)))) + print( + err_str, + "Pyro: {} vs Funsor: {}".format( + sorted(tuple(pyro_names)), sorted(tuple(funsor_names)) + ), + ) raise if _NAMED_TEST_STRENGTH >= 3: try: # finer check: exact match with unpacked Pyro shapes for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - assert pyro_node['log_prob'].shape == funsor_node['log_prob'].shape - assert pyro_node['value'].shape == funsor_node['value'].shape + assert pyro_node["log_prob"].shape == funsor_node["log_prob"].shape + assert pyro_node["value"].shape == funsor_node["value"].shape except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): - if pyro_node['type'] != 'sample': + if pyro_node["type"] != "sample": continue funsor_node = tr_funsor.nodes[name] - pyro_shape = pyro_node['log_prob'].shape - funsor_shape = funsor_node['log_prob'].shape + pyro_shape = pyro_node["log_prob"].shape + funsor_shape = funsor_node["log_prob"].shape if pyro_shape != funsor_shape: err_str = "==> (unpacked mismatch) {}".format(name) else: err_str = name - print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape)) + print( + err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape) + ) raise @pytest.mark.parametrize("history", [1, 2, 3]) def test_enum_recycling_chain_iter(history): - @infer.config_enumerate def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) @@ -182,14 +230,18 @@ def model(): @pytest.mark.parametrize("history", [2, 3]) def test_enum_recycling_chain_iter_interleave_parallel_sequential(history): - def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) xs = [0] for t in pyro.markov(range(10), history=history): - xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]]), - infer={"enumerate": ("sequential", "parallel")[t % 2]})) + xs.append( + pyro.sample( + "x_{}".format(t), + dist.Categorical(p[xs[-1]]), + infer={"enumerate": ("sequential", "parallel")[t % 2]}, + ) + ) assert all(x.dim() <= history + 1 for x in xs[1:]) assert_ok(model, max_plate_nesting=0) @@ -197,7 +249,6 @@ def model(): @pytest.mark.parametrize("history", [1, 2, 3]) def test_enum_recycling_chain_while(history): - @infer.config_enumerate def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) @@ -215,7 +266,6 @@ def model(): @pytest.mark.parametrize("history", [1, 2, 3]) def test_enum_recycling_chain_recur(history): - @infer.config_enumerate def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) @@ -233,8 +283,8 @@ def fn(t, x): assert_ok(model, max_plate_nesting=0) -@pytest.mark.parametrize('use_vindex', [False, True]) -@pytest.mark.parametrize('markov', [False, True]) +@pytest.mark.parametrize("use_vindex", [False, True]) +@pytest.mark.parametrize("markov", [False, True]) def test_enum_recycling_dbn(markov, use_vindex): # x --> x --> x enum "state" # y | y | y | enum "occlusion" @@ -257,8 +307,9 @@ def model(): else: z_ind = torch.arange(4, dtype=torch.long) probs = r[x.unsqueeze(-1), y.unsqueeze(-1), z_ind] - pyro.sample("z_{}".format(t), dist.Categorical(probs), - obs=torch.tensor(0.)) + pyro.sample( + "z_{}".format(t), dist.Categorical(probs), obs=torch.tensor(0.0) + ) assert_ok(model, max_plate_nesting=0) @@ -306,7 +357,7 @@ def model(): @pytest.mark.xfail(reason="Pyro behavior here appears to be incorrect") @pytest.mark.parametrize("grid_size", [4, 20]) -@pytest.mark.parametrize('use_vindex', [False, True]) +@pytest.mark.parametrize("use_vindex", [False, True]) def test_enum_recycling_grid(grid_size, use_vindex): # x---x---x---x -----> i # | | | | | @@ -327,38 +378,39 @@ def model(): probs = Vindex(p)[x[i - 1, j], x[i, j - 1]] else: ind = torch.arange(2, dtype=torch.long) - probs = p[x[i - 1, j].unsqueeze(-1), - x[i, j - 1].unsqueeze(-1), ind] - x[i, j] = pyro.sample("x_{}_{}".format(i, j), - dist.Categorical(probs)) + probs = p[x[i - 1, j].unsqueeze(-1), x[i, j - 1].unsqueeze(-1), ind] + x[i, j] = pyro.sample("x_{}_{}".format(i, j), dist.Categorical(probs)) assert_ok(model, max_plate_nesting=0) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2]) @pytest.mark.parametrize("depth", [3, 5, 7]) -@pytest.mark.parametrize('history', [1, 2]) +@pytest.mark.parametrize("history", [1, 2]) def test_enum_recycling_reentrant_history(max_plate_nesting, depth, history): data = (True, False) for i in range(depth): data = (data, data, False) def model_(**kwargs): - @pyro.markov(history=history) def model(data, state=0, address=""): if isinstance(data, bool): p = pyro.param("p_leaf", torch.ones(10)) - pyro.sample("leaf_{}".format(address), - dist.Bernoulli(p[state]), - obs=torch.tensor(1. if data else 0.)) + pyro.sample( + "leaf_{}".format(address), + dist.Bernoulli(p[state]), + obs=torch.tensor(1.0 if data else 0.0), + ) else: assert isinstance(data, tuple) p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model(branch, next_state, address + letter) return model(**kwargs) @@ -366,7 +418,7 @@ def model(data, state=0, address=""): assert_ok(model_, max_plate_nesting=max_plate_nesting, data=data) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2]) @pytest.mark.parametrize("depth", [3, 5, 7]) def test_enum_recycling_mutual_recursion(max_plate_nesting, depth): data = (True, False) @@ -374,12 +426,13 @@ def test_enum_recycling_mutual_recursion(max_plate_nesting, depth): data = (data, data, False) def model_(**kwargs): - def model_leaf(data, state=0, address=""): p = pyro.param("p_leaf", torch.ones(10)) - pyro.sample("leaf_{}".format(address), - dist.Bernoulli(p[state]), - obs=torch.tensor(1. if data else 0.)) + pyro.sample( + "leaf_{}".format(address), + dist.Bernoulli(p[state]), + obs=torch.tensor(1.0 if data else 0.0), + ) @pyro.markov def model1(data, state=0, address=""): @@ -388,9 +441,11 @@ def model1(data, state=0, address=""): else: p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model2(branch, next_state, address + letter) @pyro.markov @@ -400,9 +455,11 @@ def model2(data, state=0, address=""): else: p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model1(branch, next_state, address + letter) return model1(**kwargs) @@ -410,23 +467,24 @@ def model2(data, state=0, address=""): assert_ok(model_, max_plate_nesting=0, data=data) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2]) def test_enum_recycling_interleave(max_plate_nesting): - def model(): with pyro.markov() as m: with pyro.markov(): with m: # error here - pyro.sample("x", dist.Categorical(torch.ones(4)), - infer={"enumerate": "parallel"}) + pyro.sample( + "x", + dist.Categorical(torch.ones(4)), + infer={"enumerate": "parallel"}, + ) assert_ok(model, max_plate_nesting=max_plate_nesting) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2]) -@pytest.mark.parametrize('history', [2, 3]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2]) +@pytest.mark.parametrize("history", [2, 3]) def test_markov_history(max_plate_nesting, history): - @infer.config_enumerate def model(): p = pyro.param("p", 0.25 * torch.ones(2, 2)) @@ -435,8 +493,12 @@ def model(): x_curr = torch.tensor(0) for t in pyro.markov(range(10), history=history): probs = p[x_prev, x_curr] - x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long() - pyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), - obs=torch.tensor(0.)) + x_prev, x_curr = ( + x_curr, + pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long(), + ) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=torch.tensor(0.0) + ) assert_ok(model, max_plate_nesting=max_plate_nesting) diff --git a/tests/contrib/funsor/test_valid_models_plate.py b/tests/contrib/funsor/test_valid_models_plate.py index 23d9f2305e..4b2f75d844 100644 --- a/tests/contrib/funsor/test_valid_models_plate.py +++ b/tests/contrib/funsor/test_valid_models_plate.py @@ -14,6 +14,7 @@ import funsor import pyro.contrib.funsor + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import infer, pyro @@ -25,31 +26,32 @@ logger = logging.getLogger(__name__) -@pytest.mark.parametrize('enumerate_', [None, "parallel", "sequential"]) +@pytest.mark.parametrize("enumerate_", [None, "parallel", "sequential"]) def test_enum_discrete_non_enumerated_plate_ok(enumerate_): - def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) with pyro.plate("non_enum", 2): - a = pyro.sample("a", dist.Bernoulli(0.5), infer={'enumerate': None}) + a = pyro.sample("a", dist.Bernoulli(0.5), infer={"enumerate": None}) p = (1.0 + a.sum(-1)) / (2.0 + a.shape[0]) # introduce dependency of b on a with pyro.plate("enum_1", 3): - pyro.sample("b", dist.Bernoulli(p), infer={'enumerate': enumerate_}) + pyro.sample("b", dist.Bernoulli(p), infer={"enumerate": enumerate_}) assert_ok(model, max_plate_nesting=1) -@pytest.mark.parametrize("plate_dims", [ - (None, None, None, None), - (-3, None, None, None), - (None, -3, None, None), - (-2, -3, None, None), -]) +@pytest.mark.parametrize( + "plate_dims", + [ + (None, None, None, None), + (-3, None, None, None), + (None, -3, None, None), + (-2, -3, None, None), + ], +) def test_plate_dim_allocation_ok(plate_dims): - def model(): p = torch.tensor(0.5, requires_grad=True) with pyro.plate("plate_outer", 5, dim=plate_dims[0]): @@ -64,18 +66,28 @@ def model(): assert_ok(model, max_plate_nesting=4) -@pytest.mark.parametrize("tmc_strategy", [None, xfail_param("diagonal", reason="strategy not implemented yet")]) +@pytest.mark.parametrize( + "tmc_strategy", + [None, xfail_param("diagonal", reason="strategy not implemented yet")], +) @pytest.mark.parametrize("subsampling", [False, True]) @pytest.mark.parametrize("reuse_plate", [False, True]) def test_enum_recycling_plate(subsampling, reuse_plate, tmc_strategy): - - @infer.config_enumerate(default="parallel", tmc=tmc_strategy, num_samples=2 if tmc_strategy else None) + @infer.config_enumerate( + default="parallel", tmc=tmc_strategy, num_samples=2 if tmc_strategy else None + ) def model(): p = pyro.param("p", torch.ones(3, 3)) q = pyro.param("q", torch.tensor([0.5, 0.5])) - plate_x = pyro.plate("plate_x", 4, subsample_size=3 if subsampling else None, dim=-1) - plate_y = pyro.plate("plate_y", 5, subsample_size=3 if subsampling else None, dim=-1) - plate_z = pyro.plate("plate_z", 6, subsample_size=3 if subsampling else None, dim=-2) + plate_x = pyro.plate( + "plate_x", 4, subsample_size=3 if subsampling else None, dim=-1 + ) + plate_y = pyro.plate( + "plate_y", 5, subsample_size=3 if subsampling else None, dim=-1 + ) + plate_z = pyro.plate( + "plate_z", 6, subsample_size=3 if subsampling else None, dim=-2 + ) a = pyro.sample("a", dist.Bernoulli(q[0])).long() w = 0 @@ -116,7 +128,6 @@ def model(): @pytest.mark.parametrize("enumerate_", [None, "parallel", "sequential"]) @pytest.mark.parametrize("reuse_plate", [True, False]) def test_enum_discrete_plates_dependency_ok(enumerate_, reuse_plate): - @infer.config_enumerate(default=enumerate_) def model(): x_plate = pyro.plate("x_plate", 10, dim=-1) @@ -136,14 +147,17 @@ def model(): assert_ok(model, max_plate_nesting=2) -@pytest.mark.parametrize('subsampling', [False, True]) -@pytest.mark.parametrize('enumerate_', [None, "parallel", "sequential"]) +@pytest.mark.parametrize("subsampling", [False, True]) +@pytest.mark.parametrize("enumerate_", [None, "parallel", "sequential"]) def test_enum_discrete_plate_shape_broadcasting_ok(subsampling, enumerate_): - @infer.config_enumerate(default=enumerate_) def model(): - x_plate = pyro.plate("x_plate", 5, subsample_size=2 if subsampling else None, dim=-1) - y_plate = pyro.plate("y_plate", 6, subsample_size=3 if subsampling else None, dim=-2) + x_plate = pyro.plate( + "x_plate", 5, subsample_size=2 if subsampling else None, dim=-1 + ) + y_plate = pyro.plate( + "y_plate", 6, subsample_size=3 if subsampling else None, dim=-2 + ) with pyro.plate("num_particles", 50, dim=-3): with x_plate: b = pyro.sample("b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) @@ -172,11 +186,10 @@ def model(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) @pytest.mark.parametrize("num_samples", [None, 2]) def test_plate_subsample_primitive_ok(subsample_size, num_samples): - @infer.config_enumerate(num_samples=num_samples, tmc="full") def model(): with pyro.plate("plate", 10, subsample_size=subsample_size, dim=None): - p0 = torch.tensor(0.) + p0 = torch.tensor(0.0) p0 = pyro.subsample(p0, event_dim=0) assert p0.shape == () p = 0.5 * torch.ones(10) diff --git a/tests/contrib/funsor/test_valid_models_sequential_plate.py b/tests/contrib/funsor/test_valid_models_sequential_plate.py index 31095caba8..44bba9b8d8 100644 --- a/tests/contrib/funsor/test_valid_models_sequential_plate.py +++ b/tests/contrib/funsor/test_valid_models_sequential_plate.py @@ -13,6 +13,7 @@ import funsor import pyro.contrib.funsor + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import infer, pyro @@ -24,24 +25,27 @@ logger = logging.getLogger(__name__) -@pytest.mark.parametrize('subsampling', [False, True]) -@pytest.mark.parametrize('enumerate_', [None, "parallel", "sequential"]) +@pytest.mark.parametrize("subsampling", [False, True]) +@pytest.mark.parametrize("enumerate_", [None, "parallel", "sequential"]) def test_enum_discrete_iplate_plate_dependency_ok(subsampling, enumerate_): - def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) inner_plate = pyro.plate("plate", 10, subsample_size=4 if subsampling else None) - for i in pyro.plate("iplate", 10, subsample=torch.arange(3) if subsampling else None): + for i in pyro.plate( + "iplate", 10, subsample=torch.arange(3) if subsampling else None + ): pyro.sample("y_{}".format(i), dist.Bernoulli(0.5)) with inner_plate: - pyro.sample("x_{}".format(i), dist.Bernoulli(0.5), - infer={'enumerate': enumerate_}) + pyro.sample( + "x_{}".format(i), + dist.Bernoulli(0.5), + infer={"enumerate": enumerate_}, + ) assert_ok(model, max_plate_nesting=1) def test_enum_iplate_iplate_ok(): - @infer.config_enumerate def model(data=None): probs_a = torch.tensor([0.45, 0.55]) @@ -52,20 +56,25 @@ def model(data=None): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] for i in b_axis: for j in c_axis: - pyro.sample("d_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_d)[b[i], c[j]]), - obs=data[i, j]) + pyro.sample( + "d_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_d)[b[i], c[j]]), + obs=data[i, j], + ) data = torch.tensor([[0, 1], [0, 0]]) assert_ok(model, max_plate_nesting=1, data=data) def test_enum_plate_iplate_ok(): - @infer.config_enumerate def model(data=None): probs_a = torch.tensor([0.45, 0.55]) @@ -81,16 +90,17 @@ def model(data=None): with b_axis: for j in c_axis: c_j = pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) - pyro.sample("d_{}".format(j), - dist.Categorical(Vindex(probs_d)[b, c_j]), - obs=data[:, j]) + pyro.sample( + "d_{}".format(j), + dist.Categorical(Vindex(probs_d)[b, c_j]), + obs=data[:, j], + ) data = torch.tensor([[0, 1], [0, 0]]) assert_ok(model, max_plate_nesting=1, data=data) def test_enum_iplate_plate_ok(): - @infer.config_enumerate def model(data=None): probs_a = torch.tensor([0.45, 0.55]) @@ -106,9 +116,11 @@ def model(data=None): for i in b_axis: b_i = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) with c_axis: - pyro.sample("d_{}".format(i), - dist.Categorical(Vindex(probs_d)[b_i, c]), - obs=data[i]) + pyro.sample( + "d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b_i, c]), + obs=data[i], + ) data = torch.tensor([[0, 1], [0, 0]]) assert_ok(model, max_plate_nesting=1, data=data) diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 5d56c942b6..9fc24fdd43 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -15,6 +15,7 @@ import pyro.contrib.funsor from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace + funsor.set_backend("torch") from pyroapi import distributions as dist from pyroapi import handlers, infer, pyro @@ -29,22 +30,34 @@ def model_0(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) + trans = pyro.param( + "trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex + ) locs = pyro.param("locs", lambda: torch.rand(x_dim)) with pyro.plate("sequences", data.shape[0], dim=-3) as sequences: sequences = sequences[:, None] x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov( + name="time", size=data.shape[1], dim=-2, history=history + ) + if vectorized else pyro.markov(range(data.shape[1]), history=history) + ) for i in markov_loop: x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - init if isinstance(i, int) and i < 1 else trans[x_prev])) + "x_{}".format(i), + dist.Categorical( + init if isinstance(i, int) and i < 1 else trans[x_prev] + ), + ) with pyro.plate("tones", data.shape[2], dim=-1): - pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), - obs=Vindex(data)[sequences, i]) + pyro.sample( + "y_{}".format(i), + dist.Normal(Vindex(locs)[..., x_curr], 1.0), + obs=Vindex(data)[sequences, i], + ) x_prev = x_curr @@ -55,19 +68,28 @@ def model_0(data, history, vectorized): def model_1(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) + trans = pyro.param( + "trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex + ) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - init if isinstance(i, int) and i < 1 else trans[x_prev])) + "x_{}".format(i), + dist.Categorical(init if isinstance(i, int) and i < 1 else trans[x_prev]), + ) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) + pyro.sample( + "y_{}".format(i), + dist.Normal(Vindex(locs)[..., x_curr], 1.0), + obs=data[i], + ) x_prev = x_curr @@ -77,23 +99,44 @@ def model_1(data, history, vectorized): # y[t-1] --> y[t] --> y[t+1] def model_2(data, history, vectorized): x_dim, y_dim = 3, 2 - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) - y_init = pyro.param("y_init", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) - y_trans = pyro.param("y_trans", lambda: torch.rand((x_dim, y_dim, y_dim)), constraint=constraints.simplex) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex + ) + y_init = pyro.param( + "y_init", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex + ) + y_trans = pyro.param( + "y_trans", + lambda: torch.rand((x_dim, y_dim, y_dim)), + constraint=constraints.simplex, + ) x_prev = y_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - x_init if isinstance(i, int) and i < 1 else x_trans[x_prev])) + "x_{}".format(i), + dist.Categorical( + x_init if isinstance(i, int) and i < 1 else x_trans[x_prev] + ), + ) with pyro.plate("tones", data.shape[-1], dim=-1): - y_curr = pyro.sample("y_{}".format(i), dist.Categorical( - y_init[x_curr] if isinstance(i, int) and i < 1 else Vindex(y_trans)[x_curr, y_prev]), - obs=data[i]) + y_curr = pyro.sample( + "y_{}".format(i), + dist.Categorical( + y_init[x_curr] + if isinstance(i, int) and i < 1 + else Vindex(y_trans)[x_curr, y_prev] + ), + obs=data[i], + ) x_prev, y_prev = x_curr, y_curr @@ -104,27 +147,49 @@ def model_2(data, history, vectorized): # y[t-1] y[t] y[t+1] def model_3(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 - w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) - w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) - y_probs = pyro.param("y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex) + w_init = pyro.param( + "w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex + ) + w_trans = pyro.param( + "w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex + ) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex + ) + y_probs = pyro.param( + "y_probs", + lambda: torch.rand(w_dim, x_dim, y_dim), + constraint=constraints.simplex, + ) w_prev = x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: w_curr = pyro.sample( - "w_{}".format(i), dist.Categorical( - w_init if isinstance(i, int) and i < 1 else w_trans[w_prev])) + "w_{}".format(i), + dist.Categorical( + w_init if isinstance(i, int) and i < 1 else w_trans[w_prev] + ), + ) x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - x_init if isinstance(i, int) and i < 1 else x_trans[x_prev])) + "x_{}".format(i), + dist.Categorical( + x_init if isinstance(i, int) and i < 1 else x_trans[x_prev] + ), + ) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Categorical( - Vindex(y_probs)[w_curr, x_curr]), - obs=data[i]) + pyro.sample( + "y_{}".format(i), + dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), + obs=data[i], + ) x_prev, w_prev = x_curr, w_curr @@ -136,27 +201,53 @@ def model_3(data, history, vectorized): # y[t-1] y[t] y[t+1] def model_4(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 - w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) - w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) - x_init = pyro.param("x_init", lambda: torch.rand(w_dim, x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((w_dim, x_dim, x_dim)), constraint=constraints.simplex) - y_probs = pyro.param("y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex) + w_init = pyro.param( + "w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex + ) + w_trans = pyro.param( + "w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex + ) + x_init = pyro.param( + "x_init", lambda: torch.rand(w_dim, x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", + lambda: torch.rand((w_dim, x_dim, x_dim)), + constraint=constraints.simplex, + ) + y_probs = pyro.param( + "y_probs", + lambda: torch.rand(w_dim, x_dim, y_dim), + constraint=constraints.simplex, + ) w_prev = x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: w_curr = pyro.sample( - "w_{}".format(i), dist.Categorical( - w_init if isinstance(i, int) and i < 1 else w_trans[w_prev])) + "w_{}".format(i), + dist.Categorical( + w_init if isinstance(i, int) and i < 1 else w_trans[w_prev] + ), + ) x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - x_init[w_curr] if isinstance(i, int) and i < 1 else x_trans[w_curr, x_prev])) + "x_{}".format(i), + dist.Categorical( + x_init[w_curr] + if isinstance(i, int) and i < 1 + else x_trans[w_curr, x_prev] + ), + ) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Categorical( - Vindex(y_probs)[w_curr, x_curr]), - obs=data[i]) + pyro.sample( + "y_{}".format(i), + dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), + obs=data[i], + ) x_prev, w_prev = x_curr, w_curr @@ -169,15 +260,27 @@ def model_4(data, history, vectorized): # y[t-1] y[t] y[t+1] y[t+2] def model_5(data, history, vectorized): x_dim, y_dim = 3, 2 - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_init_2 = pyro.param("x_init_2", lambda: torch.rand(x_dim, x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim, x_dim)), constraint=constraints.simplex) - y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_init_2 = pyro.param( + "x_init_2", lambda: torch.rand(x_dim, x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", + lambda: torch.rand((x_dim, x_dim, x_dim)), + constraint=constraints.simplex, + ) + y_probs = pyro.param( + "y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex + ) x_prev = x_prev_2 = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: if isinstance(i, int) and i == 0: x_probs = x_init @@ -186,12 +289,11 @@ def model_5(data, history, vectorized): else: x_probs = Vindex(x_trans)[x_prev_2, x_prev] - x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical(x_probs)) + x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Categorical( - Vindex(y_probs)[x_curr]), - obs=data[i]) + pyro.sample( + "y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=data[i] + ) x_prev_2, x_prev = x_prev, x_curr @@ -203,26 +305,37 @@ def model_5(data, history, vectorized): # y[t-1] y[t] y[t+1] def model_6(data, history, vectorized): x_dim = 3 - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((len(data)-1, x_dim, x_dim)), constraint=constraints.simplex) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", + lambda: torch.rand((len(data) - 1, x_dim, x_dim)), + constraint=constraints.simplex, + ) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: if isinstance(i, int) and i < 1: x_probs = x_init elif isinstance(i, int): - x_probs = x_trans[i-1, x_prev] + x_probs = x_trans[i - 1, x_prev] else: - x_probs = Vindex(x_trans)[(i-1)[:, None], x_prev] + x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev] - x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical(x_probs)) + x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) + pyro.sample( + "y_{}".format(i), + dist.Normal(Vindex(locs)[..., x_curr], 1.0), + obs=data[i], + ) x_prev = x_curr @@ -236,52 +349,79 @@ def model_6(data, history, vectorized): # x[t-1] x[t] x[t+1] def model_7(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 - w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) - w_trans = pyro.param("w_trans", lambda: torch.rand((x_dim, w_dim)), constraint=constraints.simplex) - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((w_dim, x_dim)), constraint=constraints.simplex) - y_probs = pyro.param("y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex) + w_init = pyro.param( + "w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex + ) + w_trans = pyro.param( + "w_trans", lambda: torch.rand((x_dim, w_dim)), constraint=constraints.simplex + ) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", lambda: torch.rand((w_dim, x_dim)), constraint=constraints.simplex + ) + y_probs = pyro.param( + "y_probs", + lambda: torch.rand(w_dim, x_dim, y_dim), + constraint=constraints.simplex, + ) w_prev = x_prev = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: w_curr = pyro.sample( - "w_{}".format(i), dist.Categorical( - w_init if isinstance(i, int) and i < 1 else w_trans[x_prev])) + "w_{}".format(i), + dist.Categorical( + w_init if isinstance(i, int) and i < 1 else w_trans[x_prev] + ), + ) x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical( - x_init if isinstance(i, int) and i < 1 else x_trans[w_prev])) + "x_{}".format(i), + dist.Categorical( + x_init if isinstance(i, int) and i < 1 else x_trans[w_prev] + ), + ) with pyro.plate("tones", data.shape[-1], dim=-1): - pyro.sample("y_{}".format(i), dist.Categorical( - Vindex(y_probs)[w_curr, x_curr]), - obs=data[i]) + pyro.sample( + "y_{}".format(i), + dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), + obs=data[i], + ) x_prev, w_prev = x_curr, w_curr def _guide_from_model(model): try: with pyro_backend("contrib.funsor"): - return handlers.block(infer.config_enumerate(model, default="parallel"), - lambda msg: msg.get("is_observed", False)) + return handlers.block( + infer.config_enumerate(model, default="parallel"), + lambda msg: msg.get("is_observed", False), + ) except KeyError: # for test collection without funsor return model @pytest.mark.parametrize("use_replay", [True, False]) -@pytest.mark.parametrize("model,data,var,history", [ - (model_0, torch.rand(3, 5, 4), "xy", 1), - (model_1, torch.rand(5, 4), "xy", 1), - (model_2, torch.ones((5, 4), dtype=torch.long), "xy", 1), - (model_3, torch.ones((5, 4), dtype=torch.long), "wxy", 1), - (model_4, torch.ones((5, 4), dtype=torch.long), "wxy", 1), - (model_5, torch.ones((5, 4), dtype=torch.long), "xy", 2), - (model_6, torch.rand(5, 4), "xy", 1), - (model_6, torch.rand(100, 4), "xy", 1), - (model_7, torch.ones((5, 4), dtype=torch.long), "wxy", 1), - (model_7, torch.ones((50, 4), dtype=torch.long), "wxy", 1), -]) +@pytest.mark.parametrize( + "model,data,var,history", + [ + (model_0, torch.rand(3, 5, 4), "xy", 1), + (model_1, torch.rand(5, 4), "xy", 1), + (model_2, torch.ones((5, 4), dtype=torch.long), "xy", 1), + (model_3, torch.ones((5, 4), dtype=torch.long), "wxy", 1), + (model_4, torch.ones((5, 4), dtype=torch.long), "wxy", 1), + (model_5, torch.ones((5, 4), dtype=torch.long), "xy", 2), + (model_6, torch.rand(5, 4), "xy", 1), + (model_6, torch.rand(100, 4), "xy", 1), + (model_7, torch.ones((5, 4), dtype=torch.long), "wxy", 1), + (model_7, torch.ones((50, 4), dtype=torch.long), "wxy", 1), + ], +) def test_enumeration(model, data, var, history, use_replay): pyro.clear_param_store() @@ -292,11 +432,16 @@ def test_enumeration(model, data, var, history, use_replay): trace = handlers.trace(enum_model).get_trace(data, history, False) # vectorized trace if use_replay: - guide_trace = handlers.trace(_guide_from_model(model)).get_trace(data, history, True) + guide_trace = handlers.trace(_guide_from_model(model)).get_trace( + data, history, True + ) vectorized_trace = handlers.trace( - handlers.replay(model, trace=guide_trace)).get_trace(data, history, True) + handlers.replay(model, trace=guide_trace) + ).get_trace(data, history, True) else: - vectorized_trace = handlers.trace(enum_model).get_trace(data, history, True) + vectorized_trace = handlers.trace(enum_model).get_trace( + data, history, True + ) # sequential factors factors = list() @@ -308,15 +453,25 @@ def test_enumeration(model, data, var, history, use_replay): vectorized_factors = list() for i in range(history): for v in var: - vectorized_factors.append(vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) + vectorized_factors.append( + vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"] + ) for i in range(history, data.shape[-2]): for v in var: vectorized_factors.append( - vectorized_trace.nodes["{}_{}".format(v, slice(history, data.shape[-2]))]["funsor"]["log_prob"] - (**{"time": i-history}, - **{"{}_{}".format(k, slice(history-j, data.shape[-2]-j)): "{}_{}".format(k, i-j) - for j in range(history+1) for k in var}) + vectorized_trace.nodes[ + "{}_{}".format(v, slice(history, data.shape[-2])) + ]["funsor"]["log_prob"]( + **{"time": i - history}, + **{ + "{}_{}".format( + k, slice(history - j, data.shape[-2] - j) + ): "{}_{}".format(k, i - j) + for j in range(history + 1) + for k in var + } ) + ) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): @@ -328,8 +483,10 @@ def test_enumeration(model, data, var, history, use_replay): expected_step = frozenset() expected_measure_vars = frozenset() for v in var[:-1]: - v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ - + tuple("{}_{}".format(v, slice(j, data.shape[-2]-history+j)) for j in range(history+1)) + v_step = tuple("{}_{}".format(v, i) for i in range(history)) + tuple( + "{}_{}".format(v, slice(j, data.shape[-2] - history + j)) + for j in range(history + 1) + ) expected_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: @@ -352,65 +509,103 @@ def test_enumeration(model, data, var, history, use_replay): # z[j-1] z[j] z[j+1] def model_8(weeks_data, days_data, history, vectorized): x_dim, y_dim, w_dim, z_dim = 3, 2, 2, 3 - x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) - x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) - y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) - w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) - w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) - z_probs = pyro.param("z_probs", lambda: torch.rand(w_dim, z_dim), constraint=constraints.simplex) + x_init = pyro.param( + "x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex + ) + x_trans = pyro.param( + "x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex + ) + y_probs = pyro.param( + "y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex + ) + w_init = pyro.param( + "w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex + ) + w_trans = pyro.param( + "w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex + ) + z_probs = pyro.param( + "z_probs", lambda: torch.rand(w_dim, z_dim), constraint=constraints.simplex + ) x_prev = None - weeks_loop = \ - pyro.vectorized_markov(name="weeks", size=len(weeks_data), dim=-1, history=history) if vectorized \ + weeks_loop = ( + pyro.vectorized_markov( + name="weeks", size=len(weeks_data), dim=-1, history=history + ) + if vectorized else pyro.markov(range(len(weeks_data)), history=history) + ) for i in weeks_loop: if isinstance(i, int) and i == 0: x_probs = x_init else: x_probs = Vindex(x_trans)[x_prev] - x_curr = pyro.sample( - "x_{}".format(i), dist.Categorical(x_probs)) - pyro.sample("y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=weeks_data[i]) + x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) + pyro.sample( + "y_{}".format(i), + dist.Categorical(Vindex(y_probs)[x_curr]), + obs=weeks_data[i], + ) x_prev = x_curr w_prev = None - days_loop = \ - pyro.vectorized_markov(name="days", size=len(days_data), dim=-1, history=history) if vectorized \ + days_loop = ( + pyro.vectorized_markov( + name="days", size=len(days_data), dim=-1, history=history + ) + if vectorized else pyro.markov(range(len(days_data)), history=history) + ) for j in days_loop: if isinstance(j, int) and j == 0: w_probs = w_init else: w_probs = Vindex(w_trans)[w_prev] - w_curr = pyro.sample( - "w_{}".format(j), dist.Categorical(w_probs)) - pyro.sample("z_{}".format(j), dist.Categorical(Vindex(z_probs)[w_curr]), obs=days_data[j]) + w_curr = pyro.sample("w_{}".format(j), dist.Categorical(w_probs)) + pyro.sample( + "z_{}".format(j), + dist.Categorical(Vindex(z_probs)[w_curr]), + obs=days_data[j], + ) w_prev = w_curr @pytest.mark.parametrize("use_replay", [True, False]) -@pytest.mark.parametrize("model,weeks_data,days_data,vars1,vars2,history", [ - (model_8, torch.ones(3), torch.zeros(9), "xy", "wz", 1), - (model_8, torch.ones(30), torch.zeros(50), "xy", "wz", 1), -]) -def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): +@pytest.mark.parametrize( + "model,weeks_data,days_data,vars1,vars2,history", + [ + (model_8, torch.ones(3), torch.zeros(9), "xy", "wz", 1), + (model_8, torch.ones(30), torch.zeros(50), "xy", "wz", 1), + ], +) +def test_enumeration_multi( + model, weeks_data, days_data, vars1, vars2, history, use_replay +): pyro.clear_param_store() with pyro_backend("contrib.funsor"): with handlers.enum(): enum_model = infer.config_enumerate(model, default="parallel") # sequential factors - trace = handlers.trace(enum_model).get_trace(weeks_data, days_data, history, False) + trace = handlers.trace(enum_model).get_trace( + weeks_data, days_data, history, False + ) # vectorized trace if use_replay: - guide_trace = handlers.trace(_guide_from_model(model)).get_trace(weeks_data, days_data, history, True) + guide_trace = handlers.trace(_guide_from_model(model)).get_trace( + weeks_data, days_data, history, True + ) vectorized_trace = handlers.trace( - handlers.replay(model, trace=guide_trace)).get_trace(weeks_data, days_data, history, True) + handlers.replay(model, trace=guide_trace) + ).get_trace(weeks_data, days_data, history, True) else: - vectorized_trace = handlers.trace(enum_model).get_trace(weeks_data, days_data, history, True) + vectorized_trace = handlers.trace(enum_model).get_trace( + weeks_data, days_data, history, True + ) factors = list() # sequential weeks factors @@ -426,29 +621,47 @@ def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, # vectorized weeks factors for i in range(history): for v in vars1: - vectorized_factors.append(vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) + vectorized_factors.append( + vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"] + ) for i in range(history, len(weeks_data)): for v in vars1: vectorized_factors.append( vectorized_trace.nodes[ - "{}_{}".format(v, slice(history, len(weeks_data)))]["funsor"]["log_prob"] - (**{"weeks": i-history}, - **{"{}_{}".format(k, slice(history-j, len(weeks_data)-j)): "{}_{}".format(k, i-j) - for j in range(history+1) for k in vars1}) + "{}_{}".format(v, slice(history, len(weeks_data))) + ]["funsor"]["log_prob"]( + **{"weeks": i - history}, + **{ + "{}_{}".format( + k, slice(history - j, len(weeks_data) - j) + ): "{}_{}".format(k, i - j) + for j in range(history + 1) + for k in vars1 + } ) + ) # vectorized days factors for i in range(history): for v in vars2: - vectorized_factors.append(vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) + vectorized_factors.append( + vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"] + ) for i in range(history, len(days_data)): for v in vars2: vectorized_factors.append( vectorized_trace.nodes[ - "{}_{}".format(v, slice(history, len(days_data)))]["funsor"]["log_prob"] - (**{"days": i-history}, - **{"{}_{}".format(k, slice(history-j, len(days_data)-j)): "{}_{}".format(k, i-j) - for j in range(history+1) for k in vars2}) + "{}_{}".format(v, slice(history, len(days_data))) + ]["funsor"]["log_prob"]( + **{"days": i - history}, + **{ + "{}_{}".format( + k, slice(history - j, len(days_data) - j) + ): "{}_{}".format(k, i - j) + for j in range(history + 1) + for k in vars2 + } ) + ) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): @@ -461,8 +674,10 @@ def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, # expected step: assume that all but the last var is markov expected_weeks_step = frozenset() for v in vars1[:-1]: - v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ - + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1)) + v_step = tuple("{}_{}".format(v, i) for i in range(history)) + tuple( + "{}_{}".format(v, slice(j, len(weeks_data) - history + j)) + for j in range(history + 1) + ) expected_weeks_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: @@ -472,8 +687,10 @@ def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, # expected step: assume that all but the last var is markov expected_days_step = frozenset() for v in vars2[:-1]: - v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ - + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1)) + v_step = tuple("{}_{}".format(v, i) for i in range(history)) + tuple( + "{}_{}".format(v, slice(j, len(days_data) - history + j)) + for j in range(history + 1) + ) expected_days_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: @@ -492,18 +709,21 @@ def guide_empty(data, history, vectorized): @pytest.mark.xfail(reason="funsor version drift") -@pytest.mark.parametrize("model,guide,data,history", [ - (model_0, guide_empty, torch.rand(3, 5, 4), 1), - (model_1, guide_empty, torch.rand(5, 4), 1), - (model_2, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), - (model_3, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), - (model_4, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), - (model_5, guide_empty, torch.ones((5, 4), dtype=torch.long), 2), - (model_6, guide_empty, torch.rand(5, 4), 1), - (model_6, guide_empty, torch.rand(100, 4), 1), - (model_7, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), - (model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1), -]) +@pytest.mark.parametrize( + "model,guide,data,history", + [ + (model_0, guide_empty, torch.rand(3, 5, 4), 1), + (model_1, guide_empty, torch.rand(5, 4), 1), + (model_2, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_3, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_4, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_5, guide_empty, torch.ones((5, 4), dtype=torch.long), 2), + (model_6, guide_empty, torch.rand(5, 4), 1), + (model_6, guide_empty, torch.rand(100, 4), 1), + (model_7, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1), + ], +) def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() @@ -512,11 +732,15 @@ def test_model_enumerated_elbo(model, guide, data, history): model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) - expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + expected_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) - actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + actual_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): @@ -528,10 +752,13 @@ def guide_empty_multi(weeks_data, days_data, history, vectorized): @pytest.mark.xfail(reason="funsor version drift") -@pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [ - (model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1), - (model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1), -]) +@pytest.mark.parametrize( + "model,guide,weeks_data,days_data,history", + [ + (model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1), + (model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1), + ], +) def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history): pyro.clear_param_store() @@ -539,12 +766,20 @@ def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, histor model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.loss_and_grads(model, guide, weeks_data, days_data, history, False) - expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + expected_loss = elbo.loss_and_grads( + model, guide, weeks_data, days_data, history, False + ) + expected_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.loss_and_grads(model, guide, weeks_data, days_data, history, True) - actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + actual_loss = vectorized_elbo.loss_and_grads( + model, guide, weeks_data, days_data, history, True + ) + actual_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): @@ -553,51 +788,64 @@ def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, histor def model_10(data, history, vectorized): init_probs = torch.tensor([0.5, 0.5]) - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - emission_probs = pyro.param("emission_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + emission_probs = pyro.param( + "emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = None - markov_loop = \ - pyro.vectorized_markov(name="time", size=len(data), history=history) if vectorized \ + markov_loop = ( + pyro.vectorized_markov(name="time", size=len(data), history=history) + if vectorized else pyro.markov(range(len(data)), history=history) + ) for i in markov_loop: probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i]) -@pytest.mark.parametrize("model,guide,data,history", [ - (model_0, _guide_from_model(model_0), torch.rand(3, 5, 4), 1), - (model_1, _guide_from_model(model_1), torch.rand(5, 4), 1), - (model_2, _guide_from_model(model_2), torch.ones((5, 4), dtype=torch.long), 1), - (model_3, _guide_from_model(model_3), torch.ones((5, 4), dtype=torch.long), 1), - (model_4, _guide_from_model(model_4), torch.ones((5, 4), dtype=torch.long), 1), - (model_5, _guide_from_model(model_5), torch.ones((5, 4), dtype=torch.long), 2), - (model_6, _guide_from_model(model_6), torch.rand(5, 4), 1), - (model_7, _guide_from_model(model_7), torch.ones((5, 4), dtype=torch.long), 1), - (model_10, _guide_from_model(model_10), torch.ones(5), 1), -]) +@pytest.mark.parametrize( + "model,guide,data,history", + [ + (model_0, _guide_from_model(model_0), torch.rand(3, 5, 4), 1), + (model_1, _guide_from_model(model_1), torch.rand(5, 4), 1), + (model_2, _guide_from_model(model_2), torch.ones((5, 4), dtype=torch.long), 1), + (model_3, _guide_from_model(model_3), torch.ones((5, 4), dtype=torch.long), 1), + (model_4, _guide_from_model(model_4), torch.ones((5, 4), dtype=torch.long), 1), + (model_5, _guide_from_model(model_5), torch.ones((5, 4), dtype=torch.long), 2), + (model_6, _guide_from_model(model_6), torch.rand(5, 4), 1), + (model_7, _guide_from_model(model_7), torch.ones((5, 4), dtype=torch.long), 1), + (model_10, _guide_from_model(model_10), torch.ones(5), 1), + ], +) def test_guide_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() - with pyro_backend("contrib.funsor"), \ - pytest.raises( - NotImplementedError, - match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): + with pyro_backend("contrib.funsor"), pytest.raises( + NotImplementedError, + match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration", + ): if history > 1: pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) - expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + expected_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) - actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) + actual_grads = ( + value.grad for name, value in pyro.get_param_store().named_parameters() + ) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): diff --git a/tests/contrib/gp/test_conditional.py b/tests/contrib/gp/test_conditional.py index 693dbf0e14..a12721c8c2 100644 --- a/tests/contrib/gp/test_conditional.py +++ b/tests/contrib/gp/test_conditional.py @@ -11,11 +11,12 @@ from pyro.contrib.gp.util import conditional from tests.common import assert_equal -T = namedtuple("TestConditional", ["Xnew", "X", "kernel", "f_loc", "f_scale_tril", - "loc", "cov"]) +T = namedtuple( + "TestConditional", ["Xnew", "X", "kernel", "f_loc", "f_scale_tril", "loc", "cov"] +) -Xnew = torch.tensor([[2., 3.], [4., 6.]]) -X = torch.tensor([[1., 5.], [2., 1.], [3., 2.]]) +Xnew = torch.tensor([[2.0, 3.0], [4.0, 6.0]]) +X = torch.tensor([[1.0, 5.0], [2.0, 1.0], [3.0, 2.0]]) kernel = Matern52(input_dim=2) Kff = kernel(X) + torch.eye(3) * 1e-6 Lff = torch.linalg.cholesky(Kff) @@ -25,34 +26,29 @@ f_cov = f_scale_tril.matmul(f_scale_tril.t()) TEST_CASES = [ + T(Xnew, X, kernel, torch.zeros(3), Lff, torch.zeros(2), None), + T(Xnew, X, kernel, torch.zeros(3), None, torch.zeros(2), None), + T(Xnew, X, kernel, f_loc, Lff, None, kernel(Xnew)), + T(X, X, kernel, f_loc, f_scale_tril, f_loc, f_cov), + T(X, X, kernel, f_loc, None, f_loc, torch.zeros(3, 3)), T( - Xnew, X, kernel, torch.zeros(3), Lff, torch.zeros(2), None - ), - T( - Xnew, X, kernel, torch.zeros(3), None, torch.zeros(2), None - ), - T( - Xnew, X, kernel, f_loc, Lff, None, kernel(Xnew) - ), - T( - X, X, kernel, f_loc, f_scale_tril, f_loc, f_cov - ), - T( - X, X, kernel, f_loc, None, f_loc, torch.zeros(3, 3) - ), - T( - Xnew, X, WhiteNoise(input_dim=2), f_loc, f_scale_tril, torch.zeros(2), torch.eye(2) - ), - T( - Xnew, X, WhiteNoise(input_dim=2), f_loc, None, torch.zeros(2), torch.eye(2) + Xnew, + X, + WhiteNoise(input_dim=2), + f_loc, + f_scale_tril, + torch.zeros(2), + torch.eye(2), ), + T(Xnew, X, WhiteNoise(input_dim=2), f_loc, None, torch.zeros(2), torch.eye(2)), ] TEST_IDS = [str(i) for i in range(len(TEST_CASES))] -@pytest.mark.parametrize("Xnew, X, kernel, f_loc, f_scale_tril, loc, cov", - TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "Xnew, X, kernel, f_loc, f_scale_tril, loc, cov", TEST_CASES, ids=TEST_IDS +) def test_conditional(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov): loc0, cov0 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True) loc1, var1 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=False) @@ -61,26 +57,31 @@ def test_conditional(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov): assert_equal(loc0, loc) assert_equal(loc1, loc) n = cov0.shape[-1] - var0 = torch.stack([mat.diag() for mat in cov0.view(-1, n, n)]).reshape(cov0.shape[:-1]) + var0 = torch.stack([mat.diag() for mat in cov0.view(-1, n, n)]).reshape( + cov0.shape[:-1] + ) assert_equal(var0, var1) if cov is not None: assert_equal(cov0, cov) -@pytest.mark.parametrize("Xnew, X, kernel, f_loc, f_scale_tril, loc, cov", - TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "Xnew, X, kernel, f_loc, f_scale_tril, loc, cov", TEST_CASES, ids=TEST_IDS +) def test_conditional_whiten(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov): if f_scale_tril is None: return - loc0, cov0 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True, - whiten=False) + loc0, cov0 = conditional( + Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True, whiten=False + ) Kff = kernel(X) + torch.eye(3) * 1e-6 Lff = torch.linalg.cholesky(Kff) whiten_f_loc = Lff.inverse().matmul(f_loc) whiten_f_scale_tril = Lff.inverse().matmul(f_scale_tril) - loc1, cov1 = conditional(Xnew, X, kernel, whiten_f_loc, whiten_f_scale_tril, - full_cov=True, whiten=True) + loc1, cov1 = conditional( + Xnew, X, kernel, whiten_f_loc, whiten_f_scale_tril, full_cov=True, whiten=True + ) assert_equal(loc0, loc1) assert_equal(cov0, cov1) diff --git a/tests/contrib/gp/test_kernels.py b/tests/contrib/gp/test_kernels.py index cc6363dd3b..5a9d47c4da 100644 --- a/tests/contrib/gp/test_kernels.py +++ b/tests/contrib/gp/test_kernels.py @@ -36,89 +36,52 @@ Z = torch.tensor([[4.0, 5.0, 6.0], [3.0, 1.0, 7.0], [3.0, 1.0, 2.0]]) TEST_CASES = [ - T( - Constant(3, variance), - X=X, Z=Z, K_sum=18 - ), + T(Constant(3, variance), X=X, Z=Z, K_sum=18), T( Brownian(1, variance), # only work on 1D input - X=X[:, 0], Z=Z[:, 0], K_sum=27 - ), - T( - Cosine(3, variance, lengthscale), - X=X, Z=Z, K_sum=-0.193233 - ), - T( - Linear(3, variance), - X=X, Z=Z, K_sum=291 - ), - T( - Exponential(3, variance, lengthscale), - X=X, Z=Z, K_sum=2.685679 - ), - T( - Matern32(3, variance, lengthscale), - X=X, Z=Z, K_sum=3.229314 - ), - T( - Matern52(3, variance, lengthscale), - X=X, Z=Z, K_sum=3.391847 - ), - T( - Periodic(3, variance, lengthscale, period=torch.ones(1)), - X=X, Z=Z, K_sum=18 - ), - T( - Polynomial(3, variance, degree=2), - X=X, Z=Z, K_sum=7017 - ), + X=X[:, 0], + Z=Z[:, 0], + K_sum=27, + ), + T(Cosine(3, variance, lengthscale), X=X, Z=Z, K_sum=-0.193233), + T(Linear(3, variance), X=X, Z=Z, K_sum=291), + T(Exponential(3, variance, lengthscale), X=X, Z=Z, K_sum=2.685679), + T(Matern32(3, variance, lengthscale), X=X, Z=Z, K_sum=3.229314), + T(Matern52(3, variance, lengthscale), X=X, Z=Z, K_sum=3.391847), + T(Periodic(3, variance, lengthscale, period=torch.ones(1)), X=X, Z=Z, K_sum=18), + T(Polynomial(3, variance, degree=2), X=X, Z=Z, K_sum=7017), T( RationalQuadratic(3, variance, lengthscale, scale_mixture=torch.ones(1)), - X=X, Z=Z, K_sum=5.684670 - ), - T( - RBF(3, variance, lengthscale), - X=X, Z=Z, K_sum=3.681117 - ), - T( - WhiteNoise(3, variance, lengthscale), - X=X, Z=Z, K_sum=0 - ), - T( - WhiteNoise(3, variance, lengthscale), - X=X, Z=None, K_sum=6 + X=X, + Z=Z, + K_sum=5.684670, ), + T(RBF(3, variance, lengthscale), X=X, Z=Z, K_sum=3.681117), + T(WhiteNoise(3, variance, lengthscale), X=X, Z=Z, K_sum=0), + T(WhiteNoise(3, variance, lengthscale), X=X, Z=None, K_sum=6), T( Coregionalize(3, components=torch.eye(3, 3)), - X=torch.tensor([[1., 0., 0.], - [0.5, 0., 0.5]]), - Z=torch.tensor([[1., 0., 0.], - [0., 1., 0.]]), + X=torch.tensor([[1.0, 0.0, 0.0], [0.5, 0.0, 0.5]]), + Z=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), K_sum=2.25, ), T( Coregionalize(3, rank=2), - X=torch.tensor([[1., 0., 0.], - [0.5, 0., 0.5]]), - Z=torch.tensor([[1., 0., 0.], - [0., 1., 0.]]), + X=torch.tensor([[1.0, 0.0, 0.0], [0.5, 0.0, 0.5]]), + Z=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), K_sum=None, # kernel is randomly initialized ), T( Coregionalize(3), - X=torch.tensor([[1., 0., 0.], - [0.5, 0., 0.5]]), - Z=torch.tensor([[1., 0., 0.], - [0., 1., 0.]]), + X=torch.tensor([[1.0, 0.0, 0.0], [0.5, 0.0, 0.5]]), + Z=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), K_sum=None, # kernel is randomly initialized ), T( Coregionalize(3, rank=2, diagonal=0.01 * torch.ones(3)), - X=torch.tensor([[1., 0., 0.], - [0.5, 0., 0.5]]), - Z=torch.tensor([[1., 0., 0.], - [0., 1., 0.]]), + X=torch.tensor([[1.0, 0.0, 0.0], [0.5, 0.0, 0.5]]), + Z=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), K_sum=None, # kernel is randomly initialized ), ] @@ -133,7 +96,9 @@ def test_kernel_forward(kernel, X, Z, K_sum): if K_sum is not None: assert_equal(K.sum().item(), K_sum) assert_equal(kernel(X).diag(), kernel(X, diag=True)) - if not isinstance(kernel, WhiteNoise): # WhiteNoise avoids computing a delta function by assuming X != Z + if not isinstance( + kernel, WhiteNoise + ): # WhiteNoise avoids computing a delta function by assuming X != Z assert_equal(kernel(X), kernel(X, X)) if Z is not None: assert_equal(kernel(X, Z), kernel(Z, X).t()) @@ -141,7 +106,7 @@ def test_kernel_forward(kernel, X, Z, K_sum): def test_combination(): k0 = TEST_CASES[0][0] - k5 = TEST_CASES[5][0] # TEST_CASES[1] is Brownian, only work for 1D + k5 = TEST_CASES[5][0] # TEST_CASES[1] is Brownian, only work for 1D k2 = TEST_CASES[2][0] k3 = TEST_CASES[3][0] k4 = TEST_CASES[4][0] @@ -172,7 +137,7 @@ def vscaling_fn(x): return x.sum(dim=1) def iwarping_fn(x): - return x**2 + return x ** 2 owarping_coef = [2, 0, 1, 3, 0] diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index c63c1ce9a5..b193bdadbb 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -14,7 +14,7 @@ T = namedtuple("TestGPLikelihood", ["model_class", "X", "y", "kernel", "likelihood"]) X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0], [3.0, 4.0, 6.0]]) -kernel = RBF(input_dim=3, variance=torch.tensor(1.), lengthscale=torch.tensor(3.)) +kernel = RBF(input_dim=3, variance=torch.tensor(1.0), lengthscale=torch.tensor(3.0)) noise = torch.tensor(1e-6) y_binary1D = torch.tensor([0.0, 1.0, 0.0]) y_binary2D = torch.tensor([[0.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) @@ -27,62 +27,31 @@ multiclass_likelihood = MultiClass(num_classes=3) TEST_CASES = [ - T( - VariationalGP, - X, y_binary1D, kernel, binary_likelihood - ), - T( - VariationalGP, - X, y_binary2D, kernel, binary_likelihood - ), - T( - VariationalGP, - X, y_multiclass1D, kernel, multiclass_likelihood - ), - T( - VariationalGP, - X, y_multiclass2D, kernel, multiclass_likelihood - ), - T( - VariationalGP, - X, y_count1D, kernel, poisson_likelihood - ), - T( - VariationalGP, - X, y_count2D, kernel, poisson_likelihood - ), - T( - VariationalSparseGP, - X, y_binary1D, kernel, binary_likelihood - ), - T( - VariationalSparseGP, - X, y_binary2D, kernel, binary_likelihood - ), - T( - VariationalSparseGP, - X, y_multiclass1D, kernel, multiclass_likelihood - ), - T( - VariationalSparseGP, - X, y_multiclass2D, kernel, multiclass_likelihood - ), - T( - VariationalSparseGP, - X, y_count1D, kernel, poisson_likelihood - ), - T( - VariationalSparseGP, - X, y_count2D, kernel, poisson_likelihood - ), + T(VariationalGP, X, y_binary1D, kernel, binary_likelihood), + T(VariationalGP, X, y_binary2D, kernel, binary_likelihood), + T(VariationalGP, X, y_multiclass1D, kernel, multiclass_likelihood), + T(VariationalGP, X, y_multiclass2D, kernel, multiclass_likelihood), + T(VariationalGP, X, y_count1D, kernel, poisson_likelihood), + T(VariationalGP, X, y_count2D, kernel, poisson_likelihood), + T(VariationalSparseGP, X, y_binary1D, kernel, binary_likelihood), + T(VariationalSparseGP, X, y_binary2D, kernel, binary_likelihood), + T(VariationalSparseGP, X, y_multiclass1D, kernel, multiclass_likelihood), + T(VariationalSparseGP, X, y_multiclass2D, kernel, multiclass_likelihood), + T(VariationalSparseGP, X, y_count1D, kernel, poisson_likelihood), + T(VariationalSparseGP, X, y_count2D, kernel, poisson_likelihood), ] -TEST_IDS = ["_".join([t[0].__name__, t[4].__class__.__name__.split(".")[-1], - str(t[2].dim()) + "D"]) - for t in TEST_CASES] +TEST_IDS = [ + "_".join( + [t[0].__name__, t[4].__class__.__name__.split(".")[-1], str(t[2].dim()) + "D"] + ) + for t in TEST_CASES +] -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS +) def test_inference(model_class, X, y, kernel, likelihood): if isinstance(likelihood, MultiClass): latent_shape = y.shape[:-1] + (likelihood.num_classes,) @@ -96,7 +65,9 @@ def test_inference(model_class, X, y, kernel, likelihood): train(gp, num_steps=1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS +) def test_inference_with_empty_latent_shape(model_class, X, y, kernel, likelihood): if isinstance(likelihood, MultiClass): latent_shape = torch.Size([likelihood.num_classes]) @@ -110,7 +81,9 @@ def test_inference_with_empty_latent_shape(model_class, X, y, kernel, likelihood train(gp, num_steps=1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS +) def test_forward(model_class, X, y, kernel, likelihood): if isinstance(likelihood, MultiClass): latent_shape = y.shape[:-1] + (likelihood.num_classes,) @@ -129,7 +102,9 @@ def test_forward(model_class, X, y, kernel, likelihood): assert ynew.shape == y.shape[:-1] + (Xnew.shape[0],) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", TEST_CASES, ids=TEST_IDS +) def test_forward_with_empty_latent_shape(model_class, X, y, kernel, likelihood): if isinstance(likelihood, MultiClass): latent_shape = torch.Size([likelihood.num_classes]) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 2aa38ea8f2..1a93c72db8 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -27,14 +27,14 @@ T = namedtuple("TestGPModel", ["model_class", "X", "y", "kernel", "likelihood"]) -X = torch.tensor([[1., 5., 3.], [4., 3., 7.]]) -y1D = torch.tensor([2., 1.]) -y2D = torch.tensor([[1., 2.], [3., 3.], [1., 4.], [-1., 1.]]) +X = torch.tensor([[1.0, 5.0, 3.0], [4.0, 3.0, 7.0]]) +y1D = torch.tensor([2.0, 1.0]) +y2D = torch.tensor([[1.0, 2.0], [3.0, 3.0], [1.0, 4.0], [-1.0, 1.0]]) noise = torch.tensor(1e-7) def _kernel(): - return RBF(input_dim=3, variance=torch.tensor(3.), lengthscale=torch.tensor(2.)) + return RBF(input_dim=3, variance=torch.tensor(3.0), lengthscale=torch.tensor(2.0)) def _likelihood(): @@ -43,48 +43,25 @@ def _likelihood(): def _TEST_CASES(): TEST_CASES = [ - T( - GPRegression, - X, y1D, _kernel(), noise - ), - T( - GPRegression, - X, y2D, _kernel(), noise - ), - T( - SparseGPRegression, - X, y1D, _kernel(), noise - ), - T( - SparseGPRegression, - X, y2D, _kernel(), noise - ), - T( - VariationalGP, - X, y1D, _kernel(), _likelihood() - ), - T( - VariationalGP, - X, y2D, _kernel(), _likelihood() - ), - T( - VariationalSparseGP, - X, y1D, _kernel(), _likelihood() - ), - T( - VariationalSparseGP, - X, y2D, _kernel(), _likelihood() - ), + T(GPRegression, X, y1D, _kernel(), noise), + T(GPRegression, X, y2D, _kernel(), noise), + T(SparseGPRegression, X, y1D, _kernel(), noise), + T(SparseGPRegression, X, y2D, _kernel(), noise), + T(VariationalGP, X, y1D, _kernel(), _likelihood()), + T(VariationalGP, X, y2D, _kernel(), _likelihood()), + T(VariationalSparseGP, X, y1D, _kernel(), _likelihood()), + T(VariationalSparseGP, X, y2D, _kernel(), _likelihood()), ] return TEST_CASES -TEST_IDS = [t[0].__name__ + "_y{}D".format(str(t[2].dim())) - for t in _TEST_CASES()] +TEST_IDS = [t[0].__name__ + "_y{}D".format(str(t[2].dim())) for t in _TEST_CASES()] -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_model(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, None, kernel, X, likelihood) @@ -100,7 +77,9 @@ def test_model(model_class, X, y, kernel, likelihood): assert_equal(var, kernel(X).diag()) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_forward(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, y, kernel, X, likelihood) @@ -120,7 +99,9 @@ def test_forward(model_class, X, y, kernel, likelihood): assert cov0.shape[-1] == Xnew.shape[0] assert_equal(loc0, loc1) n = Xnew.shape[0] - cov0_diag = torch.stack([mat.diag() for mat in cov0.view(-1, n, n)]).reshape(var1.shape) + cov0_diag = torch.stack([mat.diag() for mat in cov0.view(-1, n, n)]).reshape( + var1.shape + ) assert_equal(cov0_diag, var1) # test trivial forward: Xnew = X @@ -141,13 +122,15 @@ def test_forward(model_class, X, y, kernel, likelihood): assert_equal(cov_diff.norm().item(), 0) # test noise kernel forward: kernel = WhiteNoise - gp.kernel = WhiteNoise(input_dim=3, variance=torch.tensor(10.)) + gp.kernel = WhiteNoise(input_dim=3, variance=torch.tensor(10.0)) loc, cov = gp(X, full_cov=True) assert_equal(loc.norm().item(), 0) assert_equal(cov, torch.eye(cov.shape[-1]).expand(cov.shape) * 10) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_forward_with_empty_latent_shape(model_class, X, y, kernel, likelihood): # regression models don't use latent_shape, no need for test if model_class is GPRegression or model_class is SparseGPRegression: @@ -171,7 +154,9 @@ def test_forward_with_empty_latent_shape(model_class, X, y, kernel, likelihood): assert_equal(cov0.diag(), var1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) @pytest.mark.init(rng_seed=0) def test_inference(model_class, X, y, kernel, likelihood): # skip variational GP models because variance/lengthscale highly @@ -199,17 +184,20 @@ def test_inference(model_class, X, y, kernel, likelihood): @pytest.mark.init(rng_seed=0) def test_inference_sgpr(): N = 1000 - X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() - y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() + X = dist.Uniform(torch.zeros(N), torch.ones(N) * 5).sample() + y = ( + 0.5 * torch.sin(3 * X) + + dist.Normal(torch.zeros(N), torch.ones(N) * 0.5).sample() + ) kernel = RBF(input_dim=1) - Xu = torch.arange(0., 5.5, 0.5) + Xu = torch.arange(0.0, 5.5, 0.5) sgpr = SparseGPRegression(X, y, kernel, Xu) train(sgpr) - Xnew = torch.arange(0., 5.05, 0.05) + Xnew = torch.arange(0.0, 5.05, 0.05) loc, var = sgpr(Xnew, full_cov=False) - target = 0.5 * torch.sin(3*Xnew) + target = 0.5 * torch.sin(3 * Xnew) assert_equal((loc - target).abs().mean().item(), 0, prec=0.07) @@ -217,18 +205,21 @@ def test_inference_sgpr(): @pytest.mark.init(rng_seed=0) def test_inference_vsgp(): N = 1000 - X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() - y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() + X = dist.Uniform(torch.zeros(N), torch.ones(N) * 5).sample() + y = ( + 0.5 * torch.sin(3 * X) + + dist.Normal(torch.zeros(N), torch.ones(N) * 0.5).sample() + ) kernel = RBF(input_dim=1) - Xu = torch.arange(0., 5.5, 0.5) + Xu = torch.arange(0.0, 5.5, 0.5) vsgp = VariationalSparseGP(X, y, kernel, Xu, Gaussian()) optimizer = torch.optim.Adam(vsgp.parameters(), lr=0.03) train(vsgp, optimizer) - Xnew = torch.arange(0., 5.05, 0.05) + Xnew = torch.arange(0.0, 5.05, 0.05) loc, var = vsgp(Xnew, full_cov=False) - target = 0.5 * torch.sin(3*Xnew) + target = 0.5 * torch.sin(3 * Xnew) assert_equal((loc - target).abs().mean().item(), 0, prec=0.06) @@ -236,22 +227,27 @@ def test_inference_vsgp(): @pytest.mark.init(rng_seed=0) def test_inference_whiten_vsgp(): N = 1000 - X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() - y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() + X = dist.Uniform(torch.zeros(N), torch.ones(N) * 5).sample() + y = ( + 0.5 * torch.sin(3 * X) + + dist.Normal(torch.zeros(N), torch.ones(N) * 0.5).sample() + ) kernel = RBF(input_dim=1) - Xu = torch.arange(0., 5.5, 0.5) + Xu = torch.arange(0.0, 5.5, 0.5) vsgp = VariationalSparseGP(X, y, kernel, Xu, Gaussian(), whiten=True) train(vsgp) - Xnew = torch.arange(0., 5.05, 0.05) + Xnew = torch.arange(0.0, 5.05, 0.05) loc, var = vsgp(Xnew, full_cov=False) - target = 0.5 * torch.sin(3*Xnew) + target = 0.5 * torch.sin(3 * Xnew) assert_equal((loc - target).abs().mean().item(), 0, prec=0.07) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_inference_with_empty_latent_shape(model_class, X, y, kernel, likelihood): # regression models don't use latent_shape (default=torch.Size([])) if model_class is GPRegression or model_class is SparseGPRegression: @@ -259,12 +255,16 @@ def test_inference_with_empty_latent_shape(model_class, X, y, kernel, likelihood elif model_class is VariationalGP: gp = model_class(X, y, kernel, likelihood, latent_shape=torch.Size([])) else: # model_class is SparseVariationalGP - gp = model_class(X, y, kernel, X.clone(), likelihood, latent_shape=torch.Size([])) + gp = model_class( + X, y, kernel, X.clone(), likelihood, latent_shape=torch.Size([]) + ) train(gp, num_steps=1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_inference_with_whiten(model_class, X, y, kernel, likelihood): # regression models don't use whiten if model_class is GPRegression or model_class is SparseGPRegression: @@ -277,7 +277,9 @@ def test_inference_with_whiten(model_class, X, y, kernel, likelihood): train(gp, num_steps=1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_hmc(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, y, kernel, X.clone(), likelihood) @@ -298,11 +300,15 @@ def test_hmc(model_class, X, y, kernel, likelihood): def test_inference_deepGP(): - gp1 = GPRegression(X, None, RBF(input_dim=3, variance=torch.tensor(3.), - lengthscale=torch.tensor(2.))) + gp1 = GPRegression( + X, + None, + RBF(input_dim=3, variance=torch.tensor(3.0), lengthscale=torch.tensor(2.0)), + ) Z, _ = gp1.model() - gp2 = VariationalSparseGP(Z, y2D, Matern32(input_dim=3), Z.clone(), - Gaussian(torch.tensor(1e-6))) + gp2 = VariationalSparseGP( + Z, y2D, Matern32(input_dim=3), Z.clone(), Gaussian(torch.tensor(1e-6)) + ) class DeepGP(torch.nn.Module): def __init__(self, gp1, gp2): @@ -323,7 +329,9 @@ def guide(self): train(deepgp, num_steps=1) -@pytest.mark.parametrize("model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS) +@pytest.mark.parametrize( + "model_class, X, y, kernel, likelihood", _TEST_CASES(), ids=TEST_IDS +) def test_gplvm(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, y, kernel, X.clone(), likelihood) @@ -351,8 +359,8 @@ def f(x): class Trend(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.nn.Parameter(torch.tensor(0.)) - self.b = torch.nn.Parameter(torch.tensor(1.)) + self.a = torch.nn.Parameter(torch.tensor(0.0)) + self.b = torch.nn.Parameter(torch.tensor(1.0)) def forward(self, x): return self.a * x + self.b @@ -399,7 +407,9 @@ def test_mean_function_SGPR_DTC(): def test_mean_function_SGPR_FITC(): X, y, Xnew, ynew, kernel, mean_fn = _pre_test_mean_function() Xu = X[::20].clone() - gpmodule = SparseGPRegression(X, y, kernel, Xu, mean_function=mean_fn, approx="FITC") + gpmodule = SparseGPRegression( + X, y, kernel, Xu, mean_function=mean_fn, approx="FITC" + ) train(gpmodule) _post_test_mean_function(gpmodule, Xnew, ynew) @@ -415,8 +425,9 @@ def test_mean_function_VGP(): def test_mean_function_VGP_whiten(): X, y, Xnew, ynew, kernel, mean_fn = _pre_test_mean_function() likelihood = Gaussian() - gpmodule = VariationalGP(X, y, kernel, likelihood, mean_function=mean_fn, - whiten=True) + gpmodule = VariationalGP( + X, y, kernel, likelihood, mean_function=mean_fn, whiten=True + ) optimizer = torch.optim.Adam(gpmodule.parameters(), lr=0.1) train(gpmodule, optimizer) _post_test_mean_function(gpmodule, Xnew, ynew) @@ -436,8 +447,9 @@ def test_mean_function_VSGP_whiten(): X, y, Xnew, ynew, kernel, mean_fn = _pre_test_mean_function() Xu = X[::20].clone() likelihood = Gaussian() - gpmodule = VariationalSparseGP(X, y, kernel, Xu, likelihood, mean_function=mean_fn, - whiten=True) + gpmodule = VariationalSparseGP( + X, y, kernel, Xu, likelihood, mean_function=mean_fn, whiten=True + ) optimizer = torch.optim.Adam(gpmodule.parameters(), lr=0.1) train(gpmodule, optimizer) _post_test_mean_function(gpmodule, Xnew, ynew) diff --git a/tests/contrib/gp/test_parameterized.py b/tests/contrib/gp/test_parameterized.py index 45728832df..f7ab5d3c33 100644 --- a/tests/contrib/gp/test_parameterized.py +++ b/tests/contrib/gp/test_parameterized.py @@ -17,7 +17,7 @@ class Linear(Parameterized): def __init__(self): super().__init__() self._pyro_name = "Linear" - self.a = PyroParam(torch.tensor(1.), constraints.positive) + self.a = PyroParam(torch.tensor(1.0), constraints.positive) self.b = PyroSample(dist.Normal(0, 1)) self.c = PyroSample(dist.Normal(0, 1)) self.d = PyroSample(dist.Normal(0, 4).expand([1]).to_event()) @@ -26,7 +26,16 @@ def __init__(self): self.g = PyroSample(dist.Exponential(1)) def forward(self, x): - return self.a * x + self.b + self.c + self.d + self.e + self.f + self.g + self.e + return ( + self.a * x + + self.b + + self.c + + self.d + + self.e + + self.f + + self.g + + self.e + ) linear = Linear() linear.autoguide("c", dist.Normal) @@ -34,9 +43,16 @@ def forward(self, x): linear.autoguide("e", dist.Normal) assert set(dict(linear.named_parameters()).keys()) == { - "a_unconstrained", "b_map", "c_loc", "c_scale_unconstrained", - "d_loc", "d_scale_tril_unconstrained", - "e_loc", "e_scale_unconstrained", "f_map", "g_map_unconstrained" + "a_unconstrained", + "b_map", + "c_loc", + "c_scale_unconstrained", + "d_loc", + "d_scale_tril_unconstrained", + "e_loc", + "e_scale_unconstrained", + "f_map", + "g_map_unconstrained", } def model(x): @@ -47,8 +63,8 @@ def guide(x): linear.mode = "guide" return linear(x) - model_trace = pyro.poutine.trace(model).get_trace(torch.tensor(5.)) - guide_trace = pyro.poutine.trace(guide).get_trace(torch.tensor(5.)) + model_trace = pyro.poutine.trace(model).get_trace(torch.tensor(5.0)) + guide_trace = pyro.poutine.trace(guide).get_trace(torch.tensor(5.0)) for p in ["b", "c", "d"]: assert "Linear.{}".format(p) in model_trace.nodes assert "Linear.{}".format(p) in guide_trace.nodes @@ -80,18 +96,18 @@ def __init__(self, linear1, linear2, a): def forward(self, x): return self.linear1(x) * x + self.linear2(self.a) - linear1 = Linear(torch.tensor(1.)) + linear1 = Linear(torch.tensor(1.0)) linear1.a = PyroSample(dist.Normal(0, 1)) - linear2 = Linear(torch.tensor(1.)) + linear2 = Linear(torch.tensor(1.0)) linear2.a = PyroSample(dist.Normal(0, 1)) - q = Quadratic(linear1, linear2, torch.tensor(2.)) + q = Quadratic(linear1, linear2, torch.tensor(2.0)) q.a = PyroSample(dist.Cauchy(0, 1)) def model(x): q.set_mode("model") return q(x) - trace = pyro.poutine.trace(model).get_trace(torch.tensor(5.)) + trace = pyro.poutine.trace(model).get_trace(torch.tensor(5.0)) assert "Quadratic.a" in trace.nodes assert "Quadratic.linear1.a" in trace.nodes assert "Quadratic.linear2.a" in trace.nodes @@ -106,10 +122,10 @@ def __init__(self, a): def forward(self, x): return self.a * x - target_a = torch.tensor(2.) + target_a = torch.tensor(2.0) x_train = torch.rand(100) y_train = target_a * x_train + torch.rand(100) * 0.001 - linear = Linear(torch.tensor(1.)) + linear = Linear(torch.tensor(1.0)) linear.a = PyroSample(dist.Normal(0, 10)) linear.autoguide("a", dist.Normal) diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index 94ba2fa02a..e1e570c72c 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -7,33 +7,32 @@ from pyro.contrib.mue.dataloaders import BiosequenceDataset, alphabets -@pytest.mark.parametrize('source_type', ['list', 'fasta']) -@pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', 'ATC']) -@pytest.mark.parametrize('include_stop', [False, True]) +@pytest.mark.parametrize("source_type", ["list", "fasta"]) +@pytest.mark.parametrize("alphabet", ["amino-acid", "dna", "ATC"]) +@pytest.mark.parametrize("include_stop", [False, True]) def test_biosequencedataset(source_type, alphabet, include_stop): # Define dataset. - seqs = ['AATC', 'CA', 'T'] + seqs = ["AATC", "CA", "T"] # Encode dataset, alternate approach. if alphabet in alphabets: - alphabet_list = list(alphabets[alphabet]) + include_stop*['*'] + alphabet_list = list(alphabets[alphabet]) + include_stop * ["*"] else: - alphabet_list = list(alphabet) + include_stop*['*'] + alphabet_list = list(alphabet) + include_stop * ["*"] L_data_check = [len(seq) + include_stop for seq in seqs] max_length_check = max(L_data_check) data_size_check = len(seqs) - seq_data_check = torch.zeros([len(seqs), max_length_check, - len(alphabet_list)]) + seq_data_check = torch.zeros([len(seqs), max_length_check, len(alphabet_list)]) for i in range(len(seqs)): - for j, s in enumerate(seqs[i] + include_stop*'*'): + for j, s in enumerate(seqs[i] + include_stop * "*"): seq_data_check[i, j, list(alphabet_list).index(s)] = 1 # Setup data source. - if source_type == 'fasta': + if source_type == "fasta": # Save as external file. - source = 'test_seqs.fasta' - with open(source, 'w') as fw: + source = "test_seqs.fasta" + with open(source, "w") as fw: text = """>one AAT C @@ -43,27 +42,31 @@ def test_biosequencedataset(source_type, alphabet, include_stop): T """ fw.write(text) - elif source_type == 'list': + elif source_type == "list": source = seqs # Load dataset. - dataset = BiosequenceDataset(source, source_type, alphabet, - include_stop=include_stop) + dataset = BiosequenceDataset( + source, source_type, alphabet, include_stop=include_stop + ) # Check. - assert torch.allclose(dataset.L_data, - torch.tensor(L_data_check, dtype=torch.float64)) + assert torch.allclose( + dataset.L_data, torch.tensor(L_data_check, dtype=torch.float64) + ) assert dataset.max_length == max_length_check assert len(dataset) == data_size_check assert dataset.data_size == data_size_check assert dataset.alphabet_length == len(alphabet_list) assert torch.allclose(dataset.seq_data, seq_data_check) ind = torch.tensor([0, 2]) - assert torch.allclose(dataset[ind][0], - torch.cat([seq_data_check[0, None, :, :], - seq_data_check[2, None, :, :]])) - assert torch.allclose(dataset[ind][1], torch.tensor([4. + include_stop, - 1. + include_stop])) + assert torch.allclose( + dataset[ind][0], + torch.cat([seq_data_check[0, None, :, :], seq_data_check[2, None, :, :]]), + ) + assert torch.allclose( + dataset[ind][1], torch.tensor([4.0 + include_stop, 1.0 + include_stop]) + ) dataload = torch.utils.data.DataLoader(dataset, batch_size=2) for seq_data, L_data in dataload: assert seq_data.shape[0] == L_data.shape[0] diff --git a/tests/contrib/mue/test_missingdatahmm.py b/tests/contrib/mue/test_missingdatahmm.py index ee19f4b31d..ba7a38e6e7 100644 --- a/tests/contrib/mue/test_missingdatahmm.py +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -14,15 +14,11 @@ def test_hmm_log_prob(): a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) e = torch.tensor([[0.99, 0.01], [0.01, 0.99], [0.5, 0.5]]) - x = torch.tensor([[0., 1.], - [1., 0.], - [0., 1.], - [0., 1.], - [1., 0.], - [0., 0.]]) - - hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), - torch.log(e)) + x = torch.tensor( + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ) + + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) lp = hmm_distr.log_prob(x) f = a0 * e[:, 1] @@ -35,14 +31,15 @@ def test_hmm_log_prob(): assert torch.allclose(lp, chk_lp) # Batch values. - x = torch.cat([ - x[None, :, :], - torch.tensor([[1., 0.], - [1., 0.], - [1., 0.], - [0., 0.], - [0., 0.], - [0., 0.]])[None, :, :]], dim=0) + x = torch.cat( + [ + x[None, :, :], + torch.tensor( + [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] + )[None, :, :], + ], + dim=0, + ) lp = hmm_distr.log_prob(x) f = a0 * e[:, 0] @@ -54,16 +51,23 @@ def test_hmm_log_prob(): # Batch both parameters and values. a0 = torch.cat([a0[None, :], torch.tensor([0.2, 0.7, 0.1])[None, :]]) - a = torch.cat([ - a[None, :, :], - torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.6, 0.2], [0.1, 0.1, 0.8]] - )[None, :, :]], dim=0) - e = torch.cat([ - e[None, :, :], - torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :]], - dim=0) - hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), - torch.log(e)) + a = torch.cat( + [ + a[None, :, :], + torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.6, 0.2], [0.1, 0.1, 0.8]])[ + None, :, : + ], + ], + dim=0, + ) + e = torch.cat( + [ + e[None, :, :], + torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :], + ], + dim=0, + ) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) lp = hmm_distr.log_prob(x) f = a0[1, :] * e[1, :, 0] @@ -74,10 +78,10 @@ def test_hmm_log_prob(): assert torch.allclose(lp, chk_lp) -@pytest.mark.parametrize('batch_initial', [False, True]) -@pytest.mark.parametrize('batch_transition', [False, True]) -@pytest.mark.parametrize('batch_observation', [False, True]) -@pytest.mark.parametrize('batch_data', [False, True]) +@pytest.mark.parametrize("batch_initial", [False, True]) +@pytest.mark.parametrize("batch_transition", [False, True]) +@pytest.mark.parametrize("batch_observation", [False, True]) +@pytest.mark.parametrize("batch_data", [False, True]) def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): # Dimensions. @@ -85,87 +89,90 @@ def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): state_dim, observation_dim, num_steps = 4, 5, 6 # Model initialization. - initial_logits = torch.randn([batch_size]*batch_initial + [state_dim]) - initial_logits = (initial_logits - - initial_logits.logsumexp(-1, True)) - transition_logits = torch.randn([batch_size]*batch_transition - + [state_dim, state_dim]) - transition_logits = (transition_logits - - transition_logits.logsumexp(-1, True)) - observation_logits = torch.randn([batch_size]*batch_observation - + [state_dim, observation_dim]) - observation_logits = (observation_logits - - observation_logits.logsumexp(-1, True)) - - hmm = MissingDataDiscreteHMM(initial_logits, transition_logits, - observation_logits) + initial_logits = torch.randn([batch_size] * batch_initial + [state_dim]) + initial_logits = initial_logits - initial_logits.logsumexp(-1, True) + transition_logits = torch.randn( + [batch_size] * batch_transition + [state_dim, state_dim] + ) + transition_logits = transition_logits - transition_logits.logsumexp(-1, True) + observation_logits = torch.randn( + [batch_size] * batch_observation + [state_dim, observation_dim] + ) + observation_logits = observation_logits - observation_logits.logsumexp(-1, True) + + hmm = MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits) # Random observations. - value = (torch.randint(observation_dim, - [batch_size]*batch_data + [num_steps]).unsqueeze(-1) - == torch.arange(observation_dim)).double() + value = ( + torch.randint( + observation_dim, [batch_size] * batch_data + [num_steps] + ).unsqueeze(-1) + == torch.arange(observation_dim) + ).double() # Log probability. lp = hmm.log_prob(value) # Check shapes: - if all([not batch_initial, not batch_transition, not batch_observation, - not batch_data]): + if all( + [not batch_initial, not batch_transition, not batch_observation, not batch_data] + ): assert lp.shape == () else: assert lp.shape == (batch_size,) -@pytest.mark.parametrize('batch_initial', [False, True]) -@pytest.mark.parametrize('batch_transition', [False, True]) -@pytest.mark.parametrize('batch_observation', [False, True]) -@pytest.mark.parametrize('batch_data', [False, True]) -def test_DiscreteHMM_comparison(batch_initial, batch_transition, - batch_observation, batch_data): +@pytest.mark.parametrize("batch_initial", [False, True]) +@pytest.mark.parametrize("batch_transition", [False, True]) +@pytest.mark.parametrize("batch_observation", [False, True]) +@pytest.mark.parametrize("batch_data", [False, True]) +def test_DiscreteHMM_comparison( + batch_initial, batch_transition, batch_observation, batch_data +): # Dimensions. batch_size = 3 state_dim, observation_dim, num_steps = 4, 5, 6 # -- Model setup --. - transition_logits_vldhmm = torch.randn([batch_size]*batch_transition - + [state_dim, state_dim]) - transition_logits_vldhmm = (transition_logits_vldhmm - - transition_logits_vldhmm.logsumexp(-1, True)) + transition_logits_vldhmm = torch.randn( + [batch_size] * batch_transition + [state_dim, state_dim] + ) + transition_logits_vldhmm = ( + transition_logits_vldhmm - transition_logits_vldhmm.logsumexp(-1, True) + ) # Adjust for DiscreteHMM broadcasting convention. transition_logits_dhmm = transition_logits_vldhmm.unsqueeze(-3) # Convert between discrete HMM convention for initial state and variable # length HMM convention. - initial_logits_dhmm = torch.randn([batch_size]*batch_initial + [state_dim]) - initial_logits_dhmm = (initial_logits_dhmm - - initial_logits_dhmm.logsumexp(-1, True)) - initial_logits_vldhmm = (initial_logits_dhmm.unsqueeze(-1) + - transition_logits_vldhmm).logsumexp(-2) - observation_logits = torch.randn([batch_size]*batch_observation - + [state_dim, observation_dim]) - observation_logits = (observation_logits - - observation_logits.logsumexp(-1, True)) + initial_logits_dhmm = torch.randn([batch_size] * batch_initial + [state_dim]) + initial_logits_dhmm = initial_logits_dhmm - initial_logits_dhmm.logsumexp(-1, True) + initial_logits_vldhmm = ( + initial_logits_dhmm.unsqueeze(-1) + transition_logits_vldhmm + ).logsumexp(-2) + observation_logits = torch.randn( + [batch_size] * batch_observation + [state_dim, observation_dim] + ) + observation_logits = observation_logits - observation_logits.logsumexp(-1, True) # Create distribution object for DiscreteHMM observation_dist = Categorical(logits=observation_logits.unsqueeze(-3)) - vldhmm = MissingDataDiscreteHMM(initial_logits_vldhmm, - transition_logits_vldhmm, - observation_logits) - dhmm = DiscreteHMM(initial_logits_dhmm, transition_logits_dhmm, - observation_dist) + vldhmm = MissingDataDiscreteHMM( + initial_logits_vldhmm, transition_logits_vldhmm, observation_logits + ) + dhmm = DiscreteHMM(initial_logits_dhmm, transition_logits_dhmm, observation_dist) # Random observations. - value = torch.randint(observation_dim, - [batch_size]*batch_data + [num_steps]) - value_oh = (value.unsqueeze(-1) - == torch.arange(observation_dim)).double() + value = torch.randint(observation_dim, [batch_size] * batch_data + [num_steps]) + value_oh = (value.unsqueeze(-1) == torch.arange(observation_dim)).double() # -- Check. -- # Log probability. lp_vldhmm = vldhmm.log_prob(value_oh) lp_dhmm = dhmm.log_prob(value) # Shapes. - if all([not batch_initial, not batch_transition, not batch_observation, - not batch_data]): + if all( + [not batch_initial, not batch_transition, not batch_observation, not batch_data] + ): assert lp_vldhmm.shape == () else: assert lp_vldhmm.shape == (batch_size,) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py index 5f2ba9634b..691313dff9 100644 --- a/tests/contrib/mue/test_models.py +++ b/tests/contrib/mue/test_models.py @@ -12,19 +12,23 @@ from pyro.optim import MultiStepLR -@pytest.mark.parametrize('jit', [False, True]) +@pytest.mark.parametrize("jit", [False, True]) def test_ProfileHMM_smoke(jit): # Setup dataset. - seqs = ['BABBA', 'BAAB', 'BABBB'] - alph = 'AB' - dataset = BiosequenceDataset(seqs, 'list', alph) + seqs = ["BABBA", "BAAB", "BABBB"] + alph = "AB" + dataset = BiosequenceDataset(seqs, "list", alph) # Infer. - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.1}, - 'milestones': [20, 100, 1000, 2000], - 'gamma': 0.5}) - model = ProfileHMM(int(dataset.max_length*1.1), dataset.alphabet_length) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": 0.1}, + "milestones": [20, 100, 1000, 2000], + "gamma": 0.5, + } + ) + model = ProfileHMM(int(dataset.max_length * 1.1), dataset.alphabet_length) n_epochs = 5 batch_size = 2 losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) @@ -33,41 +37,50 @@ def test_ProfileHMM_smoke(jit): # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( - dataset, dataset, jit) - assert train_lp < 0. - assert test_lp < 0. - assert train_perplex > 0. - assert test_perplex > 0. + dataset, dataset, jit + ) + assert train_lp < 0.0 + assert test_lp < 0.0 + assert train_perplex > 0.0 + assert test_perplex > 0.0 -@pytest.mark.parametrize('indel_factor_dependence', [False, True]) -@pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) -@pytest.mark.parametrize('ARD_prior', [False, True]) -@pytest.mark.parametrize('substitution_matrix', [False, True]) -@pytest.mark.parametrize('jit', [False, True]) -def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, - ARD_prior, substitution_matrix, jit): +@pytest.mark.parametrize("indel_factor_dependence", [False, True]) +@pytest.mark.parametrize("z_prior_distribution", ["Normal", "Laplace"]) +@pytest.mark.parametrize("ARD_prior", [False, True]) +@pytest.mark.parametrize("substitution_matrix", [False, True]) +@pytest.mark.parametrize("jit", [False, True]) +def test_FactorMuE_smoke( + indel_factor_dependence, z_prior_distribution, ARD_prior, substitution_matrix, jit +): # Setup dataset. - seqs = ['BABBA', 'BAAB', 'BABBB'] - alph = 'AB' - dataset = BiosequenceDataset(seqs, 'list', alph) + seqs = ["BABBA", "BAAB", "BABBB"] + alph = "AB" + dataset = BiosequenceDataset(seqs, "list", alph) # Infer. z_dim = 2 - scheduler = MultiStepLR({'optimizer': Adam, - 'optim_args': {'lr': 0.1}, - 'milestones': [20, 100, 1000, 2000], - 'gamma': 0.5}) - model = FactorMuE(dataset.max_length, dataset.alphabet_length, z_dim, - indel_factor_dependence=indel_factor_dependence, - z_prior_distribution=z_prior_distribution, - ARD_prior=ARD_prior, - substitution_matrix=substitution_matrix) + scheduler = MultiStepLR( + { + "optimizer": Adam, + "optim_args": {"lr": 0.1}, + "milestones": [20, 100, 1000, 2000], + "gamma": 0.5, + } + ) + model = FactorMuE( + dataset.max_length, + dataset.alphabet_length, + z_dim, + indel_factor_dependence=indel_factor_dependence, + z_prior_distribution=z_prior_distribution, + ARD_prior=ARD_prior, + substitution_matrix=substitution_matrix, + ) n_epochs = 5 anneal_length = 2 batch_size = 2 - losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size, - scheduler, jit) + losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size, scheduler, jit) # Reconstruct. recon = model._reconstruct_regressor_seq(dataset, 1, pyro.param) @@ -76,18 +89,19 @@ def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5)) - assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.)) + assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.0)) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( - dataset, dataset, jit) - assert train_lp < 0. - assert test_lp < 0. - assert train_perplex > 0. - assert test_perplex > 0. + dataset, dataset, jit + ) + assert train_lp < 0.0 + assert test_lp < 0.0 + assert train_perplex > 0.0 + assert test_perplex > 0.0 # Embedding. z_locs, z_scales = model.embed(dataset) assert z_locs.shape == (len(dataset), z_dim) assert z_scales.shape == (len(dataset), z_dim) - assert torch.all(z_scales > 0.) + assert torch.all(z_scales > 0.0) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index 1fd4f672a9..c17bdcd50a 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -10,57 +10,63 @@ def simpleprod(lst): # Product of list of scalar tensors, as numpy would do it. if len(lst) == 0: - return torch.tensor(1.) + return torch.tensor(1.0) else: return torch.prod(torch.cat([elem[None] for elem in lst])) -@pytest.mark.parametrize('M', [2, 20]) -@pytest.mark.parametrize('batch_size', [None, 5]) -@pytest.mark.parametrize('substitute', [False, True]) +@pytest.mark.parametrize("M", [2, 20]) +@pytest.mark.parametrize("batch_size", [None, 5]) +@pytest.mark.parametrize("substitute", [False, True]) def test_profile_alternate_imp(M, batch_size, substitute): # --- Setup random model. --- pf_arranger = Profile(M) - u1 = torch.rand((M+1, 3)) + u1 = torch.rand((M + 1, 3)) u1[M, :] = 0 # Assume u_{M+1, j} = 0 for j in {0, 1, 2} in Eqn. S40. - u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) - r1 = torch.rand((M+1, 3)) + u = torch.cat([(1 - u1)[:, :, None], u1[:, :, None]], dim=2) + r1 = torch.rand((M + 1, 3)) r1[M, :] = 1 # Assume r_{M+1, j} = 1 for j in {0, 1, 2} in Eqn. S40. - r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) + r = torch.cat([(1 - r1)[:, :, None], r1[:, :, None]], dim=2) s = torch.rand((M, 4)) - s = s/torch.sum(s, dim=1, keepdim=True) - c = torch.rand((M+1, 4)) - c = c/torch.sum(c, dim=1, keepdim=True) + s = s / torch.sum(s, dim=1, keepdim=True) + c = torch.rand((M + 1, 4)) + c = c / torch.sum(c, dim=1, keepdim=True) if batch_size is not None: s = torch.rand((batch_size, M, 4)) - s = s/torch.sum(s, dim=2, keepdim=True) - u1 = torch.rand((batch_size, M+1, 3)) + s = s / torch.sum(s, dim=2, keepdim=True) + u1 = torch.rand((batch_size, M + 1, 3)) u1[:, M, :] = 0 - u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) + u = torch.cat([(1 - u1)[:, :, :, None], u1[:, :, :, None]], dim=3) # Compute forward pass of state arranger to get HMM parameters. # Don't use dimension M, assumed fixed by statearranger. if substitute: ll = torch.rand((4, 5)) - ll = ll/torch.sum(ll, dim=1, keepdim=True) + ll = ll / torch.sum(ll, dim=1, keepdim=True) a0ln, aln, eln = pf_arranger.forward( - torch.log(s), torch.log(c), - torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :]), - torch.log(ll)) + torch.log(s), + torch.log(c), + torch.log(r[:-1, :]), + torch.log(u[..., :-1, :, :]), + torch.log(ll), + ) else: a0ln, aln, eln = pf_arranger.forward( - torch.log(s), torch.log(c), - torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :])) + torch.log(s), + torch.log(c), + torch.log(r[:-1, :]), + torch.log(u[..., :-1, :, :]), + ) # - Remake HMM parameters to check. - # Here we implement Equation S40 from the MuE paper # (https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf) # more directly, iterating over all the indices of the transition matrix # and initial transition vector. - K = 2*M + 1 + K = 2 * M + 1 if batch_size is None: batch_dim_size = 1 r1 = r1.unsqueeze(0) @@ -82,56 +88,81 @@ def test_profile_alternate_imp(M, batch_size, substitute): m, g = -1, 0 u1[b][-1] = 1e-32 for gp in range(2): - for mp in range(M+gp): + for mp in range(M + gp): kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - chk_a0[b, kp] = (1 - r1[b, m+1-g, g])*(1 - u1[b, m+1-g, g]) + chk_a0[b, kp] = (1 - r1[b, m + 1 - g, g]) * ( + 1 - u1[b, m + 1 - g, g] + ) elif m + 1 - g < mp and gp == 0: chk_a0[b, kp] = ( - (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * - simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] - for mpp in - range(m+2-g, mp)]) * - (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + (1 - r1[b, m + 1 - g, g]) + * u1[b, m + 1 - g, g] + * simpleprod( + [ + (1 - r1[b, mpp, 2]) * u1[b, mpp, 2] + for mpp in range(m + 2 - g, mp) + ] + ) + * (1 - r1[b, mp, 2]) + * (1 - u1[b, mp, 2]) + ) elif m + 1 - g == mp and gp == 1: - chk_a0[b, kp] = r1[b, m+1-g, g] + chk_a0[b, kp] = r1[b, m + 1 - g, g] elif m + 1 - g < mp and gp == 1: chk_a0[b, kp] = ( - (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * - simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] - for mpp in - range(m+2-g, mp)]) * r1[b, mp, 2]) + (1 - r1[b, m + 1 - g, g]) + * u1[b, m + 1 - g, g] + * simpleprod( + [ + (1 - r1[b, mpp, 2]) * u1[b, mpp, 2] + for mpp in range(m + 2 - g, mp) + ] + ) + * r1[b, mp, 2] + ) for g in range(2): - for m in range(M+g): + for m in range(M + g): k = mg2k(m, g, M) for gp in range(2): - for mp in range(M+gp): + for mp in range(M + gp): kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - chk_a[b, k, kp] = (1 - r1[b, m+1-g, g] - )*(1 - u1[b, m+1-g, g]) + chk_a[b, k, kp] = (1 - r1[b, m + 1 - g, g]) * ( + 1 - u1[b, m + 1 - g, g] + ) elif m + 1 - g < mp and gp == 0: chk_a[b, k, kp] = ( - (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * - simpleprod([(1 - r1[b, mpp, 2]) * - u1[b, mpp, 2] - for mpp in range(m+2-g, mp)]) * - (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + (1 - r1[b, m + 1 - g, g]) + * u1[b, m + 1 - g, g] + * simpleprod( + [ + (1 - r1[b, mpp, 2]) * u1[b, mpp, 2] + for mpp in range(m + 2 - g, mp) + ] + ) + * (1 - r1[b, mp, 2]) + * (1 - u1[b, mp, 2]) + ) elif m + 1 - g == mp and gp == 1: - chk_a[b, k, kp] = r1[b, m+1-g, g] + chk_a[b, k, kp] = r1[b, m + 1 - g, g] elif m + 1 - g < mp and gp == 1: chk_a[b, k, kp] = ( - (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * - simpleprod([(1 - r1[b, mpp, 2]) * - u1[b, mpp, 2] - for mpp in - range(m+2-g, mp)] - ) * r1[b, mp, 2]) + (1 - r1[b, m + 1 - g, g]) + * u1[b, m + 1 - g, g] + * simpleprod( + [ + (1 - r1[b, mpp, 2]) * u1[b, mpp, 2] + for mpp in range(m + 2 - g, mp) + ] + ) + * r1[b, mp, 2] + ) elif m == M and mp == M and g == 0 and gp == 0: - chk_a[b, k, kp] = 1. + chk_a[b, k, kp] = 1.0 for g in range(2): - for m in range(M+g): + for m in range(M + g): k = mg2k(m, g, M) if g == 0: chk_e[b, k, :] = s[b, m, :] @@ -146,42 +177,46 @@ def test_profile_alternate_imp(M, batch_size, substitute): chk_a0 = chk_a0.squeeze() chk_e = chk_e.squeeze() - assert torch.allclose(torch.sum(torch.exp(a0ln)), torch.tensor(1.), - atol=1e-3, rtol=1e-3) - assert torch.allclose(torch.sum(torch.exp(aln), axis=1), - torch.ones(2*M+1), atol=1e-3, - rtol=1e-3) + assert torch.allclose( + torch.sum(torch.exp(a0ln)), torch.tensor(1.0), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + torch.sum(torch.exp(aln), axis=1), + torch.ones(2 * M + 1), + atol=1e-3, + rtol=1e-3, + ) assert torch.allclose(chk_a0, torch.exp(a0ln)) assert torch.allclose(chk_a, torch.exp(aln)) assert torch.allclose(chk_e, torch.exp(eln)) -@pytest.mark.parametrize('batch_ancestor_seq', [False, True]) -@pytest.mark.parametrize('batch_insert_seq', [False, True]) -@pytest.mark.parametrize('batch_insert', [False, True]) -@pytest.mark.parametrize('batch_delete', [False, True]) -@pytest.mark.parametrize('batch_substitute', [False, True]) -def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, - batch_delete, batch_substitute): +@pytest.mark.parametrize("batch_ancestor_seq", [False, True]) +@pytest.mark.parametrize("batch_insert_seq", [False, True]) +@pytest.mark.parametrize("batch_insert", [False, True]) +@pytest.mark.parametrize("batch_delete", [False, True]) +@pytest.mark.parametrize("batch_substitute", [False, True]) +def test_profile_shapes( + batch_ancestor_seq, batch_insert_seq, batch_insert, batch_delete, batch_substitute +): M, D, B = 5, 2, 3 - K = 2*M + 1 + K = 2 * M + 1 batch_size = 6 pf_arranger = Profile(M) - sln = torch.randn([batch_size]*batch_ancestor_seq + [M, D]) + sln = torch.randn([batch_size] * batch_ancestor_seq + [M, D]) sln = sln - sln.logsumexp(-1, True) - cln = torch.randn([batch_size]*batch_insert_seq + [M+1, D]) + cln = torch.randn([batch_size] * batch_insert_seq + [M + 1, D]) cln = cln - cln.logsumexp(-1, True) - rln = torch.randn([batch_size]*batch_insert + [M, 3, 2]) + rln = torch.randn([batch_size] * batch_insert + [M, 3, 2]) rln = rln - rln.logsumexp(-1, True) - uln = torch.randn([batch_size]*batch_delete + [M, 3, 2]) + uln = torch.randn([batch_size] * batch_delete + [M, 3, 2]) uln = uln - uln.logsumexp(-1, True) - lln = torch.randn([batch_size]*batch_substitute + [D, B]) + lln = torch.randn([batch_size] * batch_substitute + [D, B]) lln = lln - lln.logsumexp(-1, True) a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) - if all([not batch_ancestor_seq, not batch_insert_seq, - not batch_substitute]): + if all([not batch_ancestor_seq, not batch_insert_seq, not batch_substitute]): assert eln.shape == (K, B) assert torch.allclose(eln.logsumexp(-1), torch.zeros(K)) else: @@ -200,7 +235,7 @@ def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, assert torch.allclose(aln.logsumexp(-1), torch.zeros((batch_size, K))) -@pytest.mark.parametrize('M', [2, 20]) # , 20 +@pytest.mark.parametrize("M", [2, 20]) # , 20 def test_profile_trivial_cases(M): # Trivial case: indel probabability of zero. Expected value of # HMM should match ancestral sequence times substitution matrix. @@ -211,12 +246,16 @@ def test_profile_trivial_cases(M): pf_arranger = Profile(M) sln = torch.randn([batch_size, M, D]) sln = sln - sln.logsumexp(-1, True) - cln = torch.randn([batch_size, M+1, D]) + cln = torch.randn([batch_size, M + 1, D]) cln = cln - cln.logsumexp(-1, True) - rln = torch.cat([torch.zeros([M, 3, 1]), - -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) - uln = torch.cat([torch.zeros([M, 3, 1]), - -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) + rln = torch.cat( + [torch.zeros([M, 3, 1]), -1 / pf_arranger.epsilon * torch.ones([M, 3, 1])], + axis=-1, + ) + uln = torch.cat( + [torch.zeros([M, 3, 1]), -1 / pf_arranger.epsilon * torch.ones([M, 3, 1])], + axis=-1, + ) lln = torch.randn([D, B]) lln = lln - lln.logsumexp(-1, True) diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index aa8df92188..14eb605c0b 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -29,22 +29,22 @@ def test_ewma(alpha, NS=10000, D=1): def test_ewma_log(): ewma_log = EwmaLog(alpha=0.5) - input1 = torch.tensor(2.) - ewma_log(input1, torch.tensor(0.)) + input1 = torch.tensor(2.0) + ewma_log(input1, torch.tensor(0.0)) assert_equal(ewma_log.ewma, input1) - input2 = torch.tensor(3.) - ewma_log(input2, torch.tensor(0.)) - assert_equal(ewma_log.ewma, torch.tensor(8./3)) + input2 = torch.tensor(3.0) + ewma_log(input2, torch.tensor(0.0)) + assert_equal(ewma_log.ewma, torch.tensor(8.0 / 3)) def test_ewma_log_with_s(): ewma_log = EwmaLog(alpha=0.5) - input1 = torch.tensor(-1.) - s1 = torch.tensor(210.) + input1 = torch.tensor(-1.0) + s1 = torch.tensor(210.0) ewma_log(input1, s1) assert_equal(ewma_log.ewma, input1) - input2 = torch.tensor(-1.) + input2 = torch.tensor(-1.0) s2 = torch.tensor(210.5) ewma_log(input2, s2) - true_ewma = (1./3)*(torch.exp(s1 - s2)*input1 + 2*input2) + true_ewma = (1.0 / 3) * (torch.exp(s1 - s2) * input1 + 2 * input2) assert_equal(ewma_log.ewma, true_ewma) diff --git a/tests/contrib/oed/test_finite_spaces_eig.py b/tests/contrib/oed/test_finite_spaces_eig.py index bc012b7fba..dcac09cdd3 100644 --- a/tests/contrib/oed/test_finite_spaces_eig.py +++ b/tests/contrib/oed/test_finite_spaces_eig.py @@ -32,15 +32,16 @@ def model(design): with ExitStack() as stack: for plate in iter_plates_to_shape(batch_shape): stack.enter_context(plate) - theta = pyro.sample("theta", dist.Bernoulli(.4).expand(batch_shape)) - y = pyro.sample("y", dist.Bernoulli((design + theta) / 2.)) + theta = pyro.sample("theta", dist.Bernoulli(0.4).expand(batch_shape)) + y = pyro.sample("y", dist.Bernoulli((design + theta) / 2.0)) return y + return model @pytest.fixture def one_point_design(): - return torch.tensor(.5) + return torch.tensor(0.5) @pytest.fixture @@ -51,30 +52,32 @@ def true_eig(): def posterior_guide(y_dict, design, observation_labels, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) - a, b = pyro.param("a", torch.tensor(0.)), pyro.param("b", torch.tensor(0.)) - pyro.sample("theta", dist.Bernoulli(logits=a + b*y)) + a, b = pyro.param("a", torch.tensor(0.0)), pyro.param("b", torch.tensor(0.0)) + pyro.sample("theta", dist.Bernoulli(logits=a + b * y)) def marginal_guide(design, observation_labels, target_labels): - logit_p = pyro.param("logit_p", torch.tensor(0.)) + logit_p = pyro.param("logit_p", torch.tensor(0.0)) pyro.sample("y", dist.Bernoulli(logits=logit_p)) def likelihood_guide(theta_dict, design, observation_labels, target_labels): theta = torch.cat(list(theta_dict.values()), dim=-1) - a, b = pyro.param("a", torch.tensor(0.)), pyro.param("b", torch.tensor(0.)) - pyro.sample("y", dist.Bernoulli(logits=a + b*theta)) + a, b = pyro.param("a", torch.tensor(0.0)), pyro.param("b", torch.tensor(0.0)) + pyro.sample("y", dist.Bernoulli(logits=a + b * theta)) def make_lfire_classifier(n_theta_samples): def lfire_classifier(design, trace, observation_labels, target_labels): y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} y = torch.cat(list(y_dict.values()), dim=-1) - a, b = pyro.param("a", torch.zeros(n_theta_samples)), pyro.param("b", torch.zeros(n_theta_samples)) + a, b = pyro.param("a", torch.zeros(n_theta_samples)), pyro.param( + "b", torch.zeros(n_theta_samples) + ) - return a + b*y + return a + b * y return lfire_classifier @@ -85,11 +88,11 @@ def dv_critic(design, trace, observation_labels, target_labels): theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} theta = torch.cat(list(theta_dict.values()), dim=-1) - w_y = pyro.param("w_y", torch.tensor(0.)) - w_theta = pyro.param("w_theta", torch.tensor(0.)) - w_ytheta = pyro.param("w_ytheta", torch.tensor(0.)) + w_y = pyro.param("w_y", torch.tensor(0.0)) + w_theta = pyro.param("w_theta", torch.tensor(0.0)) + w_ytheta = pyro.param("w_ytheta", torch.tensor(0.0)) - return y*w_y + theta*w_theta + y*theta*w_ytheta + return y * w_y + theta * w_theta + y * theta * w_ytheta ######################################################################################################################## @@ -101,13 +104,28 @@ def test_posterior_finite_space_model(finite_space_model, one_point_design, true pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - posterior_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.1})) + posterior_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = posterior_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=1000) + estimated_eig = posterior_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=1000, + ) assert_equal(estimated_eig, true_eig, prec=1e-2) @@ -115,68 +133,146 @@ def test_marginal_finite_space_model(finite_space_model, one_point_design, true_ pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - marginal_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, guide=marginal_guide, - optim=optim.Adam({"lr": 0.1})) + marginal_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + guide=marginal_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = marginal_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, guide=marginal_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=1000) + estimated_eig = marginal_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + guide=marginal_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=1000, + ) assert_equal(estimated_eig, true_eig, prec=1e-2) -def test_marginal_likelihood_finite_space_model(finite_space_model, one_point_design, true_eig): +def test_marginal_likelihood_finite_space_model( + finite_space_model, one_point_design, true_eig +): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - marginal_likelihood_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, - optim=optim.Adam({"lr": 0.1})) + marginal_likelihood_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + marginal_guide=marginal_guide, + cond_guide=likelihood_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = marginal_likelihood_eig(finite_space_model, one_point_design, "y", "theta", num_samples=10, - num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=1000) + estimated_eig = marginal_likelihood_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=10, + num_steps=250, + marginal_guide=marginal_guide, + cond_guide=likelihood_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=1000, + ) assert_equal(estimated_eig, true_eig, prec=1e-2) -@pytest.mark.xfail(reason="Bernoullis are not reparametrizable and current VNMC implementation " - "assumes reparametrization") +@pytest.mark.xfail( + reason="Bernoullis are not reparametrizable and current VNMC implementation " + "assumes reparametrization" +) def test_vnmc_finite_space_model(finite_space_model, one_point_design, true_eig): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - vnmc_eig(finite_space_model, one_point_design, "y", "theta", num_samples=[9, 3], - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.1})) + vnmc_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=[9, 3], + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = vnmc_eig(finite_space_model, one_point_design, "y", "theta", num_samples=[9, 3], - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=[1000, 100]) + estimated_eig = vnmc_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=[9, 3], + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=[1000, 100], + ) assert_equal(estimated_eig, true_eig, prec=1e-2) def test_nmc_eig_finite_space_model(finite_space_model, one_point_design, true_eig): pyro.set_rng_seed(42) pyro.clear_param_store() - estimated_eig = nmc_eig(finite_space_model, one_point_design, "y", "theta", M=40, N=40 * 40) + estimated_eig = nmc_eig( + finite_space_model, one_point_design, "y", "theta", M=40, N=40 * 40 + ) assert_equal(estimated_eig, true_eig, prec=1e-2) def test_lfire_finite_space_model(finite_space_model, one_point_design, true_eig): pyro.set_rng_seed(42) pyro.clear_param_store() - estimated_eig = lfire_eig(finite_space_model, one_point_design, "y", "theta", num_y_samples=5, - num_theta_samples=50, num_steps=1000, classifier=make_lfire_classifier(50), - optim=optim.Adam({"lr": 0.0025}), final_num_samples=500) + estimated_eig = lfire_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_y_samples=5, + num_theta_samples=50, + num_steps=1000, + classifier=make_lfire_classifier(50), + optim=optim.Adam({"lr": 0.0025}), + final_num_samples=500, + ) assert_equal(estimated_eig, true_eig, prec=1e-2) def test_dv_finite_space_model(finite_space_model, one_point_design, true_eig): pyro.set_rng_seed(42) pyro.clear_param_store() - donsker_varadhan_eig(finite_space_model, one_point_design, "y", "theta", num_samples=100, - num_steps=250, T=dv_critic, optim=optim.Adam({"lr": 0.1})) - estimated_eig = donsker_varadhan_eig(finite_space_model, one_point_design, "y", "theta", num_samples=100, - num_steps=250, T=dv_critic, optim=optim.Adam({"lr": 0.01}), - final_num_samples=2000) + donsker_varadhan_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=100, + num_steps=250, + T=dv_critic, + optim=optim.Adam({"lr": 0.1}), + ) + estimated_eig = donsker_varadhan_eig( + finite_space_model, + one_point_design, + "y", + "theta", + num_samples=100, + num_steps=250, + T=dv_critic, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=2000, + ) assert_equal(estimated_eig, true_eig, prec=1e-2) diff --git a/tests/contrib/oed/test_glmm.py b/tests/contrib/oed/test_glmm.py index d07a851ed7..f8078a44bc 100644 --- a/tests/contrib/oed/test_glmm.py +++ b/tests/contrib/oed/test_glmm.py @@ -20,45 +20,56 @@ def lm_2p_10_10_1(design): - w = pyro.sample("w", dist.Normal(torch.tensor(0.), - torch.tensor([10., 10.])).to_event(1)) + w = pyro.sample( + "w", dist.Normal(torch.tensor(0.0), torch.tensor([10.0, 10.0])).to_event(1) + ) mean = torch.matmul(design, w.unsqueeze(-1)).squeeze(-1) - y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.)).to_event(1)) + y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.0)).to_event(1)) return y def lm_2p_10_10_1_w12(design): - w1 = pyro.sample("w1", dist.Normal(torch.tensor([0.]), - torch.tensor(10.)).to_event(1)) - w2 = pyro.sample("w2", dist.Normal(torch.tensor([0.]), - torch.tensor(10.)).to_event(1)) + w1 = pyro.sample( + "w1", dist.Normal(torch.tensor([0.0]), torch.tensor(10.0)).to_event(1) + ) + w2 = pyro.sample( + "w2", dist.Normal(torch.tensor([0.0]), torch.tensor(10.0)).to_event(1) + ) w = torch.cat([w1, w2], dim=-1) mean = torch.matmul(design, w.unsqueeze(-1)).squeeze(-1) - y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.)).to_event(1)) + y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.0)).to_event(1)) return y def nz_lm_2p_10_10_1(design): - w = pyro.sample("w", dist.Normal(torch.tensor([1., -1.]), - torch.tensor([10., 10.])).to_event(1)) + w = pyro.sample( + "w", + dist.Normal(torch.tensor([1.0, -1.0]), torch.tensor([10.0, 10.0])).to_event(1), + ) mean = torch.matmul(design, w.unsqueeze(-1)).squeeze(-1) - y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.)).to_event(1)) + y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.0)).to_event(1)) return y def normal_inv_gamma_2_2_10_10(design): - tau = pyro.sample("tau", dist.Gamma(torch.tensor(2.), torch.tensor(2.))) - obs_sd = 1./torch.sqrt(tau) - w = pyro.sample("w", dist.Normal(torch.tensor([1., -1.]), - obs_sd*torch.tensor([10., 10.])).to_event(1)) + tau = pyro.sample("tau", dist.Gamma(torch.tensor(2.0), torch.tensor(2.0))) + obs_sd = 1.0 / torch.sqrt(tau) + w = pyro.sample( + "w", + dist.Normal( + torch.tensor([1.0, -1.0]), obs_sd * torch.tensor([10.0, 10.0]) + ).to_event(1), + ) mean = torch.matmul(design, w.unsqueeze(-1)).squeeze(-1) - y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.)).to_event(1)) + y = pyro.sample("y", dist.Normal(mean, torch.tensor(1.0)).to_event(1)) return y def lr_10_10(design): - w = pyro.sample("w", dist.Normal(torch.tensor([1., -1.]), - torch.tensor([10., 10.])).to_event(1)) + w = pyro.sample( + "w", + dist.Normal(torch.tensor([1.0, -1.0]), torch.tensor([10.0, 10.0])).to_event(1), + ) mean = torch.matmul(design, w.unsqueeze(-1)).squeeze(-1) y = pyro.sample("y", dist.Bernoulli(logits=mean).to_event(1)) return y @@ -66,65 +77,91 @@ def lr_10_10(design): def sigmoid_example(design): n = design.shape[-2] - random_effect_k = pyro.sample("k", dist.Gamma(2.*torch.ones(n), torch.tensor(2.))) - random_effect_offset = pyro.sample("w2", dist.Normal(torch.tensor(0.), torch.ones(n))) - w1 = pyro.sample("w1", dist.Normal(torch.tensor([1., -1.]), - torch.tensor([10., 10.])).to_event(1)) + random_effect_k = pyro.sample( + "k", dist.Gamma(2.0 * torch.ones(n), torch.tensor(2.0)) + ) + random_effect_offset = pyro.sample( + "w2", dist.Normal(torch.tensor(0.0), torch.ones(n)) + ) + w1 = pyro.sample( + "w1", + dist.Normal(torch.tensor([1.0, -1.0]), torch.tensor([10.0, 10.0])).to_event(1), + ) mean = torch.matmul(design[..., :-2], w1.unsqueeze(-1)).squeeze(-1) offset_mean = mean + random_effect_offset - base_dist = dist.Normal(offset_mean, torch.tensor(1.)).to_event(1) + base_dist = dist.Normal(offset_mean, torch.tensor(1.0)).to_event(1) transforms = [ - AffineTransform(loc=torch.tensor(0.), scale=random_effect_k), - SigmoidTransform() + AffineTransform(loc=torch.tensor(0.0), scale=random_effect_k), + SigmoidTransform(), ] response_dist = dist.TransformedDistribution(base_dist, transforms) y = pyro.sample("y", response_dist) return y -@pytest.mark.parametrize("model1,model2,design", [ - ( - zero_mean_unit_obs_sd_lm(torch.tensor([10., 10.]))[0], - lm_2p_10_10_1, - torch.tensor([[1., -1.]]) - ), - ( - lm_2p_10_10_1, - zero_mean_unit_obs_sd_lm(torch.tensor([10., 10.]))[0], - torch.tensor([[100., -100.]]) - ), - ( - group_linear_model(torch.tensor(0.), torch.tensor([10.]), torch.tensor(0.), - torch.tensor([10.]), torch.tensor(1.)), - lm_2p_10_10_1_w12, - torch.tensor([[-1.5, 0.5], [1.5, 0.]]) - ), - ( - known_covariance_linear_model(torch.tensor([1., -1.]), torch.tensor([10., 10.]), torch.tensor(1.)), - nz_lm_2p_10_10_1, - torch.tensor([[-1., 0.5], [2.5, -2.]]) - ), - ( - normal_inverse_gamma_linear_model(torch.tensor([1., -1.]), torch.tensor(.1), - torch.tensor(2.), torch.tensor(2.)), - normal_inv_gamma_2_2_10_10, - torch.tensor([[1., -0.5], [1.5, 2.]]) - ), - ( - logistic_regression_model(torch.tensor([1., -1.]), torch.tensor(10.)), - lr_10_10, - torch.tensor([[6., -1.5], [.5, 0.]]) - ), - ( - sigmoid_model(torch.tensor([1., -1.]), torch.tensor([10., 10.]), - torch.tensor(0.), torch.tensor([1., 1.]), - torch.tensor(1.), - torch.tensor(2.), torch.tensor(2.), torch.eye(2)), - sigmoid_example, - torch.cat([torch.tensor([[1., 1.], [.5, -2.5]]), torch.eye(2)], dim=-1) - ) -]) +@pytest.mark.parametrize( + "model1,model2,design", + [ + ( + zero_mean_unit_obs_sd_lm(torch.tensor([10.0, 10.0]))[0], + lm_2p_10_10_1, + torch.tensor([[1.0, -1.0]]), + ), + ( + lm_2p_10_10_1, + zero_mean_unit_obs_sd_lm(torch.tensor([10.0, 10.0]))[0], + torch.tensor([[100.0, -100.0]]), + ), + ( + group_linear_model( + torch.tensor(0.0), + torch.tensor([10.0]), + torch.tensor(0.0), + torch.tensor([10.0]), + torch.tensor(1.0), + ), + lm_2p_10_10_1_w12, + torch.tensor([[-1.5, 0.5], [1.5, 0.0]]), + ), + ( + known_covariance_linear_model( + torch.tensor([1.0, -1.0]), torch.tensor([10.0, 10.0]), torch.tensor(1.0) + ), + nz_lm_2p_10_10_1, + torch.tensor([[-1.0, 0.5], [2.5, -2.0]]), + ), + ( + normal_inverse_gamma_linear_model( + torch.tensor([1.0, -1.0]), + torch.tensor(0.1), + torch.tensor(2.0), + torch.tensor(2.0), + ), + normal_inv_gamma_2_2_10_10, + torch.tensor([[1.0, -0.5], [1.5, 2.0]]), + ), + ( + logistic_regression_model(torch.tensor([1.0, -1.0]), torch.tensor(10.0)), + lr_10_10, + torch.tensor([[6.0, -1.5], [0.5, 0.0]]), + ), + ( + sigmoid_model( + torch.tensor([1.0, -1.0]), + torch.tensor([10.0, 10.0]), + torch.tensor(0.0), + torch.tensor([1.0, 1.0]), + torch.tensor(1.0), + torch.tensor(2.0), + torch.tensor(2.0), + torch.eye(2), + ), + sigmoid_example, + torch.cat([torch.tensor([[1.0, 1.0], [0.5, -2.5]]), torch.eye(2)], dim=-1), + ), + ], +) def test_log_prob_matches(model1, model2, design): trace = poutine.trace(model1).get_trace(design) trace.compute_log_prob() diff --git a/tests/contrib/oed/test_linear_models_eig.py b/tests/contrib/oed/test_linear_models_eig.py index 1680c4ec86..c657b4a777 100644 --- a/tests/contrib/oed/test_linear_models_eig.py +++ b/tests/contrib/oed/test_linear_models_eig.py @@ -27,15 +27,17 @@ @pytest.fixture def linear_model(): - return known_covariance_linear_model(coef_means=torch.tensor(0.), - coef_sds=torch.tensor([1., 1.5]), - observation_sd=torch.tensor(1.)) + return known_covariance_linear_model( + coef_means=torch.tensor(0.0), + coef_sds=torch.tensor([1.0, 1.5]), + observation_sd=torch.tensor(1.0), + ) @pytest.fixture def one_point_design(): X = torch.zeros(3, 2) - X[0, 0] = X[1, 1] = X[2, 1] = 1. + X[0, 0] = X[1, 1] = X[2, 1] = 1.0 return X @@ -43,8 +45,11 @@ def posterior_guide(y_dict, design, observation_labels, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) A = pyro.param("A", torch.zeros(2, 3)) - scale_tril = pyro.param("scale_tril", torch.tensor([[1., 0.], [0., 1.5]]), - constraint=torch.distributions.constraints.lower_cholesky) + scale_tril = pyro.param( + "scale_tril", + torch.tensor([[1.0, 0.0], [0.0, 1.5]]), + constraint=torch.distributions.constraints.lower_cholesky, + ) mu = rmv(A, y) pyro.sample("w", dist.MultivariateNormal(mu, scale_tril=scale_tril)) @@ -52,8 +57,11 @@ def posterior_guide(y_dict, design, observation_labels, target_labels): def marginal_guide(design, observation_labels, target_labels): mu = pyro.param("mu", torch.zeros(3)) - scale_tril = pyro.param("scale_tril", torch.eye(3), - constraint=torch.distributions.constraints.lower_cholesky) + scale_tril = pyro.param( + "scale_tril", + torch.eye(3), + constraint=torch.distributions.constraints.lower_cholesky, + ) pyro.sample("y", dist.MultivariateNormal(mu, scale_tril)) @@ -64,8 +72,11 @@ def likelihood_guide(theta_dict, design, observation_labels, target_labels): # Need to avoid name collision here mu = pyro.param("mu_l", torch.zeros(3)) - scale_tril = pyro.param("scale_tril_l", torch.eye(3), - constraint=torch.distributions.constraints.lower_cholesky) + scale_tril = pyro.param( + "scale_tril_l", + torch.eye(3), + constraint=torch.distributions.constraints.lower_cholesky, + ) pyro.sample("y", dist.MultivariateNormal(centre + mu, scale_tril=scale_tril)) @@ -79,12 +90,18 @@ def lfire_classifier(design, trace, observation_labels, target_labels): y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} y = torch.cat(list(y_dict.values()), dim=-1) - quadratic_coef = pyro.param("quadratic_coef", torch.zeros(n_theta_samples, 3, 3)) + quadratic_coef = pyro.param( + "quadratic_coef", torch.zeros(n_theta_samples, 3, 3) + ) linear_coef = pyro.param("linear_coef", torch.zeros(n_theta_samples, 3)) bias = pyro.param("bias", torch.zeros(n_theta_samples)) y_quadratic = y.unsqueeze(-1) * y.unsqueeze(-2) - return (quadratic_coef * y_quadratic).sum(-1).sum(-1) + (linear_coef * y).sum(-1) + bias + return ( + (quadratic_coef * y_quadratic).sum(-1).sum(-1) + + (linear_coef * y).sum(-1) + + bias + ) return lfire_classifier @@ -107,13 +124,28 @@ def test_posterior_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - posterior_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.1})) + posterior_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = posterior_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=500) + estimated_eig = posterior_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=500, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -122,13 +154,28 @@ def test_marginal_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - marginal_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, guide=marginal_guide, - optim=optim.Adam({"lr": 0.1})) + marginal_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + guide=marginal_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = marginal_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, guide=marginal_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=500) + estimated_eig = marginal_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + guide=marginal_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=500, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -137,13 +184,30 @@ def test_marginal_likelihood_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - marginal_likelihood_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, - optim=optim.Adam({"lr": 0.1})) + marginal_likelihood_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + marginal_guide=marginal_guide, + cond_guide=likelihood_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = marginal_likelihood_eig(linear_model, one_point_design, "y", "w", num_samples=10, - num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=500) + estimated_eig = marginal_likelihood_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=10, + num_steps=250, + marginal_guide=marginal_guide, + cond_guide=likelihood_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=500, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -152,13 +216,28 @@ def test_vnmc_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) - vnmc_eig(linear_model, one_point_design, "y", "w", num_samples=[9, 3], - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.1})) + vnmc_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=[9, 3], + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.1}), + ) # Finesse (small learning rate) - estimated_eig = vnmc_eig(linear_model, one_point_design, "y", "w", num_samples=[9, 3], - num_steps=250, guide=posterior_guide, - optim=optim.Adam({"lr": 0.01}), final_num_samples=[500, 100]) + estimated_eig = vnmc_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=[9, 3], + num_steps=250, + guide=posterior_guide, + optim=optim.Adam({"lr": 0.01}), + final_num_samples=[500, 100], + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -175,10 +254,17 @@ def test_laplace_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # You can use 1 final sample here because linear models have a posterior entropy that is independent of `y` - estimated_eig = laplace_eig(linear_model, one_point_design, "y", "w", - guide=laplace_guide, num_steps=250, final_num_samples=1, - optim=optim.Adam({"lr": 0.05}), - loss=Trace_ELBO().differentiable_loss) + estimated_eig = laplace_eig( + linear_model, + one_point_design, + "y", + "w", + guide=laplace_guide, + num_steps=250, + final_num_samples=1, + optim=optim.Adam({"lr": 0.05}), + loss=Trace_ELBO().differentiable_loss, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -186,9 +272,18 @@ def test_laplace_linear_model(linear_model, one_point_design): def test_lfire_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() - estimated_eig = lfire_eig(linear_model, one_point_design, "y", "w", num_y_samples=2, num_theta_samples=50, - num_steps=1200, classifier=make_lfire_classifier(50), optim=optim.Adam({"lr": 0.0025}), - final_num_samples=100) + estimated_eig = lfire_eig( + linear_model, + one_point_design, + "y", + "w", + num_y_samples=2, + num_theta_samples=50, + num_steps=1200, + classifier=make_lfire_classifier(50), + optim=optim.Adam({"lr": 0.0025}), + final_num_samples=100, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) @@ -196,10 +291,26 @@ def test_lfire_linear_model(linear_model, one_point_design): def test_dv_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() - donsker_varadhan_eig(linear_model, one_point_design, "y", "w", num_samples=100, num_steps=500, T=dv_critic, - optim=optim.Adam({"lr": 0.1})) - estimated_eig = donsker_varadhan_eig(linear_model, one_point_design, "y", "w", num_samples=100, - num_steps=650, T=dv_critic, optim=optim.Adam({"lr": 0.001}), - final_num_samples=2000) + donsker_varadhan_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=100, + num_steps=500, + T=dv_critic, + optim=optim.Adam({"lr": 0.1}), + ) + estimated_eig = donsker_varadhan_eig( + linear_model, + one_point_design, + "y", + "w", + num_samples=100, + num_steps=650, + T=dv_critic, + optim=optim.Adam({"lr": 0.001}), + final_num_samples=2000, + ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2) diff --git a/tests/contrib/oed/test_xexpx.py b/tests/contrib/oed/test_xexpx.py index e786ecd646..4d22515286 100644 --- a/tests/contrib/oed/test_xexpx.py +++ b/tests/contrib/oed/test_xexpx.py @@ -7,10 +7,13 @@ from pyro.contrib.oed.eig import xexpx -@pytest.mark.parametrize("argument,output", [ - (torch.tensor([float('-inf')]), torch.tensor([0.])), - (torch.tensor([0.]), torch.tensor([0.])), - (torch.tensor([1.]), torch.exp(torch.tensor([1.]))) -]) +@pytest.mark.parametrize( + "argument,output", + [ + (torch.tensor([float("-inf")]), torch.tensor([0.0])), + (torch.tensor([0.0]), torch.tensor([0.0])), + (torch.tensor([1.0]), torch.exp(torch.tensor([1.0]))), + ], +) def test_xexpx(argument, output): assert xexpx(argument) == output diff --git a/tests/contrib/randomvariable/test_random_variable.py b/tests/contrib/randomvariable/test_random_variable.py index 067393dc8d..c6363ab8af 100644 --- a/tests/contrib/randomvariable/test_random_variable.py +++ b/tests/contrib/randomvariable/test_random_variable.py @@ -38,7 +38,7 @@ def test_multiply_divide(): def test_abs(): X = Uniform(0, 1).rv # (0, 1) - X = 2*(X - 0.5) # (-1, 1) + X = 2 * (X - 0.5) # (-1, 1) X = abs(X) # (0, 1) x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 1)).all().item() @@ -53,7 +53,7 @@ def test_neg(): def test_pow(): X = Uniform(0, 1).rv # (0, 1) - X = X**2 # (0, 1) + X = X ** 2 # (0, 1) x = X.dist.sample([N_SAMPLES]) assert ((0 <= x) & (x <= 1)).all().item() @@ -63,7 +63,7 @@ def test_tensor_ops(): X = Uniform(0, 1).expand([5, 5]).rv a = torch.tensor([[1, 2, 3, 4, 5]]) b = a.T - X = abs(pi*(-X + a - 3*b)) + X = abs(pi * (-X + a - 3 * b)) x = X.dist.sample() assert x.shape == (5, 5) assert (x >= 0).all().item() @@ -71,8 +71,8 @@ def test_tensor_ops(): def test_chaining(): X = ( - Uniform(0, 1).rv # (0, 1) - .add(1) # (1, 2) + Uniform(0, 1) + .rv.add(1) # (0, 1) # (1, 2) .pow(2) # (1, 4) .mul(2) # (2, 8) .sub(5) # (-3, 3) @@ -80,4 +80,4 @@ def test_chaining(): .exp() # (1/e, e) ) x = X.dist.sample([N_SAMPLES]) - assert ((1/math.e <= x) & (x <= math.e)).all().item() + assert ((1 / math.e <= x) & (x <= math.e)).all().item() diff --git a/tests/contrib/test_minipyro.py b/tests/contrib/test_minipyro.py index b9c6324011..c6e509cd57 100644 --- a/tests/contrib/test_minipyro.py +++ b/tests/contrib/test_minipyro.py @@ -36,8 +36,10 @@ def assert_error(model, guide, elbo, match=None): Assert that inference fails with an error. """ inference = build_svi(model, guide, elbo) - with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), - match=match): + with pytest.raises( + (NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), + match=match, + ): inference.step() @@ -49,14 +51,16 @@ def assert_warning(model, guide, elbo): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inference.step() - assert len(w), 'No warnings were raised' + assert len(w), "No warnings were raised" for warning in w: print(warning) def constrained_model(data): locs = pyro.param("locs", torch.randn(3), constraint=constraints.real) - scales = pyro.param("scales", ops.exp(torch.randn(3)), constraint=constraints.positive) + scales = pyro.param( + "scales", ops.exp(torch.randn(3)), constraint=constraints.positive + ) p = torch.tensor([0.5, 0.3, 0.2]) x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) @@ -69,7 +73,6 @@ def guide_constrained_model(data): @pytest.mark.parametrize("backend", ["pyro", "minipyro"]) def test_generate_data(backend): - def model(data=None): loc = pyro.param("loc", torch.tensor(2.0)) scale = pyro.param("scale", torch.tensor(1.0)) @@ -103,15 +106,14 @@ def model(data=None): @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) @pytest.mark.parametrize("backend", ["pyro", "minipyro"]) def test_nonempty_model_empty_guide_ok(backend, jit): - def model(data): loc = pyro.param("loc", torch.tensor(0.0)) - pyro.sample("x", dist.Normal(loc, 1.), obs=data) + pyro.sample("x", dist.Normal(loc, 1.0), obs=data) def guide(data): pass - data = torch.tensor(2.) + data = torch.tensor(2.0) with pyro_backend(backend): Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) @@ -129,7 +131,7 @@ def model(): p = torch.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) - pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) + pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data) def guide(): p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2])) @@ -150,13 +152,13 @@ def test_nested_plate_plate_ok(backend, jit): def model(): loc = torch.tensor(3.0) with pyro.plate("plate_outer", data.size(-1), dim=-1): - x = pyro.sample("x", dist.Normal(loc, 1.)) + x = pyro.sample("x", dist.Normal(loc, 1.0)) with pyro.plate("plate_inner", data.size(-2), dim=-2): - pyro.sample("y", dist.Normal(x, 1.), obs=data) + pyro.sample("y", dist.Normal(x, 1.0), obs=data) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - scale = pyro.param("scale", torch.tensor(1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0)) with pyro.plate("plate_outer", data.size(-1), dim=-1): pyro.sample("x", dist.Normal(loc, scale)) @@ -167,18 +169,21 @@ def guide(): @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) -@pytest.mark.parametrize("backend", [ - "pyro", - xfail_param("minipyro", reason="not implemented"), -]) +@pytest.mark.parametrize( + "backend", + [ + "pyro", + xfail_param("minipyro", reason="not implemented"), + ], +) def test_local_param_ok(backend, jit): data = torch.randn(10) def model(): - locs = pyro.param("locs", torch.tensor([-1., 0., 1.])) + locs = pyro.param("locs", torch.tensor([-1.0, 0.0, 1.0])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3)) - pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) + pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data) def guide(): with pyro.plate("plate", len(data), dim=-1): @@ -226,7 +231,9 @@ def test_elbo_jit(backend): elbo_test_case(backend, jit=True, expected_elbo=0.4780, data=data, steps=15) -@pytest.mark.parametrize(["backend", "jit"], [("pyro", True), ("pyro", False), ("minipyro", False)]) +@pytest.mark.parametrize( + ["backend", "jit"], [("pyro", True), ("pyro", False), ("minipyro", False)] +) def test_elbo_equivalence(backend, jit): """ Given model and guide diff --git a/tests/contrib/test_util.py b/tests/contrib/test_util.py index a0688fcdbb..bb122f021d 100644 --- a/tests/contrib/test_util.py +++ b/tests/contrib/test_util.py @@ -23,24 +23,32 @@ def test_get_indices_sizes(): sizes = OrderedDict([("a", 2), ("b", 2), ("c", 2)]) assert_equal(get_indices(["b"], sizes=sizes), torch.tensor([2, 3])) assert_equal(get_indices(["b", "c"], sizes=sizes), torch.tensor([2, 3, 4, 5])) - tensors = OrderedDict([("a", torch.ones(2)), ("b", torch.ones(2)), ("c", torch.ones(2))]) + tensors = OrderedDict( + [("a", torch.ones(2)), ("b", torch.ones(2)), ("c", torch.ones(2))] + ) assert_equal(get_indices(["b"], tensors=tensors), torch.tensor([2, 3])) assert_equal(get_indices(["b", "c"], tensors=tensors), torch.tensor([2, 3, 4, 5])) def test_tensor_to_dict(): sizes = OrderedDict([("a", 2), ("b", 2), ("c", 2)]) - vector = torch.tensor([1., 2, 3, 4, 5, 6]) - assert_equal(tensor_to_dict(sizes, vector), {"a": torch.tensor([1., 2.]), - "b": torch.tensor([3., 4.]), - "c": torch.tensor([5., 6.])}) - assert_equal(tensor_to_dict(sizes, vector, subset=["b"]), - {"b": torch.tensor([3., 4.])}) - - -@pytest.mark.parametrize("A,b", [ - (torch.tensor([[1., 2.], [2., -3.]]), torch.tensor([-1., 2.])) - ]) + vector = torch.tensor([1.0, 2, 3, 4, 5, 6]) + assert_equal( + tensor_to_dict(sizes, vector), + { + "a": torch.tensor([1.0, 2.0]), + "b": torch.tensor([3.0, 4.0]), + "c": torch.tensor([5.0, 6.0]), + }, + ) + assert_equal( + tensor_to_dict(sizes, vector, subset=["b"]), {"b": torch.tensor([3.0, 4.0])} + ) + + +@pytest.mark.parametrize( + "A,b", [(torch.tensor([[1.0, 2.0], [2.0, -3.0]]), torch.tensor([-1.0, 2.0]))] +) def test_rmv(A, b): assert_equal(rmv(A, b), A.mv(b), prec=1e-8) batched_A = lexpand(A, 5, 4) @@ -49,9 +57,7 @@ def test_rmv(A, b): assert_equal(rmv(batched_A, batched_b), expected_Ab, prec=1e-8) -@pytest.mark.parametrize("a,b", [ - (torch.tensor([1., 2.]), torch.tensor([-1., 2.])) - ]) +@pytest.mark.parametrize("a,b", [(torch.tensor([1.0, 2.0]), torch.tensor([-1.0, 2.0]))]) def test_rvv(a, b): assert_equal(rvv(a, b), torch.dot(a, b), prec=1e-8) batched_a = lexpand(a, 5, 4) @@ -61,21 +67,23 @@ def test_rvv(a, b): def test_lexpand(): - A = torch.tensor([[1., 2.], [-2., 0]]) + A = torch.tensor([[1.0, 2.0], [-2.0, 0]]) assert_equal(lexpand(A), A, prec=1e-8) assert_equal(lexpand(A, 4), A.expand(4, 2, 2), prec=1e-8) assert_equal(lexpand(A, 4, 2), A.expand(4, 2, 2, 2), prec=1e-8) def test_rexpand(): - A = torch.tensor([[1., 2.], [-2., 0]]) + A = torch.tensor([[1.0, 2.0], [-2.0, 0]]) assert_equal(rexpand(A), A, prec=1e-8) assert_equal(rexpand(A, 4), A.unsqueeze(-1).expand(2, 2, 4), prec=1e-8) - assert_equal(rexpand(A, 4, 2), A.unsqueeze(-1).unsqueeze(-1).expand(2, 2, 4, 2), prec=1e-8) + assert_equal( + rexpand(A, 4, 2), A.unsqueeze(-1).unsqueeze(-1).expand(2, 2, 4, 2), prec=1e-8 + ) def test_rtril(): - A = torch.tensor([[1., 2.], [-2., 0]]) + A = torch.tensor([[1.0, 2.0], [-2.0, 0]]) assert_equal(rtril(A), torch.tril(A), prec=1e-8) expanded = lexpand(A, 5, 4) expected = lexpand(torch.tril(A), 5, 4) @@ -83,7 +91,7 @@ def test_rtril(): def test_rdiag(): - v = torch.tensor([1., 2., -1.]) + v = torch.tensor([1.0, 2.0, -1.0]) assert_equal(rdiag(v), torch.diag(v), prec=1e-8) expanded = lexpand(v, 5, 4) expeceted = lexpand(torch.diag(v), 5, 4) diff --git a/tests/contrib/timeseries/test_gp.py b/tests/contrib/timeseries/test_gp.py index fe520b6c74..2a6c73bf85 100644 --- a/tests/contrib/timeseries/test_gp.py +++ b/tests/contrib/timeseries/test_gp.py @@ -18,60 +18,95 @@ from tests.common import assert_equal -@pytest.mark.parametrize('model,obs_dim,nu_statedim', [('ssmgp', 3, 1.5), ('ssmgp', 2, 2.5), - ('lcmgp', 3, 1.5), ('lcmgp', 2, 2.5), - ('imgp', 1, 0.5), ('imgp', 2, 0.5), - ('imgp', 1, 1.5), ('imgp', 3, 1.5), - ('imgp', 1, 2.5), ('imgp', 3, 2.5), - ('dmgp', 1, 1.5), ('dmgp', 2, 1.5), - ('dmgp', 3, 1.5), - ('glgssm', 1, 3), ('glgssm', 3, 1)]) -@pytest.mark.parametrize('T', [11, 37]) +@pytest.mark.parametrize( + "model,obs_dim,nu_statedim", + [ + ("ssmgp", 3, 1.5), + ("ssmgp", 2, 2.5), + ("lcmgp", 3, 1.5), + ("lcmgp", 2, 2.5), + ("imgp", 1, 0.5), + ("imgp", 2, 0.5), + ("imgp", 1, 1.5), + ("imgp", 3, 1.5), + ("imgp", 1, 2.5), + ("imgp", 3, 2.5), + ("dmgp", 1, 1.5), + ("dmgp", 2, 1.5), + ("dmgp", 3, 1.5), + ("glgssm", 1, 3), + ("glgssm", 3, 1), + ], +) +@pytest.mark.parametrize("T", [11, 37]) def test_timeseries_models(model, nu_statedim, obs_dim, T): - torch.set_default_tensor_type('torch.DoubleTensor') + torch.set_default_tensor_type("torch.DoubleTensor") dt = 0.1 + torch.rand(1).item() - if model == 'lcmgp': + if model == "lcmgp": num_gps = 2 - gp = LinearlyCoupledMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, num_gps=num_gps, - length_scale_init=0.5 + torch.rand(num_gps), - kernel_scale_init=0.5 + torch.rand(num_gps), - obs_noise_scale_init=0.5 + torch.rand(obs_dim)) - elif model == 'imgp': - gp = IndependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, - length_scale_init=0.5 + torch.rand(obs_dim), - kernel_scale_init=0.5 + torch.rand(obs_dim), - obs_noise_scale_init=0.5 + torch.rand(obs_dim)) - elif model == 'glgssm': - gp = GenericLGSSM(state_dim=nu_statedim, obs_dim=obs_dim, - obs_noise_scale_init=0.5 + torch.rand(obs_dim)) - elif model == 'ssmgp': + gp = LinearlyCoupledMaternGP( + nu=nu_statedim, + obs_dim=obs_dim, + dt=dt, + num_gps=num_gps, + length_scale_init=0.5 + torch.rand(num_gps), + kernel_scale_init=0.5 + torch.rand(num_gps), + obs_noise_scale_init=0.5 + torch.rand(obs_dim), + ) + elif model == "imgp": + gp = IndependentMaternGP( + nu=nu_statedim, + obs_dim=obs_dim, + dt=dt, + length_scale_init=0.5 + torch.rand(obs_dim), + kernel_scale_init=0.5 + torch.rand(obs_dim), + obs_noise_scale_init=0.5 + torch.rand(obs_dim), + ) + elif model == "glgssm": + gp = GenericLGSSM( + state_dim=nu_statedim, + obs_dim=obs_dim, + obs_noise_scale_init=0.5 + torch.rand(obs_dim), + ) + elif model == "ssmgp": state_dim = {0.5: 4, 1.5: 3, 2.5: 2}[nu_statedim] - gp = GenericLGSSMWithGPNoiseModel(nu=nu_statedim, state_dim=state_dim, obs_dim=obs_dim, - obs_noise_scale_init=0.5 + torch.rand(obs_dim)) - elif model == 'dmgp': + gp = GenericLGSSMWithGPNoiseModel( + nu=nu_statedim, + state_dim=state_dim, + obs_dim=obs_dim, + obs_noise_scale_init=0.5 + torch.rand(obs_dim), + ) + elif model == "dmgp": linearly_coupled = bool(torch.rand(1).item() > 0.5) - gp = DependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, linearly_coupled=linearly_coupled, - length_scale_init=0.5 + torch.rand(obs_dim)) + gp = DependentMaternGP( + nu=nu_statedim, + obs_dim=obs_dim, + dt=dt, + linearly_coupled=linearly_coupled, + length_scale_init=0.5 + torch.rand(obs_dim), + ) targets = torch.randn(T, obs_dim) gp_log_prob = gp.log_prob(targets) - if model == 'imgp': + if model == "imgp": assert gp_log_prob.shape == (obs_dim,) else: assert gp_log_prob.dim() == 0 # compare matern log probs to vanilla GP result via multivariate normal - if model == 'imgp': + if model == "imgp": times = dt * torch.arange(T).double() for dim in range(obs_dim): lengthscale = gp.kernel.length_scale[dim] variance = gp.kernel.kernel_scale.pow(2)[dim] obs_noise = gp.obs_noise_scale.pow(2)[dim] - kernel = {0.5: pyro.contrib.gp.kernels.Exponential, - 1.5: pyro.contrib.gp.kernels.Matern32, - 2.5: pyro.contrib.gp.kernels.Matern52}[nu_statedim] + kernel = { + 0.5: pyro.contrib.gp.kernels.Exponential, + 1.5: pyro.contrib.gp.kernels.Matern32, + 2.5: pyro.contrib.gp.kernels.Matern52, + }[nu_statedim] kernel = kernel(input_dim=1, lengthscale=lengthscale, variance=variance) # XXX kernel(times) loads old parameters from param store kernel = kernel.forward(times) + obs_noise * torch.eye(T) @@ -81,27 +116,27 @@ def test_timeseries_models(model, nu_statedim, obs_dim, T): assert_equal(mvn_log_prob, gp_log_prob[dim], prec=1e-4) for S in [1, 5]: - if model in ['imgp', 'lcmgp', 'dmgp', 'lcdgp']: + if model in ["imgp", "lcmgp", "dmgp", "lcdgp"]: dts = torch.rand(S).cumsum(dim=-1) predictive = gp.forecast(targets, dts) else: predictive = gp.forecast(targets, S) assert predictive.loc.shape == (S, obs_dim) - if model == 'imgp': + if model == "imgp": assert predictive.scale.shape == (S, obs_dim) # assert monotonic increase of predictive noise if S > 1: - delta = predictive.scale[1:S, :] - predictive.scale[0:S-1, :] + delta = predictive.scale[1:S, :] - predictive.scale[0 : S - 1, :] assert (delta > 0.0).sum() == (S - 1) * obs_dim else: assert predictive.covariance_matrix.shape == (S, obs_dim, obs_dim) # assert monotonic increase of predictive noise if S > 1: dets = predictive.covariance_matrix.det() - delta = dets[1:S] - dets[0:S-1] + delta = dets[1:S] - dets[0 : S - 1] assert (delta > 0.0).sum() == (S - 1) - if model in ['imgp', 'lcmgp', 'dmgp', 'lcdgp']: + if model in ["imgp", "lcmgp", "dmgp", "lcdgp"]: # the distant future dts = torch.tensor([500.0]) predictive = gp.forecast(targets, dts) @@ -109,11 +144,12 @@ def test_timeseries_models(model, nu_statedim, obs_dim, T): assert_equal(predictive.loc, torch.zeros(1, obs_dim)) -@pytest.mark.parametrize('obs_dim', [1, 3]) +@pytest.mark.parametrize("obs_dim", [1, 3]) def test_dependent_matern_gp(obs_dim): dt = 0.5 + torch.rand(1).item() - gp = DependentMaternGP(nu=1.5, obs_dim=obs_dim, dt=dt, - length_scale_init=0.5 + torch.rand(obs_dim)) + gp = DependentMaternGP( + nu=1.5, obs_dim=obs_dim, dt=dt, length_scale_init=0.5 + torch.rand(obs_dim) + ) # make sure stationary covariance matrix satisfies the relevant # matrix riccati equation @@ -127,5 +163,9 @@ def test_dependent_matern_gp(obs_dim): wiener_cov = gp._get_wiener_cov() wiener_cov *= torch.tensor([[0.0, 0.0], [0.0, 1.0]]).repeat(obs_dim, obs_dim) - expected_zero = torch.matmul(F, stat_cov) + torch.matmul(stat_cov, F.transpose(-1, -2)) + wiener_cov + expected_zero = ( + torch.matmul(F, stat_cov) + + torch.matmul(stat_cov, F.transpose(-1, -2)) + + wiener_cov + ) assert_equal(expected_zero, torch.zeros(gp.full_state_dim, gp.full_state_dim)) diff --git a/tests/contrib/timeseries/test_lgssm.py b/tests/contrib/timeseries/test_lgssm.py index 5b5ed9d339..8bfcef31a2 100644 --- a/tests/contrib/timeseries/test_lgssm.py +++ b/tests/contrib/timeseries/test_lgssm.py @@ -8,19 +8,26 @@ from tests.common import assert_equal -@pytest.mark.parametrize('model_class', ['lgssm', 'lgssmgp']) -@pytest.mark.parametrize('state_dim', [2, 3]) -@pytest.mark.parametrize('obs_dim', [2, 4]) -@pytest.mark.parametrize('T', [11, 17]) +@pytest.mark.parametrize("model_class", ["lgssm", "lgssmgp"]) +@pytest.mark.parametrize("state_dim", [2, 3]) +@pytest.mark.parametrize("obs_dim", [2, 4]) +@pytest.mark.parametrize("T", [11, 17]) def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): - torch.set_default_tensor_type('torch.DoubleTensor') - - if model_class == 'lgssm': - model = GenericLGSSM(state_dim=state_dim, obs_dim=obs_dim, - obs_noise_scale_init=0.1 + torch.rand(obs_dim)) - elif model_class == 'lgssmgp': - model = GenericLGSSMWithGPNoiseModel(state_dim=state_dim, obs_dim=obs_dim, nu=1.5, - obs_noise_scale_init=0.1 + torch.rand(obs_dim)) + torch.set_default_tensor_type("torch.DoubleTensor") + + if model_class == "lgssm": + model = GenericLGSSM( + state_dim=state_dim, + obs_dim=obs_dim, + obs_noise_scale_init=0.1 + torch.rand(obs_dim), + ) + elif model_class == "lgssmgp": + model = GenericLGSSMWithGPNoiseModel( + state_dim=state_dim, + obs_dim=obs_dim, + nu=1.5, + obs_noise_scale_init=0.1 + torch.rand(obs_dim), + ) # with these hyperparameters we essentially turn off the GP contributions model.kernel.length_scale = 1.0e-6 * torch.ones(obs_dim) model.kernel.kernel_scale = 1.0e-6 * torch.ones(obs_dim) @@ -28,10 +35,14 @@ def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): targets = torch.randn(T, obs_dim) filtering_state = model._filter(targets) - actual_loc, actual_cov = model._forecast(3, filtering_state, include_observation_noise=False) + actual_loc, actual_cov = model._forecast( + 3, filtering_state, include_observation_noise=False + ) - obs_matrix = model.obs_matrix if model_class == 'lgssm' else model.z_obs_matrix - trans_matrix = model.trans_matrix if model_class == 'lgssm' else model.z_trans_matrix + obs_matrix = model.obs_matrix if model_class == "lgssm" else model.z_obs_matrix + trans_matrix = ( + model.trans_matrix if model_class == "lgssm" else model.z_trans_matrix + ) trans_matrix_sq = torch.mm(trans_matrix, trans_matrix) trans_matrix_cubed = torch.mm(trans_matrix_sq, trans_matrix) @@ -40,7 +51,11 @@ def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): trans_trans_trans_obs = torch.mm(trans_matrix_cubed, obs_matrix) # we only compute contributions for the state space portion for lgssmgp - fs_loc = filtering_state.loc if model_class == 'lgssm' else filtering_state.loc[-state_dim:] + fs_loc = ( + filtering_state.loc + if model_class == "lgssm" + else filtering_state.loc[-state_dim:] + ) predicted_mean1 = torch.mm(fs_loc.unsqueeze(-2), trans_obs).squeeze(-2) predicted_mean2 = torch.mm(fs_loc.unsqueeze(-2), trans_trans_obs).squeeze(-2) @@ -53,25 +68,30 @@ def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): # check predicted covariances for 3 timesteps fs_covar, process_covar = None, None - if model_class == 'lgssm': + if model_class == "lgssm": process_covar = model._get_trans_dist().covariance_matrix fs_covar = filtering_state.covariance_matrix - elif model_class == 'lgssmgp': + elif model_class == "lgssmgp": # we only compute contributions for the state space portion process_covar = model.trans_noise_scale_sq.diag_embed() fs_covar = filtering_state.covariance_matrix[-state_dim:, -state_dim:] - predicted_covar1 = torch.mm(trans_obs.t(), torch.mm(fs_covar, trans_obs)) + \ - torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) - - predicted_covar2 = torch.mm(trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_obs)) + \ - torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \ - torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) - - predicted_covar3 = torch.mm(trans_trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_trans_obs)) + \ - torch.mm(trans_trans_obs.t(), torch.mm(process_covar, trans_trans_obs)) + \ - torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \ - torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) + predicted_covar1 = torch.mm( + trans_obs.t(), torch.mm(fs_covar, trans_obs) + ) + torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) + + predicted_covar2 = ( + torch.mm(trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_obs)) + + torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + + torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) + ) + + predicted_covar3 = ( + torch.mm(trans_trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_trans_obs)) + + torch.mm(trans_trans_obs.t(), torch.mm(process_covar, trans_trans_obs)) + + torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + + torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix)) + ) assert_equal(actual_cov[0], predicted_covar1) assert_equal(actual_cov[1], predicted_covar2) diff --git a/tests/contrib/tracking/test_assignment.py b/tests/contrib/tracking/test_assignment.py index b8457b6797..ae92ea85af 100644 --- a/tests/contrib/tracking/test_assignment.py +++ b/tests/contrib/tracking/test_assignment.py @@ -16,12 +16,12 @@ ) from tests.common import assert_equal -INF = float('inf') +INF = float("inf") logger = logging.getLogger(__name__) def assert_finite(tensor, name): - assert ((tensor - tensor) == 0).all(), 'bad {}: {}'.format(tensor, name) + assert ((tensor - tensor) == 0).all(), "bad {}: {}".format(tensor, name) def logit(p): @@ -30,8 +30,10 @@ def logit(p): def dense_to_sparse(assign_logits): num_detections, num_objects = assign_logits.shape - edges = assign_logits.new_tensor([[j, i] for j in range(num_detections) for i in range(num_objects)], - dtype=torch.long).t() + edges = assign_logits.new_tensor( + [[j, i] for j in range(num_detections) for i in range(num_objects)], + dtype=torch.long, + ).t() assign_logits = assign_logits[edges[0], edges[1]] return edges, assign_logits @@ -47,10 +49,14 @@ def test_dense_smoke(): num_detections = 2 pyro.set_rng_seed(0) exists_logits = torch.zeros(num_objects) - assign_logits = logit(torch.tensor([ - [0.5, 0.5, 0.0, 0.0], - [0.0, 0.5, 0.5, 0.5], - ])) + assign_logits = logit( + torch.tensor( + [ + [0.5, 0.5, 0.0, 0.0], + [0.0, 0.5, 0.5, 0.5], + ] + ) + ) assert assign_logits.shape == (num_detections, num_objects) solver = MarginalAssignment(exists_logits, assign_logits, bp_iters=5) @@ -63,7 +69,9 @@ def test_dense_smoke(): # test dense matches sparse edges, assign_logits = dense_to_sparse(assign_logits) - other = MarginalAssignmentSparse(num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=5) + other = MarginalAssignmentSparse( + num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=5 + ) assert_equal(other.exists_dist.probs, solver.exists_dist.probs, prec=1e-3) assert_equal(other.assign_dist.probs, solver.assign_dist.probs, prec=1e-3) @@ -73,15 +81,19 @@ def test_sparse_smoke(): num_detections = 2 pyro.set_rng_seed(0) exists_logits = torch.zeros(num_objects) - edges = exists_logits.new_tensor([ - [0, 0, 1, 0, 1, 0], - [0, 1, 1, 2, 2, 3], - ], dtype=torch.long) + edges = exists_logits.new_tensor( + [ + [0, 0, 1, 0, 1, 0], + [0, 1, 1, 2, 2, 3], + ], + dtype=torch.long, + ) assign_logits = logit(torch.tensor([0.99, 0.8, 0.2, 0.2, 0.8, 0.9])) assert assign_logits.shape == edges.shape[1:] - solver = MarginalAssignmentSparse(num_objects, num_detections, edges, - exists_logits, assign_logits, bp_iters=5) + solver = MarginalAssignmentSparse( + num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=5 + ) assert solver.exists_dist.batch_shape == (num_objects,) assert solver.exists_dist.event_shape == () @@ -97,7 +109,6 @@ def test_sparse_smoke(): def test_sparse_grid_smoke(): - def my_existence_prior(ox, oy): return -0.5 @@ -105,12 +116,9 @@ def my_assign_prior(ox, oy, dx, dy): return 0.0 num_detections = 3 * 3 - detections = [[0, 1, 2], - [3, 4, 5], - [6, 7, 8]] + detections = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] num_objects = 2 * 2 - objects = [[0, 1], - [2, 3]] + objects = [[0, 1], [2, 3]] edges = [] edge_coords = [] for x in range(2): @@ -130,50 +138,59 @@ def my_assign_prior(ox, oy, dx, dy): for y in range(2): object_id = objects[x][y] exists_logits[object_id] = my_existence_prior(x, y) - assign_logits = exists_logits.new_tensor([my_assign_prior(ox, oy, dx, dy) - for ox, oy, dx, dy in edge_coords]) - assign = MarginalAssignmentSparse(num_objects, num_detections, edges, - exists_logits, assign_logits, bp_iters=10) + assign_logits = exists_logits.new_tensor( + [my_assign_prior(ox, oy, dx, dy) for ox, oy, dx, dy in edge_coords] + ) + assign = MarginalAssignmentSparse( + num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=10 + ) assert isinstance(assign.assign_dist, dist.Categorical) -@pytest.mark.parametrize('bp_iters', [None, 10], ids=['enum', 'bp']) +@pytest.mark.parametrize("bp_iters", [None, 10], ids=["enum", "bp"]) def test_persistent_smoke(bp_iters): - exists_logits = torch.tensor([-1., -1., -2.], requires_grad=True) - assign_logits = torch.tensor([[[-1., -INF, -INF], - [-2., -2., -INF]], - [[-1., -2., -3.], - [-2., -2., -1.]], - [[-1., -2., -3.], - [-2., -2., -1.]], - [[-1., -1., 1.], - [1., 1., -1.]]], requires_grad=True) - - assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters=bp_iters) + exists_logits = torch.tensor([-1.0, -1.0, -2.0], requires_grad=True) + assign_logits = torch.tensor( + [ + [[-1.0, -INF, -INF], [-2.0, -2.0, -INF]], + [[-1.0, -2.0, -3.0], [-2.0, -2.0, -1.0]], + [[-1.0, -2.0, -3.0], [-2.0, -2.0, -1.0]], + [[-1.0, -1.0, 1.0], [1.0, 1.0, -1.0]], + ], + requires_grad=True, + ) + + assignment = MarginalAssignmentPersistent( + exists_logits, assign_logits, bp_iters=bp_iters + ) assert assignment.num_frames == 4 assert assignment.num_detections == 2 assert assignment.num_objects == 3 assign_dist = assignment.assign_dist exists_dist = assignment.exists_dist - assert_finite(exists_dist.probs, 'exists_probs') - assert_finite(assign_dist.probs, 'assign_probs') + assert_finite(exists_dist.probs, "exists_probs") + assert_finite(assign_dist.probs, "assign_probs") for exists in exists_dist.enumerate_support(): log_prob = exists_dist.log_prob(exists).sum() - e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits], create_graph=True) - assert_finite(e_grad, 'dexists_probs/dexists_logits') - assert_finite(a_grad, 'dexists_probs/dassign_logits') + e_grad, a_grad = grad( + log_prob, [exists_logits, assign_logits], create_graph=True + ) + assert_finite(e_grad, "dexists_probs/dexists_logits") + assert_finite(a_grad, "dexists_probs/dassign_logits") for assign in assign_dist.enumerate_support(): log_prob = assign_dist.log_prob(assign).sum() - e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits], create_graph=True) - assert_finite(e_grad, 'dassign_probs/dexists_logits') - assert_finite(a_grad, 'dassign_probs/dassign_logits') + e_grad, a_grad = grad( + log_prob, [exists_logits, assign_logits], create_graph=True + ) + assert_finite(e_grad, "dassign_probs/dexists_logits") + assert_finite(a_grad, "dassign_probs/dassign_logits") -@pytest.mark.parametrize('e', [-1., 0., 1.]) -@pytest.mark.parametrize('a', [-1., 0., 1.]) +@pytest.mark.parametrize("e", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("a", [-1.0, 0.0, 1.0]) def test_flat_exact_1_1(e, a): exists_logits = torch.tensor([e]) assign_logits = torch.tensor([[a]]) @@ -183,9 +200,9 @@ def test_flat_exact_1_1(e, a): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) -@pytest.mark.parametrize('e', [-1., 0., 1.]) -@pytest.mark.parametrize('a11', [-1., 0., 1.]) -@pytest.mark.parametrize('a21', [-1., 0., 1.]) +@pytest.mark.parametrize("e", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("a11", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("a21", [-1.0, 0.0, 1.0]) def test_flat_exact_2_1(e, a11, a21): exists_logits = torch.tensor([e]) assign_logits = torch.tensor([[a11], [a21]]) @@ -195,10 +212,10 @@ def test_flat_exact_2_1(e, a11, a21): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) -@pytest.mark.parametrize('e1', [-1., 0., 1.]) -@pytest.mark.parametrize('e2', [-1., 0., 1.]) -@pytest.mark.parametrize('a11', [-1., 0., 1.]) -@pytest.mark.parametrize('a12', [-1., 0., 1.]) +@pytest.mark.parametrize("e1", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("e2", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("a11", [-1.0, 0.0, 1.0]) +@pytest.mark.parametrize("a12", [-1.0, 0.0, 1.0]) def test_flat_exact_1_2(e1, e2, a11, a12): exists_logits = torch.tensor([e1, e2]) assign_logits = torch.tensor([[a11, a12]]) @@ -208,11 +225,11 @@ def test_flat_exact_1_2(e1, e2, a11, a12): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) -@pytest.mark.parametrize('e1', [-1., 1.]) -@pytest.mark.parametrize('e2', [-1., 1.]) -@pytest.mark.parametrize('a11', [-1., 1.]) -@pytest.mark.parametrize('a12', [-1., 1.]) -@pytest.mark.parametrize('a22', [-1., 1.]) +@pytest.mark.parametrize("e1", [-1.0, 1.0]) +@pytest.mark.parametrize("e2", [-1.0, 1.0]) +@pytest.mark.parametrize("a11", [-1.0, 1.0]) +@pytest.mark.parametrize("a12", [-1.0, 1.0]) +@pytest.mark.parametrize("a22", [-1.0, 1.0]) def test_flat_exact_2_2(e1, e2, a11, a12, a22): a21 = -INF exists_logits = torch.tensor([e1, e2]) @@ -223,8 +240,8 @@ def test_flat_exact_2_2(e1, e2, a11, a12, a22): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) -@pytest.mark.parametrize('num_detections', [1, 2, 3, 4]) -@pytest.mark.parametrize('num_objects', [1, 2, 3, 4]) +@pytest.mark.parametrize("num_detections", [1, 2, 3, 4]) +@pytest.mark.parametrize("num_objects", [1, 2, 3, 4]) def test_flat_bp_vs_exact(num_objects, num_detections): exists_logits = -2 * torch.rand(num_objects) assign_logits = -2 * torch.rand(num_detections, num_objects) @@ -235,21 +252,23 @@ def test_flat_bp_vs_exact(num_objects, num_detections): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.01) -@pytest.mark.parametrize('num_frames', [1, 2, 3, 4]) -@pytest.mark.parametrize('num_objects', [1, 2, 3, 4]) -@pytest.mark.parametrize('bp_iters', [None, 30], ids=['enum', 'bp']) +@pytest.mark.parametrize("num_frames", [1, 2, 3, 4]) +@pytest.mark.parametrize("num_objects", [1, 2, 3, 4]) +@pytest.mark.parametrize("bp_iters", [None, 30], ids=["enum", "bp"]) def test_flat_vs_persistent(num_objects, num_frames, bp_iters): exists_logits = -2 * torch.rand(num_objects) assign_logits = -2 * torch.rand(num_frames, num_objects) flat = MarginalAssignment(exists_logits, assign_logits, bp_iters) - full = MarginalAssignmentPersistent(exists_logits, assign_logits.unsqueeze(1), bp_iters) + full = MarginalAssignmentPersistent( + exists_logits, assign_logits.unsqueeze(1), bp_iters + ) assert_equal(flat.exists_dist.probs, full.exists_dist.probs) assert_equal(flat.assign_dist.probs, full.assign_dist.probs.squeeze(1)) -@pytest.mark.parametrize('num_detections', [1, 2, 3, 4]) -@pytest.mark.parametrize('num_frames', [1, 2, 3, 4]) -@pytest.mark.parametrize('num_objects', [1, 2, 3, 4]) +@pytest.mark.parametrize("num_detections", [1, 2, 3, 4]) +@pytest.mark.parametrize("num_frames", [1, 2, 3, 4]) +@pytest.mark.parametrize("num_objects", [1, 2, 3, 4]) def test_persistent_bp_vs_exact(num_objects, num_frames, num_detections): exists_logits = -2 * torch.rand(num_objects) assign_logits = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 @@ -260,44 +279,59 @@ def test_persistent_bp_vs_exact(num_objects, num_frames, num_detections): assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.05) -@pytest.mark.parametrize('e1', [-1., 1.]) -@pytest.mark.parametrize('e2', [-1., 1.]) -@pytest.mark.parametrize('e3', [-1., 1.]) -@pytest.mark.parametrize('bp_iters, bp_momentum', [(3, 0.), (30, 0.5)], ids=['momentum', 'none']) +@pytest.mark.parametrize("e1", [-1.0, 1.0]) +@pytest.mark.parametrize("e2", [-1.0, 1.0]) +@pytest.mark.parametrize("e3", [-1.0, 1.0]) +@pytest.mark.parametrize( + "bp_iters, bp_momentum", [(3, 0.0), (30, 0.5)], ids=["momentum", "none"] +) def test_persistent_exact_5_4_3(e1, e2, e3, bp_iters, bp_momentum): exists_logits = torch.tensor([e1, e2, e3]) assign_logits = 2 * torch.rand(5, 4, 3) - 1 # this has tree-shaped connectivity and should lead to exact inference - mask = torch.tensor([[[1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[1, 0, 0], [0, 1, 1], [0, 0, 1], [1, 0, 0]], - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0]], - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]], dtype=torch.bool) + mask = torch.tensor( + [ + [[1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[1, 0, 0], [0, 1, 1], [0, 0, 1], [1, 0, 0]], + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0]], + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], + ], + dtype=torch.bool, + ) assign_logits[~mask] = -INF expected = MarginalAssignmentPersistent(exists_logits, assign_logits, None) - actual = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters, bp_momentum) + actual = MarginalAssignmentPersistent( + exists_logits, assign_logits, bp_iters, bp_momentum + ) assert_equal(expected.exists_dist.probs, actual.exists_dist.probs) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) logger.debug(actual.exists_dist.probs) logger.debug(actual.assign_dist.probs) -@pytest.mark.parametrize('num_detections', [1, 2, 3]) -@pytest.mark.parametrize('num_frames', [1, 2, 3]) -@pytest.mark.parametrize('num_objects', [1, 2]) -@pytest.mark.parametrize('bp_iters', [None, 30], ids=['enum', 'bp']) -def test_persistent_independent_subproblems(num_objects, num_frames, num_detections, bp_iters): +@pytest.mark.parametrize("num_detections", [1, 2, 3]) +@pytest.mark.parametrize("num_frames", [1, 2, 3]) +@pytest.mark.parametrize("num_objects", [1, 2]) +@pytest.mark.parametrize("bp_iters", [None, 30], ids=["enum", "bp"]) +def test_persistent_independent_subproblems( + num_objects, num_frames, num_detections, bp_iters +): # solve a random assignment problem exists_logits_1 = -2 * torch.rand(num_objects) assign_logits_1 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 - assignment_1 = MarginalAssignmentPersistent(exists_logits_1, assign_logits_1, bp_iters) + assignment_1 = MarginalAssignmentPersistent( + exists_logits_1, assign_logits_1, bp_iters + ) exists_probs_1 = assignment_1.exists_dist.probs assign_probs_1 = assignment_1.assign_dist.probs # solve another random assignment problem exists_logits_2 = -2 * torch.rand(num_objects) assign_logits_2 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 - assignment_2 = MarginalAssignmentPersistent(exists_logits_2, assign_logits_2, bp_iters) + assignment_2 = MarginalAssignmentPersistent( + exists_logits_2, assign_logits_2, bp_iters + ) exists_probs_2 = assignment_2.exists_dist.probs assign_probs_2 = assignment_2.assign_dist.probs @@ -313,7 +347,11 @@ def test_persistent_independent_subproblems(num_objects, num_frames, num_detecti # check agreement assert_equal(exists_probs_1, exists_probs[:num_objects]) assert_equal(exists_probs_2, exists_probs[num_objects:]) - assert_equal(assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects]) + assert_equal( + assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects] + ) assert_equal(assign_probs_1[:, :, -1], assign_probs[:, :num_detections, -1]) - assert_equal(assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1]) + assert_equal( + assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1] + ) assert_equal(assign_probs_2[:, :, -1], assign_probs[:, num_detections:, -1]) diff --git a/tests/contrib/tracking/test_distributions.py b/tests/contrib/tracking/test_distributions.py index fe4c149b49..119a7251dc 100644 --- a/tests/contrib/tracking/test_distributions.py +++ b/tests/contrib/tracking/test_distributions.py @@ -8,15 +8,15 @@ from pyro.contrib.tracking.dynamic_models import NcpContinuous, NcvContinuous -@pytest.mark.parametrize('Model', [NcpContinuous, NcvContinuous]) -@pytest.mark.parametrize('dim', [2, 3]) -@pytest.mark.parametrize('time', [2, 3]) +@pytest.mark.parametrize("Model", [NcpContinuous, NcvContinuous]) +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("time", [2, 3]) def test_EKFDistribution_smoke(Model, dim, time): - x0 = torch.rand(2*dim) + x0 = torch.rand(2 * dim) ys = torch.randn(time, dim) - P0 = torch.eye(2*dim).requires_grad_() + P0 = torch.eye(2 * dim).requires_grad_() R = torch.eye(dim).requires_grad_() - model = Model(2*dim, 2.0) + model = Model(2 * dim, 2.0) dist = EKFDistribution(x0, P0, model, R, time_steps=time) log_prob = dist.log_prob(ys) assert log_prob.shape == torch.Size() diff --git a/tests/contrib/tracking/test_dynamic_models.py b/tests/contrib/tracking/test_dynamic_models.py index abda5ff56e..d307b3e67f 100644 --- a/tests/contrib/tracking/test_dynamic_models.py +++ b/tests/contrib/tracking/test_dynamic_models.py @@ -12,36 +12,36 @@ from tests.common import assert_equal, assert_not_equal -def assert_cov_validity(cov, eigenvalue_lbnd=0., condition_number_ubnd=1e6): - ''' +def assert_cov_validity(cov, eigenvalue_lbnd=0.0, condition_number_ubnd=1e6): + """ cov: covariance matrix eigenvalue_lbnd: eigenvalues should be at least this much greater than zero. Must be strictly positive. condition_number_ubnd: inclusive upper bound on matrix condition number. Must be greater or equal to 1.0. - ''' - assert eigenvalue_lbnd >= 0.0, \ - 'Covariance eigenvalue lower bound must be > 0.0!' - assert condition_number_ubnd >= 1.0, \ - 'Covariance condition number bound must be >= 1.0!' + """ + assert eigenvalue_lbnd >= 0.0, "Covariance eigenvalue lower bound must be > 0.0!" + assert ( + condition_number_ubnd >= 1.0 + ), "Covariance condition number bound must be >= 1.0!" # Symmetry - assert (cov.t() == cov).all(), 'Covariance must be symmetric!' + assert (cov.t() == cov).all(), "Covariance must be symmetric!" # Precompute eigenvalues for subsequent tests. ws = torch.linalg.eigvalsh(cov) # The eigenvalues of cov w_min = torch.min(ws) w_max = torch.max(ws) # Strict positivity - assert w_min > 0.0, 'Covariance must be strictly positive!' + assert w_min > 0.0, "Covariance must be strictly positive!" # Eigenvalue lower bound - assert w_min >= eigenvalue_lbnd, \ - 'Covariance eigenvalues must be >= lower bound!' + assert w_min >= eigenvalue_lbnd, "Covariance eigenvalues must be >= lower bound!" # Condition number upper bound - assert w_max/w_min <= condition_number_ubnd, \ - 'Condition number must be <= upper bound!' + assert ( + w_max / w_min <= condition_number_ubnd + ), "Condition number must be <= upper bound!" def test_NcpContinuous(): @@ -50,7 +50,7 @@ def test_NcpContinuous(): d = 3 ncp = NcpContinuous(dimension=d, sv2=2.0) assert ncp.dimension == d - assert ncp.dimension_pv == 2*d + assert ncp.dimension_pv == 2 * d assert ncp.num_process_noise_parameters == 1 x = torch.rand(d) @@ -67,8 +67,8 @@ def test_NcpContinuous(): P = torch.eye(d) P_pv = ncp.cov2pv(P) - assert P_pv.shape == (2*d, 2*d) - P_pv_ref = torch.zeros((2*d, 2*d)) + assert P_pv.shape == (2 * d, 2 * d) + P_pv_ref = torch.zeros((2 * d, 2 * d)) P_pv_ref[:d, :d] = P assert_equal(P_pv_ref, P_pv) @@ -84,7 +84,7 @@ def test_NcpContinuous(): def test_NcvContinuous(): framerate = 100 # Hz - dt = 1.0/framerate + dt = 1.0 / framerate d = 6 ncv = NcvContinuous(dimension=d, sa2=2.0) assert ncv.dimension == d @@ -93,7 +93,7 @@ def test_NcvContinuous(): x = torch.rand(d) y = ncv(x, dt) - assert_equal(y[0], x[0] + dt*x[d//2]) + assert_equal(y[0], x[0] + dt * x[d // 2]) dx = ncv.geodesic_difference(x, y) assert_not_equal(dx, torch.zeros(d)) @@ -119,11 +119,11 @@ def test_NcvContinuous(): def test_NcpDiscrete(): framerate = 100 # Hz - dt = 1.0/framerate + dt = 1.0 / framerate d = 3 ncp = NcpDiscrete(dimension=d, sv2=2.0) assert ncp.dimension == d - assert ncp.dimension_pv == 2*d + assert ncp.dimension_pv == 2 * d assert ncp.num_process_noise_parameters == 1 x = torch.rand(d) @@ -140,8 +140,8 @@ def test_NcpDiscrete(): P = torch.eye(d) P_pv = ncp.cov2pv(P) - assert P_pv.shape == (2*d, 2*d) - P_pv_ref = torch.zeros((2*d, 2*d)) + assert P_pv.shape == (2 * d, 2 * d) + P_pv_ref = torch.zeros((2 * d, 2 * d)) P_pv_ref[:d, :d] = P assert_equal(P_pv_ref, P_pv) @@ -157,7 +157,7 @@ def test_NcpDiscrete(): def test_NcvDiscrete(): framerate = 100 # Hz - dt = 1.0/framerate + dt = 1.0 / framerate dt = 100 d = 6 ncv = NcvDiscrete(dimension=d, sa2=2.0) @@ -167,7 +167,7 @@ def test_NcvDiscrete(): x = torch.rand(d) y = ncv(x, dt) - assert_equal(y[0], x[0] + dt*x[d//2]) + assert_equal(y[0], x[0] + dt * x[d // 2]) dx = ncv.geodesic_difference(x, y) assert_not_equal(dx, torch.zeros(d)) diff --git a/tests/contrib/tracking/test_ekf.py b/tests/contrib/tracking/test_ekf.py index 35db1544d1..c511cdbda5 100644 --- a/tests/contrib/tracking/test_ekf.py +++ b/tests/contrib/tracking/test_ekf.py @@ -20,7 +20,7 @@ def test_EKFState_with_NcpContinuous(): assert ekf_state.dynamic_model.__class__ == NcpContinuous assert ekf_state.dimension == d - assert ekf_state.dimension_pv == 2*d + assert ekf_state.dimension_pv == 2 * d assert_equal(x, ekf_state.mean, prec=1e-5) assert_equal(P, ekf_state.cov, prec=1e-5) @@ -28,16 +28,13 @@ def test_EKFState_with_NcpContinuous(): assert_equal(P, ekf_state.cov_pv[:d, :d], prec=1e-5) assert_equal(t, ekf_state.time, prec=1e-5) - ekf_state1 = EKFState(ncp, 2*x, 2*P, t) + ekf_state1 = EKFState(ncp, 2 * x, 2 * P, t) ekf_state2 = ekf_state1.predict(dt) assert ekf_state2.dynamic_model.__class__ == NcpContinuous - measurement = PositionMeasurement( - mean=torch.rand(d), - cov=torch.eye(d), - time=t + dt) + measurement = PositionMeasurement(mean=torch.rand(d), cov=torch.eye(d), time=t + dt) log_likelihood = ekf_state2.log_likelihood_of_update(measurement) - assert (log_likelihood < 0.).all() + assert (log_likelihood < 0.0).all() ekf_state3, (dz, S) = ekf_state2.update(measurement) assert dz.shape == (measurement.dimension,) assert S.shape == (measurement.dimension, measurement.dimension) @@ -51,8 +48,7 @@ def test_EKFState_with_NcvContinuous(): P = torch.eye(d) t = 0.0 dt = 2.0 - ekf_state = EKFState( - dynamic_model=ncv, mean=x, cov=P, time=t) + ekf_state = EKFState(dynamic_model=ncv, mean=x, cov=P, time=t) assert ekf_state.dynamic_model.__class__ == NcvContinuous assert ekf_state.dimension == d @@ -64,16 +60,13 @@ def test_EKFState_with_NcvContinuous(): assert_equal(P, ekf_state.cov_pv, prec=1e-5) assert_equal(t, ekf_state.time, prec=1e-5) - ekf_state1 = EKFState(ncv, 2*x, 2*P, t) + ekf_state1 = EKFState(ncv, 2 * x, 2 * P, t) ekf_state2 = ekf_state1.predict(dt) assert ekf_state2.dynamic_model.__class__ == NcvContinuous - measurement = PositionMeasurement( - mean=torch.rand(d), - cov=torch.eye(d), - time=t + dt) + measurement = PositionMeasurement(mean=torch.rand(d), cov=torch.eye(d), time=t + dt) log_likelihood = ekf_state2.log_likelihood_of_update(measurement) - assert (log_likelihood < 0.).all() + assert (log_likelihood < 0.0).all() ekf_state3, (dz, S) = ekf_state2.update(measurement) assert dz.shape == (measurement.dimension,) assert S.shape == (measurement.dimension, measurement.dimension) diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index c3401f4114..aa6e12dea2 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -20,7 +20,7 @@ def make_args(): - args = type('Args', (), {}) # A fake ArgumentParser.parse_args() + args = type("Args", (), {}) # A fake ArgumentParser.parse_args() args.max_num_objects = 4 args.num_real_detections = 13 args.num_fake_detections = 3 @@ -37,38 +37,41 @@ def make_args(): def model(detections, args): - noise_scale = pyro.param('noise_scale') - objects = pyro.param('objects_loc').squeeze(-1) - num_detections, = detections.shape - max_num_objects, = objects.shape + noise_scale = pyro.param("noise_scale") + objects = pyro.param("objects_loc").squeeze(-1) + (num_detections,) = detections.shape + (max_num_objects,) = objects.shape # Existence part. p_exists = args.expected_num_objects / max_num_objects - with pyro.plate('objects_plate', max_num_objects): - exists = pyro.sample('exists', dist.Bernoulli(p_exists)) + with pyro.plate("objects_plate", max_num_objects): + exists = pyro.sample("exists", dist.Bernoulli(p_exists)) with poutine.mask(mask=exists.bool()): - pyro.sample('objects', dist.Normal(0., 1.), obs=objects) + pyro.sample("objects", dist.Normal(0.0, 1.0), obs=objects) # Assignment part. p_fake = args.num_fake_detections / num_detections - with pyro.plate('detections_plate', num_detections): + with pyro.plate("detections_plate", num_detections): assign_probs = torch.empty(max_num_objects + 1) assign_probs[:-1] = (1 - p_fake) / max_num_objects assign_probs[-1] = p_fake - assign = pyro.sample('assign', dist.Categorical(logits=assign_probs)) - is_fake = (assign == assign.shape[-1] - 1) + assign = pyro.sample("assign", dist.Categorical(logits=assign_probs)) + is_fake = assign == assign.shape[-1] - 1 objects_plus_bogus = torch.zeros(max_num_objects + 1) objects_plus_bogus[:max_num_objects] = objects real_dist = dist.Normal(objects_plus_bogus[assign], noise_scale) - fake_dist = dist.Normal(0., 1.) - pyro.sample('detections', dist.MaskedMixture(is_fake, real_dist, fake_dist), - obs=detections) + fake_dist = dist.Normal(0.0, 1.0) + pyro.sample( + "detections", + dist.MaskedMixture(is_fake, real_dist, fake_dist), + obs=detections, + ) # This should match detection_model's existence part. def compute_exists_logits(objects, args): p_exists = args.expected_num_objects / args.max_num_objects - real_part = dist.Normal(0., 1.).log_prob(objects) + real_part = dist.Normal(0.0, 1.0).log_prob(objects) real_part = real_part + math.log(p_exists) spurious_part = torch.full(real_part.shape, math.log(1 - p_exists)) return real_part - spurious_part @@ -80,98 +83,115 @@ def compute_assign_logits(objects, detections, noise_scale, args): p_fake = args.num_fake_detections / num_detections real_part = dist.Normal(objects, noise_scale).log_prob(detections) real_part = real_part + math.log((1 - p_fake) / args.max_num_objects) - fake_part = dist.Normal(0., 1.).log_prob(detections) + fake_part = dist.Normal(0.0, 1.0).log_prob(detections) fake_part = fake_part + math.log(p_fake) return real_part - fake_part def guide(detections, args): - noise_scale = pyro.param('noise_scale') # trained by SVI - objects = pyro.param('objects_loc').squeeze(-1) # trained by M-step of EM - num_detections, = detections.shape - max_num_objects, = objects.shape + noise_scale = pyro.param("noise_scale") # trained by SVI + objects = pyro.param("objects_loc").squeeze(-1) # trained by M-step of EM + (num_detections,) = detections.shape + (max_num_objects,) = objects.shape with torch.set_grad_enabled(args.assignment_grad): # Evaluate log likelihoods. TODO make this more pyronic. exists_logits = compute_exists_logits(objects, args) - assign_logits = compute_assign_logits(objects, detections.unsqueeze(-1), noise_scale, args) + assign_logits = compute_assign_logits( + objects, detections.unsqueeze(-1), noise_scale, args + ) assert exists_logits.shape == (max_num_objects,) assert assign_logits.shape == (num_detections, max_num_objects) # Compute soft assignments. assignment = MarginalAssignment(exists_logits, assign_logits, bp_iters=10) - with pyro.plate('objects_plate', max_num_objects): - pyro.sample('exists', assignment.exists_dist, - infer={'enumerate': 'parallel'}) - with pyro.plate('detections_plate', num_detections): - pyro.sample('assign', assignment.assign_dist, - infer={'enumerate': 'parallel'}) + with pyro.plate("objects_plate", max_num_objects): + pyro.sample("exists", assignment.exists_dist, infer={"enumerate": "parallel"}) + with pyro.plate("detections_plate", num_detections): + pyro.sample("assign", assignment.assign_dist, infer={"enumerate": "parallel"}) def generate_data(args): num_objects = args.expected_num_objects true_objects = torch.randn(num_objects) - true_assign = dist.Categorical(torch.ones(args.num_real_detections, num_objects)).sample() + true_assign = dist.Categorical( + torch.ones(args.num_real_detections, num_objects) + ).sample() real_detections = true_objects[true_assign] - real_detections = real_detections + args.init_noise_scale * torch.randn(real_detections.shape) + real_detections = real_detections + args.init_noise_scale * torch.randn( + real_detections.shape + ) fake_detections = torch.randn(args.num_fake_detections) detections = torch.cat([real_detections, fake_detections]) assert detections.shape == (args.num_real_detections + args.num_fake_detections,) return detections -@pytest.mark.parametrize('assignment_grad', [False, True]) +@pytest.mark.parametrize("assignment_grad", [False, True]) def test_em(assignment_grad): args = make_args() args.assignment_grad = assignment_grad detections = generate_data(args) pyro.clear_param_store() - pyro.param('noise_scale', torch.tensor(args.init_noise_scale), - constraint=constraints.positive) - pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) + pyro.param( + "noise_scale", + torch.tensor(args.init_noise_scale), + constraint=constraints.positive, + ) + pyro.param("objects_loc", torch.randn(args.max_num_objects, 1)) # Learn object_loc via EM algorithm. elbo = TraceEnum_ELBO(max_plate_nesting=2) - newton = Newton(trust_radii={'objects_loc': 1.0}) + newton = Newton(trust_radii={"objects_loc": 1.0}) for step in range(10): # Detach previous iterations. - objects_loc = pyro.param('objects_loc').detach_().requires_grad_() + objects_loc = pyro.param("objects_loc").detach_().requires_grad_() loss = elbo.differentiable_loss(model, guide, detections, args) # E-step - newton.step(loss, {'objects_loc': objects_loc}) # M-step - logger.debug('step {}, loss = {}'.format(step, loss.item())) + newton.step(loss, {"objects_loc": objects_loc}) # M-step + logger.debug("step {}, loss = {}".format(step, loss.item())) -@pytest.mark.parametrize('assignment_grad', [False, True]) +@pytest.mark.parametrize("assignment_grad", [False, True]) def test_em_nested_in_svi(assignment_grad): args = make_args() args.assignment_grad = assignment_grad detections = generate_data(args) pyro.clear_param_store() - pyro.param('noise_scale', torch.tensor(args.init_noise_scale), - constraint=constraints.positive) - pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) + pyro.param( + "noise_scale", + torch.tensor(args.init_noise_scale), + constraint=constraints.positive, + ) + pyro.param("objects_loc", torch.randn(args.max_num_objects, 1)) # Learn object_loc via EM and noise_scale via SVI. - optim = Adam({'lr': 0.1}) + optim = Adam({"lr": 0.1}) elbo = TraceEnum_ELBO(max_plate_nesting=2) - newton = Newton(trust_radii={'objects_loc': 1.0}) - svi = SVI(poutine.block(model, hide=['objects_loc']), - poutine.block(guide, hide=['objects_loc']), optim, elbo) + newton = Newton(trust_radii={"objects_loc": 1.0}) + svi = SVI( + poutine.block(model, hide=["objects_loc"]), + poutine.block(guide, hide=["objects_loc"]), + optim, + elbo, + ) for svi_step in range(50): for em_step in range(2): - objects_loc = pyro.param('objects_loc').detach_().requires_grad_() - assert pyro.param('objects_loc').grad_fn is None + objects_loc = pyro.param("objects_loc").detach_().requires_grad_() + assert pyro.param("objects_loc").grad_fn is None loss = elbo.differentiable_loss(model, guide, detections, args) # E-step - updated = newton.get_step(loss, {'objects_loc': objects_loc}) # M-step - assert updated['objects_loc'].grad_fn is not None - pyro.get_param_store()['objects_loc'] = updated['objects_loc'] - assert pyro.param('objects_loc').grad_fn is not None + updated = newton.get_step(loss, {"objects_loc": objects_loc}) # M-step + assert updated["objects_loc"].grad_fn is not None + pyro.get_param_store()["objects_loc"] = updated["objects_loc"] + assert pyro.param("objects_loc").grad_fn is not None loss = svi.step(detections, args) - logger.debug('step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format( - svi_step, loss, pyro.param('noise_scale').item())) + logger.debug( + "step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}".format( + svi_step, loss, pyro.param("noise_scale").item() + ) + ) def test_svi_multi(): @@ -180,21 +200,28 @@ def test_svi_multi(): detections = generate_data(args) pyro.clear_param_store() - pyro.param('noise_scale', torch.tensor(args.init_noise_scale), - constraint=constraints.positive) - pyro.param('objects_loc', torch.randn(args.max_num_objects, 1)) + pyro.param( + "noise_scale", + torch.tensor(args.init_noise_scale), + constraint=constraints.positive, + ) + pyro.param("objects_loc", torch.randn(args.max_num_objects, 1)) # Learn object_loc via Newton and noise_scale via Adam. elbo = TraceEnum_ELBO(max_plate_nesting=2) - adam = Adam({'lr': 0.1}) - newton = Newton(trust_radii={'objects_loc': 1.0}) - optim = MixedMultiOptimizer([(['noise_scale'], adam), - (['objects_loc'], newton)]) + adam = Adam({"lr": 0.1}) + newton = Newton(trust_radii={"objects_loc": 1.0}) + optim = MixedMultiOptimizer([(["noise_scale"], adam), (["objects_loc"], newton)]) for svi_step in range(50): with poutine.trace(param_only=True) as param_capture: loss = elbo.differentiable_loss(model, guide, detections, args) - params = {name: pyro.param(name).unconstrained() - for name in param_capture.trace.nodes.keys()} + params = { + name: pyro.param(name).unconstrained() + for name in param_capture.trace.nodes.keys() + } optim.step(loss, params) - logger.debug('step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format( - svi_step, loss.item(), pyro.param('noise_scale').item())) + logger.debug( + "step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}".format( + svi_step, loss.item(), pyro.param("noise_scale").item() + ) + ) diff --git a/tests/contrib/tracking/test_hashing.py b/tests/contrib/tracking/test_hashing.py index 39d22d2690..b714163636 100644 --- a/tests/contrib/tracking/test_hashing.py +++ b/tests/contrib/tracking/test_hashing.py @@ -12,21 +12,21 @@ logger = logging.getLogger(__name__) -@pytest.mark.parametrize('scale', [-1., 0., -1 * torch.ones(2, 2)]) +@pytest.mark.parametrize("scale", [-1.0, 0.0, -1 * torch.ones(2, 2)]) def test_lsh_init(scale): with pytest.raises(ValueError): LSH(scale) -@pytest.mark.parametrize('scale', [0.1, 1, 10, 100]) +@pytest.mark.parametrize("scale", [0.1, 1, 10, 100]) def test_lsh_add(scale): lsh = LSH(scale) a = torch.rand(10) - lsh.add('a', a) - assert lsh._hash_to_key[lsh._key_to_hash['a']] == {'a'} + lsh.add("a", a) + assert lsh._hash_to_key[lsh._key_to_hash["a"]] == {"a"} -@pytest.mark.parametrize('scale', [0.1, 1, 10, 100]) +@pytest.mark.parametrize("scale", [0.1, 1, 10, 100]) def test_lsh_hash_nearby(scale): k = 5 lsh = LSH(scale) @@ -44,51 +44,51 @@ def test_lsh_hash_nearby(scale): assert_equal(lsh._hash(e), (2,) * k) assert_equal(lsh._hash(f), (4,) * k) - lsh.add('a', a) - lsh.add('b', b) - lsh.add('c', c) - lsh.add('d', d) - lsh.add('e', e) - lsh.add('f', f) + lsh.add("a", a) + lsh.add("b", b) + lsh.add("c", c) + lsh.add("d", d) + lsh.add("e", e) + lsh.add("f", f) - assert lsh.nearby('a') == {'b'} - assert lsh.nearby('b') == {'a', 'c'} - assert lsh.nearby('c') == {'b', 'd'} - assert lsh.nearby('d') == {'c', 'e'} - assert lsh.nearby('e') == {'d'} - assert lsh.nearby('f') == set() + assert lsh.nearby("a") == {"b"} + assert lsh.nearby("b") == {"a", "c"} + assert lsh.nearby("c") == {"b", "d"} + assert lsh.nearby("d") == {"c", "e"} + assert lsh.nearby("e") == {"d"} + assert lsh.nearby("f") == set() def test_lsh_overwrite(): lsh = LSH(1) a = torch.zeros(2) b = torch.ones(2) - lsh.add('a', a) - lsh.add('b', b) - assert lsh.nearby('a') == {'b'} + lsh.add("a", a) + lsh.add("b", b) + assert lsh.nearby("a") == {"b"} b = torch.ones(2) * 4 - lsh.add('b', b) - assert lsh.nearby('a') == set() + lsh.add("b", b) + assert lsh.nearby("a") == set() def test_lsh_remove(): lsh = LSH(1) a = torch.zeros(2) b = torch.ones(2) - lsh.add('a', a) - lsh.add('b', b) - assert lsh.nearby('a') == {'b'} - lsh.remove('b') - assert lsh.nearby('a') == set() + lsh.add("a", a) + lsh.add("b", b) + assert lsh.nearby("a") == {"b"} + lsh.remove("b") + assert lsh.nearby("a") == set() -@pytest.mark.parametrize('scale', [-1., 0., -1 * torch.ones(2, 2)]) +@pytest.mark.parametrize("scale", [-1.0, 0.0, -1 * torch.ones(2, 2)]) def test_aps_init(scale): with pytest.raises(ValueError): ApproxSet(scale) -@pytest.mark.parametrize('scale', [0.1, 1, 10, 100]) +@pytest.mark.parametrize("scale", [0.1, 1, 10, 100]) def test_aps_hash(scale): k = 10 aps = ApproxSet(scale) @@ -107,7 +107,7 @@ def test_aps_hash(scale): assert_equal(aps._hash(f), (4,) * k) -@pytest.mark.parametrize('scale', [0.1, 1, 10, 100]) +@pytest.mark.parametrize("scale", [0.1, 1, 10, 100]) def test_aps_try_add(scale): k = 10 aps = ApproxSet(scale) @@ -123,13 +123,15 @@ def test_aps_try_add(scale): def test_merge_points_small(): - points = torch.tensor([ - [0., 0.], - [0., 1.], - [2., 0.], - [2., 0.5], - [2., 1.0], - ]) + points = torch.tensor( + [ + [0.0, 0.0], + [0.0, 1.0], + [2.0, 0.0], + [2.0, 0.5], + [2.0, 1.0], + ] + ) merged_points, groups = merge_points(points, radius=1.0) assert len(merged_points) == 3 @@ -140,12 +142,12 @@ def test_merge_points_small(): assert 0.325 <= merged_points[2, 1] <= 0.625 -@pytest.mark.parametrize('radius', [0.01, 0.1, 1., 10., 100.]) -@pytest.mark.parametrize('dim', [1, 2, 3]) +@pytest.mark.parametrize("radius", [0.01, 0.1, 1.0, 10.0, 100.0]) +@pytest.mark.parametrize("dim", [1, 2, 3]) def test_merge_points_large(dim, radius): points = 10 * torch.randn(200, dim) merged_points, groups = merge_points(points, radius) - logger.debug('merged {} -> {}'.format(len(points), len(merged_points))) + logger.debug("merged {} -> {}".format(len(points), len(merged_points))) assert merged_points.dim() == 2 assert merged_points.shape[-1] == dim diff --git a/tests/contrib/tracking/test_measurements.py b/tests/contrib/tracking/test_measurements.py index 373cad0e79..f946311906 100644 --- a/tests/contrib/tracking/test_measurements.py +++ b/tests/contrib/tracking/test_measurements.py @@ -12,15 +12,18 @@ def test_PositionMeasurement(): frame_num = 5 measurement = PositionMeasurement( mean=torch.rand(dimension), - cov=torch.eye(dimension), time=time, frame_num=frame_num) + cov=torch.eye(dimension), + time=time, + frame_num=frame_num, + ) assert measurement.dimension == dimension - x = torch.rand(2*dimension) + x = torch.rand(2 * dimension) assert measurement(x).shape == (dimension,) assert measurement.mean.shape == (dimension,) assert measurement.cov.shape == (dimension, dimension) assert measurement.time == time assert measurement.frame_num == frame_num assert measurement.geodesic_difference( - torch.rand(dimension), torch.rand(dimension)).shape \ - == (dimension,) - assert measurement.jacobian().shape == (dimension, 2*dimension) + torch.rand(dimension), torch.rand(dimension) + ).shape == (dimension,) + assert measurement.jacobian().shape == (dimension, 2 * dimension) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index fb7d1ef7c1..91d0c3d8da 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -52,615 +52,968 @@ def __init__(self, von_loc, von_conc, skewness): continuous_dists = [ - Fixture(pyro_dist=dist.Uniform, - scipy_dist=sp.uniform, - examples=[ - {'low': [2.], 'high': [2.5], - 'test_data': [2.2]}, - {'low': [2., 4.], 'high': [3., 5.], - 'test_data': [[[2.5, 4.5]], [[2.5, 4.5]], [[2.5, 4.5]]]}, - {'low': [[2.], [-3.], [0.]], - 'high': [[2.5], [0.], [1.]], - 'test_data': [[2.2], [-2], [0.7]]}, - ], - scipy_arg_fn=lambda low, high: ((), {"loc": np.array(low), - "scale": np.array(high) - np.array(low)})), - Fixture(pyro_dist=dist.Exponential, - scipy_dist=sp.expon, - examples=[ - {'rate': [2.4], - 'test_data': [5.5]}, - {'rate': [2.4, 5.5], - 'test_data': [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]]}, - {'rate': [[2.4, 5.5]], - 'test_data': [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]]}, - {'rate': [[2.4], [5.5]], - 'test_data': [[5.5], [3.2]]}, - ], - scipy_arg_fn=lambda rate: ((), {"scale": 1.0 / np.array(rate)})), - Fixture(pyro_dist=RejectionExponential, - scipy_dist=sp.expon, - examples=[ - {'rate': [2.4], 'factor': [0.5], - 'test_data': [5.5]}, - {'rate': [2.4, 5.5], 'factor': [0.5], - 'test_data': [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]]}, - {'rate': [[2.4, 5.5]], 'factor': [0.5], - 'test_data': [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]]}, - {'rate': [[2.4], [5.5]], 'factor': [0.5], - 'test_data': [[5.5], [3.2]]}, - ], - scipy_arg_fn=lambda rate, factor: ((), {"scale": 1.0 / np.array(rate)})), - Fixture(pyro_dist=dist.Gamma, - scipy_dist=sp.gamma, - examples=[ - {'concentration': [2.4], 'rate': [3.2], - 'test_data': [5.5]}, - {'concentration': [[2.4, 2.4], [3.2, 3.2]], 'rate': [[2.4, 2.4], [3.2, 3.2]], - 'test_data': [[[5.5, 4.4], [5.5, 4.4]]]}, - {'concentration': [[2.4], [2.4]], 'rate': [[3.2], [3.2]], 'test_data': [[5.5], [4.4]]} - ], - scipy_arg_fn=lambda concentration, rate: ((np.array(concentration),), - {"scale": 1.0 / np.array(rate)})), - Fixture(pyro_dist=ShapeAugmentedGamma, - scipy_dist=sp.gamma, - examples=[ - {'concentration': [2.4], 'rate': [3.2], - 'test_data': [5.5]}, - {'concentration': [[2.4, 2.4], [3.2, 3.2]], 'rate': [[2.4, 2.4], [3.2, 3.2]], - 'test_data': [[[5.5, 4.4], [5.5, 4.4]]]}, - {'concentration': [[2.4], [2.4]], 'rate': [[3.2], [3.2]], 'test_data': [[5.5], [4.4]]} - ], - scipy_arg_fn=lambda concentration, rate: ((np.array(concentration),), - {"scale": 1.0 / np.array(rate)})), - Fixture(pyro_dist=dist.Beta, - scipy_dist=sp.beta, - examples=[ - {'concentration1': [2.4], 'concentration0': [3.6], - 'test_data': [0.4]}, - {'concentration1': [[2.4, 2.4], [3.6, 3.6]], 'concentration0': [[2.5, 2.5], [2.5, 2.5]], - 'test_data': [[[0.5, 0.4], [0.5, 0.4]]]}, - {'concentration1': [[2.4], [3.7]], 'concentration0': [[3.6], [2.5]], - 'test_data': [[0.4], [0.6]]} - ], - scipy_arg_fn=lambda concentration1, concentration0: - ((np.array(concentration1), np.array(concentration0)), {})), - - Fixture(pyro_dist=NaiveBeta, - scipy_dist=sp.beta, - examples=[ - {'concentration1': [2.4], 'concentration0': [3.6], - 'test_data': [0.4]}, - {'concentration1': [[2.4, 2.4], [3.6, 3.6]], 'concentration0': [[2.5, 2.5], [2.5, 2.5]], - 'test_data': [[[0.5, 0.4], [0.5, 0.4]]]}, - {'concentration1': [[2.4], [3.7]], 'concentration0': [[3.6], [2.5]], - 'test_data': [[0.4], [0.6]]} - ], - scipy_arg_fn=lambda concentration1, concentration0: - ((np.array(concentration1), np.array(concentration0)), {})), - Fixture(pyro_dist=ShapeAugmentedBeta, - scipy_dist=sp.beta, - examples=[ - {'concentration1': [2.4], 'concentration0': [3.6], - 'test_data': [0.4]}, - {'concentration1': [[2.4, 2.4], [3.6, 3.6]], 'concentration0': [[2.5, 2.5], [2.5, 2.5]], - 'test_data': [[[0.5, 0.4], [0.5, 0.4]]]}, - {'concentration1': [[2.4], [3.7]], 'concentration0': [[3.6], [2.5]], - 'test_data': [[0.4], [0.6]]} - ], - scipy_arg_fn=lambda concentration1, concentration0: - ((np.array(concentration1), np.array(concentration0)), {})), - Fixture(pyro_dist=dist.LogNormal, - scipy_dist=sp.lognorm, - examples=[ - {'loc': [1.4], 'scale': [0.4], - 'test_data': [5.5]}, - {'loc': [1.4], 'scale': [0.4], - 'test_data': [[5.5]]}, - {'loc': [[1.4, 0.4], [1.4, 0.4]], 'scale': [[2.6, 0.5], [2.6, 0.5]], - 'test_data': [[5.5, 6.4], [5.5, 6.4]]}, - {'loc': [[1.4], [0.4]], 'scale': [[2.6], [0.5]], - 'test_data': [[5.5], [6.4]]} - ], - scipy_arg_fn=lambda loc, scale: ((np.array(scale),), {"scale": np.exp(np.array(loc))})), - Fixture(pyro_dist=dist.AffineBeta, - scipy_dist=sp.beta, - examples=[ - {'concentration1': [2.4], 'concentration0': [3.6], 'loc': [-1.0], 'scale': [2.0], - 'test_data': [-0.4]}, - {'concentration1': [[2.4, 2.4], [3.6, 3.6]], 'concentration0': [[2.5, 2.5], [2.5, 2.5]], - 'loc': [[-1.0, -1.0], [2.0, 2.0]], 'scale': [[2.0, 2.0], [1.0, 1.0]], - 'test_data': [[[-0.4, 0.4], [2.5, 2.6]]]}, - {'concentration1': [[2.4], [3.7]], 'concentration0': [[3.6], [2.5]], - 'loc': [[-1.0], [2.0]], 'scale': [[2.0], [2.0]], - 'test_data': [[0.0], [3.0]]} - ], - scipy_arg_fn=lambda concentration1, concentration0, loc, scale: - ((np.array(concentration1), np.array(concentration0), np.array(loc), np.array(scale)), {})), - Fixture(pyro_dist=dist.Normal, - scipy_dist=sp.norm, - examples=[ - {'loc': [2.0], 'scale': [4.0], - 'test_data': [2.0]}, - {'loc': [[2.0]], 'scale': [[4.0]], - 'test_data': [[2.0]]}, - {'loc': [[[2.0]]], 'scale': [[[4.0]]], - 'test_data': [[[2.0]]]}, - {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], - 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, - ], - scipy_arg_fn=lambda loc, scale: ((), {"loc": np.array(loc), "scale": np.array(scale)}), - prec=0.07, - min_samples=50000), - Fixture(pyro_dist=dist.MultivariateNormal, - scipy_dist=sp.multivariate_normal, - examples=[ - {'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]], - 'test_data': [[2.0, 1.0], [9.0, 3.4]]}, - ], - # This hack seems to be the best option right now, as 'scale' is not handled well by get_scipy_batch_logpdf - scipy_arg_fn=lambda loc, covariance_matrix=None: - ((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}), - prec=0.01, - min_samples=500000), - Fixture(pyro_dist=dist.LowRankMultivariateNormal, - scipy_dist=sp.multivariate_normal, - examples=[ - {'loc': [2.0, 1.0], 'cov_diag': [0.5, 0.5], 'cov_factor': [[1.0], [0.5]], - 'test_data': [[2.0, 1.0], [9.0, 3.4]]}, - ], - scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None: - ((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}), - prec=0.01, - min_samples=500000), - Fixture(pyro_dist=FoldedNormal, - examples=[ - {'loc': [2.0], 'scale': [4.0], - 'test_data': [2.0]}, - {'loc': [[2.0]], 'scale': [[4.0]], - 'test_data': [[2.0]]}, - {'loc': [[[2.0]]], 'scale': [[[4.0]]], - 'test_data': [[[2.0]]]}, - {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], - 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, - ]), - Fixture(pyro_dist=dist.Dirichlet, - scipy_dist=sp.dirichlet, - examples=[ - {'concentration': [2.4, 3, 6], - 'test_data': [0.2, 0.45, 0.35]}, - {'concentration': [2.4, 3, 6], - 'test_data': [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]]}, - {'concentration': [[2.4, 3, 6], [3.2, 1.2, 0.4]], - 'test_data': [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]]} - ], - scipy_arg_fn=lambda concentration: ((concentration,), {})), - Fixture(pyro_dist=NaiveDirichlet, - scipy_dist=sp.dirichlet, - examples=[ - {'concentration': [2.4, 3, 6], - 'test_data': [0.2, 0.45, 0.35]}, - {'concentration': [2.4, 3, 6], - 'test_data': [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]]}, - {'concentration': [[2.4, 3, 6], [3.2, 1.2, 0.4]], - 'test_data': [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]]} - ], - scipy_arg_fn=lambda concentration: ((concentration,), {})), - Fixture(pyro_dist=ShapeAugmentedDirichlet, - scipy_dist=sp.dirichlet, - examples=[ - {'concentration': [2.4, 3, 6], - 'test_data': [0.2, 0.45, 0.35]}, - {'concentration': [2.4, 3, 6], - 'test_data': [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]]}, - {'concentration': [[2.4, 3, 6], [3.2, 1.2, 0.4]], - 'test_data': [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]]} - ], - scipy_arg_fn=lambda concentration: ((concentration,), {})), - Fixture(pyro_dist=dist.Cauchy, - scipy_dist=sp.cauchy, - examples=[ - {'loc': [0.5], 'scale': [1.2], - 'test_data': [1.0]}, - {'loc': [0.5, 0.5], 'scale': [1.2, 1.2], - 'test_data': [[1.0, 1.0], [1.0, 1.0]]}, - {'loc': [[0.5], [0.3]], 'scale': [[1.2], [1.0]], - 'test_data': [[0.4], [0.35]]} - ], - scipy_arg_fn=lambda loc, scale: ((), {"loc": np.array(loc), "scale": np.array(scale)})), - Fixture(pyro_dist=dist.HalfCauchy, - scipy_dist=sp.halfcauchy, - examples=[ - {'scale': [1.2], - 'test_data': [1.0]}, - {'scale': [1.2, 1.2], - 'test_data': [[1.0, 2.0], [1.0, 2.0]]}, - {'scale': [[1.2], [1.0]], - 'test_data': [[0.54], [0.35]]} - ], - scipy_arg_fn=lambda scale: ((), {"scale": np.array(scale)})), - Fixture(pyro_dist=dist.VonMises, - scipy_dist=sp.vonmises, - examples=[ - {'loc': [0.5], 'concentration': [1.2], - 'test_data': [1.0]}, - {'loc': [0.5, 3.0], 'concentration': [2.0, 0.5], - 'test_data': [[1.0, 2.0], [1.0, 2.0]]}, - {'loc': [[0.5], [0.3]], 'concentration': [[2.0], [0.5]], - 'test_data': [[1.0], [2.0]]} - ], - scipy_arg_fn=lambda loc, concentration: ((), {"loc": np.array(loc), "kappa": np.array(concentration)})), - Fixture(pyro_dist=dist.LKJ, - examples=[ - {'dim': 3, 'concentration': 1., 'test_data': - [[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]], - [[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]], - [[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]], - [[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]], - [[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]}, - ]), - Fixture(pyro_dist=dist.LKJCholesky, - examples=[ - { - 'dim': 3, - 'concentration': 1., - 'test_data': [ - [[1.0, 0.0, 0.0], - [-0.17332135, 0.98486533, 0.0], - [0.43106407, -0.54767312, 0.71710384]], - [[1.0, 0.0, 0.0], - [-0.31391555, 0.94945091, 0.0], - [-0.31391296, -0.29767500, 0.90158097]], + Fixture( + pyro_dist=dist.Uniform, + scipy_dist=sp.uniform, + examples=[ + {"low": [2.0], "high": [2.5], "test_data": [2.2]}, + { + "low": [2.0, 4.0], + "high": [3.0, 5.0], + "test_data": [[[2.5, 4.5]], [[2.5, 4.5]], [[2.5, 4.5]]], + }, + { + "low": [[2.0], [-3.0], [0.0]], + "high": [[2.5], [0.0], [1.0]], + "test_data": [[2.2], [-2], [0.7]], + }, + ], + scipy_arg_fn=lambda low, high: ( + (), + {"loc": np.array(low), "scale": np.array(high) - np.array(low)}, + ), + ), + Fixture( + pyro_dist=dist.Exponential, + scipy_dist=sp.expon, + examples=[ + {"rate": [2.4], "test_data": [5.5]}, + { + "rate": [2.4, 5.5], + "test_data": [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]], + }, + { + "rate": [[2.4, 5.5]], + "test_data": [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]], + }, + {"rate": [[2.4], [5.5]], "test_data": [[5.5], [3.2]]}, + ], + scipy_arg_fn=lambda rate: ((), {"scale": 1.0 / np.array(rate)}), + ), + Fixture( + pyro_dist=RejectionExponential, + scipy_dist=sp.expon, + examples=[ + {"rate": [2.4], "factor": [0.5], "test_data": [5.5]}, + { + "rate": [2.4, 5.5], + "factor": [0.5], + "test_data": [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]], + }, + { + "rate": [[2.4, 5.5]], + "factor": [0.5], + "test_data": [[[5.5, 3.2]], [[5.5, 3.2]], [[5.5, 3.2]]], + }, + {"rate": [[2.4], [5.5]], "factor": [0.5], "test_data": [[5.5], [3.2]]}, + ], + scipy_arg_fn=lambda rate, factor: ((), {"scale": 1.0 / np.array(rate)}), + ), + Fixture( + pyro_dist=dist.Gamma, + scipy_dist=sp.gamma, + examples=[ + {"concentration": [2.4], "rate": [3.2], "test_data": [5.5]}, + { + "concentration": [[2.4, 2.4], [3.2, 3.2]], + "rate": [[2.4, 2.4], [3.2, 3.2]], + "test_data": [[[5.5, 4.4], [5.5, 4.4]]], + }, + { + "concentration": [[2.4], [2.4]], + "rate": [[3.2], [3.2]], + "test_data": [[5.5], [4.4]], + }, + ], + scipy_arg_fn=lambda concentration, rate: ( + (np.array(concentration),), + {"scale": 1.0 / np.array(rate)}, + ), + ), + Fixture( + pyro_dist=ShapeAugmentedGamma, + scipy_dist=sp.gamma, + examples=[ + {"concentration": [2.4], "rate": [3.2], "test_data": [5.5]}, + { + "concentration": [[2.4, 2.4], [3.2, 3.2]], + "rate": [[2.4, 2.4], [3.2, 3.2]], + "test_data": [[[5.5, 4.4], [5.5, 4.4]]], + }, + { + "concentration": [[2.4], [2.4]], + "rate": [[3.2], [3.2]], + "test_data": [[5.5], [4.4]], + }, + ], + scipy_arg_fn=lambda concentration, rate: ( + (np.array(concentration),), + {"scale": 1.0 / np.array(rate)}, + ), + ), + Fixture( + pyro_dist=dist.Beta, + scipy_dist=sp.beta, + examples=[ + {"concentration1": [2.4], "concentration0": [3.6], "test_data": [0.4]}, + { + "concentration1": [[2.4, 2.4], [3.6, 3.6]], + "concentration0": [[2.5, 2.5], [2.5, 2.5]], + "test_data": [[[0.5, 0.4], [0.5, 0.4]]], + }, + { + "concentration1": [[2.4], [3.7]], + "concentration0": [[3.6], [2.5]], + "test_data": [[0.4], [0.6]], + }, + ], + scipy_arg_fn=lambda concentration1, concentration0: ( + (np.array(concentration1), np.array(concentration0)), + {}, + ), + ), + Fixture( + pyro_dist=NaiveBeta, + scipy_dist=sp.beta, + examples=[ + {"concentration1": [2.4], "concentration0": [3.6], "test_data": [0.4]}, + { + "concentration1": [[2.4, 2.4], [3.6, 3.6]], + "concentration0": [[2.5, 2.5], [2.5, 2.5]], + "test_data": [[[0.5, 0.4], [0.5, 0.4]]], + }, + { + "concentration1": [[2.4], [3.7]], + "concentration0": [[3.6], [2.5]], + "test_data": [[0.4], [0.6]], + }, + ], + scipy_arg_fn=lambda concentration1, concentration0: ( + (np.array(concentration1), np.array(concentration0)), + {}, + ), + ), + Fixture( + pyro_dist=ShapeAugmentedBeta, + scipy_dist=sp.beta, + examples=[ + {"concentration1": [2.4], "concentration0": [3.6], "test_data": [0.4]}, + { + "concentration1": [[2.4, 2.4], [3.6, 3.6]], + "concentration0": [[2.5, 2.5], [2.5, 2.5]], + "test_data": [[[0.5, 0.4], [0.5, 0.4]]], + }, + { + "concentration1": [[2.4], [3.7]], + "concentration0": [[3.6], [2.5]], + "test_data": [[0.4], [0.6]], + }, + ], + scipy_arg_fn=lambda concentration1, concentration0: ( + (np.array(concentration1), np.array(concentration0)), + {}, + ), + ), + Fixture( + pyro_dist=dist.LogNormal, + scipy_dist=sp.lognorm, + examples=[ + {"loc": [1.4], "scale": [0.4], "test_data": [5.5]}, + {"loc": [1.4], "scale": [0.4], "test_data": [[5.5]]}, + { + "loc": [[1.4, 0.4], [1.4, 0.4]], + "scale": [[2.6, 0.5], [2.6, 0.5]], + "test_data": [[5.5, 6.4], [5.5, 6.4]], + }, + { + "loc": [[1.4], [0.4]], + "scale": [[2.6], [0.5]], + "test_data": [[5.5], [6.4]], + }, + ], + scipy_arg_fn=lambda loc, scale: ( + (np.array(scale),), + {"scale": np.exp(np.array(loc))}, + ), + ), + Fixture( + pyro_dist=dist.AffineBeta, + scipy_dist=sp.beta, + examples=[ + { + "concentration1": [2.4], + "concentration0": [3.6], + "loc": [-1.0], + "scale": [2.0], + "test_data": [-0.4], + }, + { + "concentration1": [[2.4, 2.4], [3.6, 3.6]], + "concentration0": [[2.5, 2.5], [2.5, 2.5]], + "loc": [[-1.0, -1.0], [2.0, 2.0]], + "scale": [[2.0, 2.0], [1.0, 1.0]], + "test_data": [[[-0.4, 0.4], [2.5, 2.6]]], + }, + { + "concentration1": [[2.4], [3.7]], + "concentration0": [[3.6], [2.5]], + "loc": [[-1.0], [2.0]], + "scale": [[2.0], [2.0]], + "test_data": [[0.0], [3.0]], + }, + ], + scipy_arg_fn=lambda concentration1, concentration0, loc, scale: ( + ( + np.array(concentration1), + np.array(concentration0), + np.array(loc), + np.array(scale), + ), + {}, + ), + ), + Fixture( + pyro_dist=dist.Normal, + scipy_dist=sp.norm, + examples=[ + {"loc": [2.0], "scale": [4.0], "test_data": [2.0]}, + {"loc": [[2.0]], "scale": [[4.0]], "test_data": [[2.0]]}, + {"loc": [[[2.0]]], "scale": [[[4.0]]], "test_data": [[[2.0]]]}, + { + "loc": [2.0, 50.0], + "scale": [4.0, 100.0], + "test_data": [[2.0, 50.0], [2.0, 50.0]], + }, + ], + scipy_arg_fn=lambda loc, scale: ( + (), + {"loc": np.array(loc), "scale": np.array(scale)}, + ), + prec=0.07, + min_samples=50000, + ), + Fixture( + pyro_dist=dist.MultivariateNormal, + scipy_dist=sp.multivariate_normal, + examples=[ + { + "loc": [2.0, 1.0], + "covariance_matrix": [[1.0, 0.5], [0.5, 1.0]], + "test_data": [[2.0, 1.0], [9.0, 3.4]], + }, + ], + # This hack seems to be the best option right now, as 'scale' is not handled well by get_scipy_batch_logpdf + scipy_arg_fn=lambda loc, covariance_matrix=None: ( + (), + {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}, + ), + prec=0.01, + min_samples=500000, + ), + Fixture( + pyro_dist=dist.LowRankMultivariateNormal, + scipy_dist=sp.multivariate_normal, + examples=[ + { + "loc": [2.0, 1.0], + "cov_diag": [0.5, 0.5], + "cov_factor": [[1.0], [0.5]], + "test_data": [[2.0, 1.0], [9.0, 3.4]], + }, + ], + scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None: ( + (), + {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}, + ), + prec=0.01, + min_samples=500000, + ), + Fixture( + pyro_dist=FoldedNormal, + examples=[ + {"loc": [2.0], "scale": [4.0], "test_data": [2.0]}, + {"loc": [[2.0]], "scale": [[4.0]], "test_data": [[2.0]]}, + {"loc": [[[2.0]]], "scale": [[[4.0]]], "test_data": [[[2.0]]]}, + { + "loc": [2.0, 50.0], + "scale": [4.0, 100.0], + "test_data": [[2.0, 50.0], [2.0, 50.0]], + }, + ], + ), + Fixture( + pyro_dist=dist.Dirichlet, + scipy_dist=sp.dirichlet, + examples=[ + {"concentration": [2.4, 3, 6], "test_data": [0.2, 0.45, 0.35]}, + { + "concentration": [2.4, 3, 6], + "test_data": [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]], + }, + { + "concentration": [[2.4, 3, 6], [3.2, 1.2, 0.4]], + "test_data": [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]], + }, + ], + scipy_arg_fn=lambda concentration: ((concentration,), {}), + ), + Fixture( + pyro_dist=NaiveDirichlet, + scipy_dist=sp.dirichlet, + examples=[ + {"concentration": [2.4, 3, 6], "test_data": [0.2, 0.45, 0.35]}, + { + "concentration": [2.4, 3, 6], + "test_data": [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]], + }, + { + "concentration": [[2.4, 3, 6], [3.2, 1.2, 0.4]], + "test_data": [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]], + }, + ], + scipy_arg_fn=lambda concentration: ((concentration,), {}), + ), + Fixture( + pyro_dist=ShapeAugmentedDirichlet, + scipy_dist=sp.dirichlet, + examples=[ + {"concentration": [2.4, 3, 6], "test_data": [0.2, 0.45, 0.35]}, + { + "concentration": [2.4, 3, 6], + "test_data": [[0.2, 0.45, 0.35], [0.2, 0.45, 0.35]], + }, + { + "concentration": [[2.4, 3, 6], [3.2, 1.2, 0.4]], + "test_data": [[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]], + }, + ], + scipy_arg_fn=lambda concentration: ((concentration,), {}), + ), + Fixture( + pyro_dist=dist.Cauchy, + scipy_dist=sp.cauchy, + examples=[ + {"loc": [0.5], "scale": [1.2], "test_data": [1.0]}, + { + "loc": [0.5, 0.5], + "scale": [1.2, 1.2], + "test_data": [[1.0, 1.0], [1.0, 1.0]], + }, + { + "loc": [[0.5], [0.3]], + "scale": [[1.2], [1.0]], + "test_data": [[0.4], [0.35]], + }, + ], + scipy_arg_fn=lambda loc, scale: ( + (), + {"loc": np.array(loc), "scale": np.array(scale)}, + ), + ), + Fixture( + pyro_dist=dist.HalfCauchy, + scipy_dist=sp.halfcauchy, + examples=[ + {"scale": [1.2], "test_data": [1.0]}, + {"scale": [1.2, 1.2], "test_data": [[1.0, 2.0], [1.0, 2.0]]}, + {"scale": [[1.2], [1.0]], "test_data": [[0.54], [0.35]]}, + ], + scipy_arg_fn=lambda scale: ((), {"scale": np.array(scale)}), + ), + Fixture( + pyro_dist=dist.VonMises, + scipy_dist=sp.vonmises, + examples=[ + {"loc": [0.5], "concentration": [1.2], "test_data": [1.0]}, + { + "loc": [0.5, 3.0], + "concentration": [2.0, 0.5], + "test_data": [[1.0, 2.0], [1.0, 2.0]], + }, + { + "loc": [[0.5], [0.3]], + "concentration": [[2.0], [0.5]], + "test_data": [[1.0], [2.0]], + }, + ], + scipy_arg_fn=lambda loc, concentration: ( + (), + {"loc": np.array(loc), "kappa": np.array(concentration)}, + ), + ), + Fixture( + pyro_dist=dist.LKJ, + examples=[ + { + "dim": 3, + "concentration": 1.0, + "test_data": [ + [ + [1.0000, -0.8221, 0.7655], + [-0.8221, 1.0000, -0.5293], + [0.7655, -0.5293, 1.0000], + ], + [ + [1.0000, -0.5345, -0.5459], + [-0.5345, 1.0000, -0.0333], + [-0.5459, -0.0333, 1.0000], + ], + [ + [1.0000, -0.3758, -0.2409], + [-0.3758, 1.0000, 0.4653], + [-0.2409, 0.4653, 1.0000], + ], + [ + [1.0000, -0.8800, -0.9493], + [-0.8800, 1.0000, 0.9088], + [-0.9493, 0.9088, 1.0000], + ], + [ + [1.0000, 0.2284, -0.1283], + [0.2284, 1.0000, 0.0146], + [-0.1283, 0.0146, 1.0000], ], - }, - ]), - Fixture(pyro_dist=dist.Stable, - examples=[ - {'stability': [1.5], 'skew': 0.1, 'test_data': [-10.]}, - {'stability': [1.5], 'skew': 0.1, 'scale': 2.0, 'loc': -2.0, 'test_data': [10.]}, - ]), - Fixture(pyro_dist=dist.MultivariateStudentT, - examples=[ - {'df': 1.5, 'loc': [0.2, 0.3], 'scale_tril': [[0.8, 0.0], [1.3, 0.4]], - 'test_data': [-3., 2]}, - ]), - Fixture(pyro_dist=dist.ProjectedNormal, - examples=[ - {'concentration': [0., 0.], 'test_data': [1., 0.]}, - {'concentration': [2., 3.], 'test_data': [0., 1.]}, - {'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]}, - {'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]}, - ]), - Fixture(pyro_dist=dist.SineBivariateVonMises, - examples=[ - {'phi_loc': [0.], 'psi_loc': [0.], 'phi_concentration': [5.], 'psi_concentration': [6.], - 'correlation': [2.], 'test_data': [[0., 0.]]}, - {'phi_loc': [3.003], 'psi_loc': [-1.343], 'phi_concentration': [5.], 'psi_concentration': [6.], - 'correlation': [2.], 'test_data': [[0., 1.]]}, - {'phi_loc': [-math.pi / 3], 'psi_loc': -1., 'phi_concentration': .5, 'psi_concentration': 10., - 'correlation': .9, 'test_data': [[1., 0.555]]}, - {'phi_loc': [math.pi - .2, 1.], 'psi_loc': [0., 1.], - 'phi_concentration': [5., 5.], 'psi_concentration': [7., .5], - 'weighted_correlation': [.5, .1], 'test_data': [[[1., -3.], [1., 59.]]]}, - ]), - Fixture(pyro_dist=dist.SoftLaplace, - examples=[ - {'loc': [2.0], 'scale': [4.0], - 'test_data': [2.0]}, - {'loc': [[2.0]], 'scale': [[4.0]], - 'test_data': [[2.0]]}, - {'loc': [[[2.0]]], 'scale': [[[4.0]]], - 'test_data': [[[2.0]]]}, - {'loc': [2.0, 50.0], 'scale': [4.0, 100.0], - 'test_data': [[2.0, 50.0], [2.0, 50.0]]}, - ]), - Fixture(pyro_dist=SineSkewedUniform, - examples=[ - {'lower': [-pi, -pi], - 'upper':[pi, pi], - 'skewness': [-pi / 4, .1], - 'test_data': [pi / 2, -2 * pi / 3]} - ]), - Fixture(pyro_dist=SineSkewedVonMises, - examples=[ - {'von_loc': [0.], - 'von_conc': [1.], - 'skewness': [.342355], - 'test_data': [.1]}, - {'von_loc': [0., -1.234], - 'von_conc': [1., 10.], - 'skewness': [[.342355, -.0001], [.91, 0.09]], - 'test_data': [[.1, -3.2], [-2., 0.]]} - ]), - Fixture(pyro_dist=dist.AsymmetricLaplace, - examples=[ - {'loc': [1.0], 'scale': [1.0], 'asymmetry': [2.0], - 'test_data': [2.0]}, - {'loc': [2.0, -50.0], 'scale': [2.0, 10.0], - 'asymmetry': [0.5, 2.5], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]}, - ]), - Fixture(pyro_dist=dist.SoftAsymmetricLaplace, - examples=[ - {'loc': [1.0], 'scale': [1.0], 'asymmetry': [2.0], - 'test_data': [2.0]}, - {'loc': [2.0, -50.0], 'scale': [2.0, 10.0], 'asymmetry': [0.5, 2.5], - 'softness': [0.7, 1.4], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]}, - ]), - Fixture(pyro_dist=dist.SkewLogistic, - examples=[ - {'loc': [1.0], 'scale': [1.0], 'asymmetry': [2.0], - 'test_data': [2.0]}, - {'loc': [2.0, -50.0], 'scale': [2.0, 10.0], 'asymmetry': [0.5, 2.5], - 'test_data': [[2.0, 10.0], [-1.0, -50.0]]}, - ]), + ], + }, + ], + ), + Fixture( + pyro_dist=dist.LKJCholesky, + examples=[ + { + "dim": 3, + "concentration": 1.0, + "test_data": [ + [ + [1.0, 0.0, 0.0], + [-0.17332135, 0.98486533, 0.0], + [0.43106407, -0.54767312, 0.71710384], + ], + [ + [1.0, 0.0, 0.0], + [-0.31391555, 0.94945091, 0.0], + [-0.31391296, -0.29767500, 0.90158097], + ], + ], + }, + ], + ), + Fixture( + pyro_dist=dist.Stable, + examples=[ + {"stability": [1.5], "skew": 0.1, "test_data": [-10.0]}, + { + "stability": [1.5], + "skew": 0.1, + "scale": 2.0, + "loc": -2.0, + "test_data": [10.0], + }, + ], + ), + Fixture( + pyro_dist=dist.MultivariateStudentT, + examples=[ + { + "df": 1.5, + "loc": [0.2, 0.3], + "scale_tril": [[0.8, 0.0], [1.3, 0.4]], + "test_data": [-3.0, 2], + }, + ], + ), + Fixture( + pyro_dist=dist.ProjectedNormal, + examples=[ + {"concentration": [0.0, 0.0], "test_data": [1.0, 0.0]}, + {"concentration": [2.0, 3.0], "test_data": [0.0, 1.0]}, + {"concentration": [0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]}, + {"concentration": [-1.0, 2.0, 3.0], "test_data": [0.0, 0.0, 1.0]}, + ], + ), + Fixture( + pyro_dist=dist.SineBivariateVonMises, + examples=[ + { + "phi_loc": [0.0], + "psi_loc": [0.0], + "phi_concentration": [5.0], + "psi_concentration": [6.0], + "correlation": [2.0], + "test_data": [[0.0, 0.0]], + }, + { + "phi_loc": [3.003], + "psi_loc": [-1.343], + "phi_concentration": [5.0], + "psi_concentration": [6.0], + "correlation": [2.0], + "test_data": [[0.0, 1.0]], + }, + { + "phi_loc": [-math.pi / 3], + "psi_loc": -1.0, + "phi_concentration": 0.5, + "psi_concentration": 10.0, + "correlation": 0.9, + "test_data": [[1.0, 0.555]], + }, + { + "phi_loc": [math.pi - 0.2, 1.0], + "psi_loc": [0.0, 1.0], + "phi_concentration": [5.0, 5.0], + "psi_concentration": [7.0, 0.5], + "weighted_correlation": [0.5, 0.1], + "test_data": [[[1.0, -3.0], [1.0, 59.0]]], + }, + ], + ), + Fixture( + pyro_dist=dist.SoftLaplace, + examples=[ + {"loc": [2.0], "scale": [4.0], "test_data": [2.0]}, + {"loc": [[2.0]], "scale": [[4.0]], "test_data": [[2.0]]}, + {"loc": [[[2.0]]], "scale": [[[4.0]]], "test_data": [[[2.0]]]}, + { + "loc": [2.0, 50.0], + "scale": [4.0, 100.0], + "test_data": [[2.0, 50.0], [2.0, 50.0]], + }, + ], + ), + Fixture( + pyro_dist=SineSkewedUniform, + examples=[ + { + "lower": [-pi, -pi], + "upper": [pi, pi], + "skewness": [-pi / 4, 0.1], + "test_data": [pi / 2, -2 * pi / 3], + } + ], + ), + Fixture( + pyro_dist=SineSkewedVonMises, + examples=[ + { + "von_loc": [0.0], + "von_conc": [1.0], + "skewness": [0.342355], + "test_data": [0.1], + }, + { + "von_loc": [0.0, -1.234], + "von_conc": [1.0, 10.0], + "skewness": [[0.342355, -0.0001], [0.91, 0.09]], + "test_data": [[0.1, -3.2], [-2.0, 0.0]], + }, + ], + ), + Fixture( + pyro_dist=dist.AsymmetricLaplace, + examples=[ + {"loc": [1.0], "scale": [1.0], "asymmetry": [2.0], "test_data": [2.0]}, + { + "loc": [2.0, -50.0], + "scale": [2.0, 10.0], + "asymmetry": [0.5, 2.5], + "test_data": [[2.0, 10.0], [-1.0, -50.0]], + }, + ], + ), + Fixture( + pyro_dist=dist.SoftAsymmetricLaplace, + examples=[ + {"loc": [1.0], "scale": [1.0], "asymmetry": [2.0], "test_data": [2.0]}, + { + "loc": [2.0, -50.0], + "scale": [2.0, 10.0], + "asymmetry": [0.5, 2.5], + "softness": [0.7, 1.4], + "test_data": [[2.0, 10.0], [-1.0, -50.0]], + }, + ], + ), + Fixture( + pyro_dist=dist.SkewLogistic, + examples=[ + {"loc": [1.0], "scale": [1.0], "asymmetry": [2.0], "test_data": [2.0]}, + { + "loc": [2.0, -50.0], + "scale": [2.0, 10.0], + "asymmetry": [0.5, 2.5], + "test_data": [[2.0, 10.0], [-1.0, -50.0]], + }, + ], + ), ] discrete_dists = [ - Fixture(pyro_dist=dist.OrderedLogistic, - examples=[ - {'cutpoints': [0., 1., 2.], - 'predictor': [1.], - 'test_data': [1]}, - {'cutpoints': [0., 1., 2.], - 'predictor': [-0.5, 0.5, 1.5, 2.5], - 'test_data': [0, 1, 2, 3]}, - {'cutpoints': [0., 1.], - 'predictor': [[-0.5, 0.5, 1.5], [-0.5, 0.5, 1.5]], - 'test_data': [[0, 1, 2], [0, 1, 2]]}, - ], - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.Multinomial, - scipy_dist=sp.multinomial, - examples=[ - {'probs': [0.1, 0.6, 0.3], - 'test_data': [0., 1., 0.]}, - {'probs': [0.1, 0.6, 0.3], 'total_count': 8, - 'test_data': [2., 4., 2.]}, - {'probs': [0.1, 0.6, 0.3], 'total_count': 8, - 'test_data': [[2., 4., 2.], [2., 4., 2.]]}, - {'probs': [[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]], 'total_count': 8, - 'test_data': [[2., 4., 2.], [1., 4., 3.]]} - ], - scipy_arg_fn=lambda probs, total_count=[1]: ((total_count[0], np.array(probs)), {}), - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.Bernoulli, - scipy_dist=sp.bernoulli, - examples=[ - {'probs': [0.25], - 'test_data': [1.]}, - {'probs': [0.25, 0.25], - 'test_data': [[[0., 1.]], [[1., 0.]], [[0., 0.]]]}, - {'logits': [math.log(p / (1 - p)) for p in (0.25, 0.25)], - 'test_data': [[[0., 1.]], [[1., 0.]], [[0., 0.]]]}, - # for now, avoid tests on infinite logits - # {'logits': [-float('inf'), 0], - # 'test_data': [[0, 1], [0, 1], [0, 1]]}, - {'logits': [[math.log(p / (1 - p)) for p in (0.25, 0.25)], - [math.log(p / (1 - p)) for p in (0.3, 0.3)]], - 'test_data': [[1., 1.], [0., 0.]]}, - {'probs': [[0.25, 0.25], [0.3, 0.3]], - 'test_data': [[1., 1.], [0., 0.]]} - ], + Fixture( + pyro_dist=dist.OrderedLogistic, + examples=[ + {"cutpoints": [0.0, 1.0, 2.0], "predictor": [1.0], "test_data": [1]}, + { + "cutpoints": [0.0, 1.0, 2.0], + "predictor": [-0.5, 0.5, 1.5, 2.5], + "test_data": [0, 1, 2, 3], + }, + { + "cutpoints": [0.0, 1.0], + "predictor": [[-0.5, 0.5, 1.5], [-0.5, 0.5, 1.5]], + "test_data": [[0, 1, 2], [0, 1, 2]], + }, + ], + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Multinomial, + scipy_dist=sp.multinomial, + examples=[ + {"probs": [0.1, 0.6, 0.3], "test_data": [0.0, 1.0, 0.0]}, + {"probs": [0.1, 0.6, 0.3], "total_count": 8, "test_data": [2.0, 4.0, 2.0]}, + { + "probs": [0.1, 0.6, 0.3], + "total_count": 8, + "test_data": [[2.0, 4.0, 2.0], [2.0, 4.0, 2.0]], + }, + { + "probs": [[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]], + "total_count": 8, + "test_data": [[2.0, 4.0, 2.0], [1.0, 4.0, 3.0]], + }, + ], + scipy_arg_fn=lambda probs, total_count=[1]: ( + (total_count[0], np.array(probs)), + {}, + ), + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Bernoulli, + scipy_dist=sp.bernoulli, + examples=[ + {"probs": [0.25], "test_data": [1.0]}, + { + "probs": [0.25, 0.25], + "test_data": [[[0.0, 1.0]], [[1.0, 0.0]], [[0.0, 0.0]]], + }, + { + "logits": [math.log(p / (1 - p)) for p in (0.25, 0.25)], + "test_data": [[[0.0, 1.0]], [[1.0, 0.0]], [[0.0, 0.0]]], + }, # for now, avoid tests on infinite logits - # test_data_indices=[0, 1, 2, 3], - batch_data_indices=[-1, -2], - scipy_arg_fn=lambda **kwargs: ((), {'p': kwargs['probs']}), - prec=0.01, - min_samples=10000, - is_discrete=True, - expected_support_non_vec=[[0.], [1.]], - expected_support=[[[0., 0.], [0., 0.]], [[1., 1.], [1., 1.]]]), - Fixture(pyro_dist=dist.BetaBinomial, - examples=[ - {'concentration1': [2.], 'concentration0': [5.], 'total_count': 8, - 'test_data': [4.]}, - {'concentration1': [2.], 'concentration0': [5.], 'total_count': 8, - 'test_data': [[2.], [4.]]}, - {'concentration1': [[2.], [2.]], 'concentration0': [[5.], [5.]], 'total_count': 8, - 'test_data': [[4.], [3.]]}, - {'concentration1': [2., 2.], 'concentration0': [5., 5.], 'total_count': [0., 0.], - 'test_data': [[0., 0.], [0., 0.]]}, - {'concentration1': [2., 2.], 'concentration0': [5., 5.], 'total_count': [[8., 7.], [5., 9.]], - 'test_data': [[6., 3.], [2., 8.]]}, - ], - batch_data_indices=[-1, -2], - prec=0.01, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.Binomial, - scipy_dist=sp.binom, - examples=[ - {'probs': [0.6], 'total_count': 8, - 'test_data': [4.]}, - {'probs': [0.3], 'total_count': 8, - 'test_data': [[2.], [4.]]}, - {'probs': [[0.2], [0.4]], 'total_count': 8, - 'test_data': [[4.], [3.]]}, - {'probs': [0.2, 0.4], 'total_count': [0., 0.], - 'test_data': [[0., 0.], [0., 0.]]}, - {'probs': [0.2, 0.4], 'total_count': [[8., 7.], [5., 9.]], - 'test_data': [[6., 3.], [2., 8.]]}, - ], - scipy_arg_fn=lambda probs, total_count: ((total_count, probs), {}), - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.ExtendedBetaBinomial, - examples=[ - {'concentration1': [2.], 'concentration0': [5.], 'total_count': 8, - 'test_data': [4.]}, - {'concentration1': [2.], 'concentration0': [5.], 'total_count': 8, - 'test_data': [[2.], [4.]]}, - {'concentration1': [[2.], [2.]], 'concentration0': [[5.], [5.]], 'total_count': 8, - 'test_data': [[4.], [3.]]}, - {'concentration1': [2., 2.], 'concentration0': [5., 5.], 'total_count': [0., 0.], - 'test_data': [[0., 0.], [0., 0.]]}, - {'concentration1': [2., 2.], 'concentration0': [5., 5.], 'total_count': [[8., 7.], [5., 9.]], - 'test_data': [[6., 3.], [2., 8.]]}, - ], - batch_data_indices=[-1, -2], - prec=0.01, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.ExtendedBinomial, - scipy_dist=sp.binom, - examples=[ - {'probs': [0.6], 'total_count': 8, - 'test_data': [4.]}, - {'probs': [0.3], 'total_count': 8, - 'test_data': [[2.], [4.]]}, - {'probs': [[0.2], [0.4]], 'total_count': 8, - 'test_data': [[4.], [3.]]}, - {'probs': [0.2, 0.4], 'total_count': [0., 0.], - 'test_data': [[0., 0.], [0., 0.]]}, - {'probs': [0.2, 0.4], 'total_count': [[8., 7.], [5., 9.]], - 'test_data': [[6., 3.], [2., 8.]]}, - ], - scipy_arg_fn=lambda probs, total_count: ((total_count, probs), {}), - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.Categorical, - scipy_dist=sp.multinomial, - examples=[ - {'probs': [0.1, 0.6, 0.3], - 'test_data': [2]}, - {'logits': list(map(math.log, [0.1, 0.6, 0.3])), - 'test_data': [2]}, - {'logits': [list(map(math.log, [0.1, 0.6, 0.3])), - list(map(math.log, [0.2, 0.4, 0.4]))], - 'test_data': [2, 0]}, - {'probs': [[0.1, 0.6, 0.3], - [0.2, 0.4, 0.4]], - 'test_data': [2, 0]} - ], - test_data_indices=[0, 1, 2], - batch_data_indices=[-1, -2], - scipy_arg_fn=None, - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.DirichletMultinomial, - examples=[ - {'concentration': [0.1, 0.6, 0.3], - 'test_data': [0., 1., 0.]}, - {'concentration': [0.5, 1.0, 2.0], 'total_count': 8, - 'test_data': [0., 2., 6.]}, - {'concentration': [[0.5, 1.0, 2.0], [3., 3., 0.1]], 'total_count': 8, - 'test_data': [[0., 2., 6.], [5., 2., 1.]]}, - ], - prec=0.08, - is_discrete=True), - Fixture(pyro_dist=dist.GammaPoisson, - examples=[ - {'concentration': [1.], 'rate': [2.], - 'test_data': [0.]}, - {'concentration': [1.], 'rate': [2.], - 'test_data': [1.]}, - {'concentration': [1.], 'rate': [2.], - 'test_data': [4.]}, - {'concentration': [1., 1., 1.], 'rate': [2., 2., 3.], - 'test_data': [[0., 1., 4.], [0., 1., 4.]]}, - {'concentration': [[1.0], [1.0], [1.0]], 'rate': [[2.0], [2.0], [3.0]], - 'test_data': [[0.], [1.], [4.]]} - ], - prec=0.08, - is_discrete=True), - Fixture(pyro_dist=dist.OneHotCategorical, - scipy_dist=sp.multinomial, - examples=[ - {'probs': [0.1, 0.6, 0.3], - 'test_data': [0., 0., 1.]}, - {'logits': list(map(math.log, [0.1, 0.6, 0.3])), - 'test_data': [0., 0., 1.]}, - {'logits': [list(map(math.log, [0.1, 0.6, 0.3])), - list(map(math.log, [0.2, 0.4, 0.4]))], - 'test_data': [[0., 0., 1.], [1., 0., 0.]]}, - {'probs': [[0.1, 0.6, 0.3], - [0.2, 0.4, 0.4]], - 'test_data': [[0., 0., 1.], [1., 0., 0.]]} - ], - test_data_indices=[0, 1, 2], - batch_data_indices=[-1, -2], - scipy_arg_fn=lambda probs: ((1, np.array(probs)), {}), - prec=0.05, - min_samples=10000, - is_discrete=True), - Fixture(pyro_dist=dist.Poisson, - scipy_dist=sp.poisson, - examples=[ - {'rate': [2.0], - 'test_data': [0.]}, - {'rate': [3.0], - 'test_data': [1.]}, - {'rate': [6.0], - 'test_data': [4.]}, - {'rate': [2.0, 3.0, 6.0], - 'test_data': [[0., 1., 4.], [0., 1., 4.]]}, - {'rate': [[2.0], [3.0], [6.0]], - 'test_data': [[0.], [1.], [4.]]} - ], - scipy_arg_fn=lambda rate: ((np.array(rate),), {}), - prec=0.08, - is_discrete=True), - Fixture(pyro_dist=SparsePoisson, - scipy_dist=sp.poisson, - examples=[ - {'rate': [2.0], - 'test_data': [0.]}, - {'rate': [3.0], - 'test_data': [1.]}, - {'rate': [6.0], - 'test_data': [4.]}, - {'rate': [2.0, 3.0, 6.0], - 'test_data': [[0., 1., 4.], [0., 1., 4.]]}, - {'rate': [[2.0], [3.0], [6.0]], - 'test_data': [[0.], [1.], [4.]]} - ], - scipy_arg_fn=lambda rate: ((np.array(rate),), {}), - prec=0.08, - is_discrete=True), - Fixture(pyro_dist=dist.Geometric, - scipy_dist=sp.geom, - examples=[ - {'logits': [2.0], - 'test_data': [0.]}, - {'logits': [3.0], - 'test_data': [1.]}, - {'logits': [-6.0], - 'test_data': [4.]}, - {'logits': [2.0, 3.0, -6.0], - 'test_data': [[0., 1., 4.], [0., 1., 4.]]}, - {'logits': [[2.0], [3.0], [-6.0]], - 'test_data': [[0.], [1.], [4.]]} - ], - scipy_arg_fn=lambda probs: ((np.array(probs), -1), {}), - prec=0.08, - is_discrete=True), + # {'logits': [-float('inf'), 0], + # 'test_data': [[0, 1], [0, 1], [0, 1]]}, + { + "logits": [ + [math.log(p / (1 - p)) for p in (0.25, 0.25)], + [math.log(p / (1 - p)) for p in (0.3, 0.3)], + ], + "test_data": [[1.0, 1.0], [0.0, 0.0]], + }, + { + "probs": [[0.25, 0.25], [0.3, 0.3]], + "test_data": [[1.0, 1.0], [0.0, 0.0]], + }, + ], + # for now, avoid tests on infinite logits + # test_data_indices=[0, 1, 2, 3], + batch_data_indices=[-1, -2], + scipy_arg_fn=lambda **kwargs: ((), {"p": kwargs["probs"]}), + prec=0.01, + min_samples=10000, + is_discrete=True, + expected_support_non_vec=[[0.0], [1.0]], + expected_support=[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], + ), + Fixture( + pyro_dist=dist.BetaBinomial, + examples=[ + { + "concentration1": [2.0], + "concentration0": [5.0], + "total_count": 8, + "test_data": [4.0], + }, + { + "concentration1": [2.0], + "concentration0": [5.0], + "total_count": 8, + "test_data": [[2.0], [4.0]], + }, + { + "concentration1": [[2.0], [2.0]], + "concentration0": [[5.0], [5.0]], + "total_count": 8, + "test_data": [[4.0], [3.0]], + }, + { + "concentration1": [2.0, 2.0], + "concentration0": [5.0, 5.0], + "total_count": [0.0, 0.0], + "test_data": [[0.0, 0.0], [0.0, 0.0]], + }, + { + "concentration1": [2.0, 2.0], + "concentration0": [5.0, 5.0], + "total_count": [[8.0, 7.0], [5.0, 9.0]], + "test_data": [[6.0, 3.0], [2.0, 8.0]], + }, + ], + batch_data_indices=[-1, -2], + prec=0.01, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Binomial, + scipy_dist=sp.binom, + examples=[ + {"probs": [0.6], "total_count": 8, "test_data": [4.0]}, + {"probs": [0.3], "total_count": 8, "test_data": [[2.0], [4.0]]}, + {"probs": [[0.2], [0.4]], "total_count": 8, "test_data": [[4.0], [3.0]]}, + { + "probs": [0.2, 0.4], + "total_count": [0.0, 0.0], + "test_data": [[0.0, 0.0], [0.0, 0.0]], + }, + { + "probs": [0.2, 0.4], + "total_count": [[8.0, 7.0], [5.0, 9.0]], + "test_data": [[6.0, 3.0], [2.0, 8.0]], + }, + ], + scipy_arg_fn=lambda probs, total_count: ((total_count, probs), {}), + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.ExtendedBetaBinomial, + examples=[ + { + "concentration1": [2.0], + "concentration0": [5.0], + "total_count": 8, + "test_data": [4.0], + }, + { + "concentration1": [2.0], + "concentration0": [5.0], + "total_count": 8, + "test_data": [[2.0], [4.0]], + }, + { + "concentration1": [[2.0], [2.0]], + "concentration0": [[5.0], [5.0]], + "total_count": 8, + "test_data": [[4.0], [3.0]], + }, + { + "concentration1": [2.0, 2.0], + "concentration0": [5.0, 5.0], + "total_count": [0.0, 0.0], + "test_data": [[0.0, 0.0], [0.0, 0.0]], + }, + { + "concentration1": [2.0, 2.0], + "concentration0": [5.0, 5.0], + "total_count": [[8.0, 7.0], [5.0, 9.0]], + "test_data": [[6.0, 3.0], [2.0, 8.0]], + }, + ], + batch_data_indices=[-1, -2], + prec=0.01, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.ExtendedBinomial, + scipy_dist=sp.binom, + examples=[ + {"probs": [0.6], "total_count": 8, "test_data": [4.0]}, + {"probs": [0.3], "total_count": 8, "test_data": [[2.0], [4.0]]}, + {"probs": [[0.2], [0.4]], "total_count": 8, "test_data": [[4.0], [3.0]]}, + { + "probs": [0.2, 0.4], + "total_count": [0.0, 0.0], + "test_data": [[0.0, 0.0], [0.0, 0.0]], + }, + { + "probs": [0.2, 0.4], + "total_count": [[8.0, 7.0], [5.0, 9.0]], + "test_data": [[6.0, 3.0], [2.0, 8.0]], + }, + ], + scipy_arg_fn=lambda probs, total_count: ((total_count, probs), {}), + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Categorical, + scipy_dist=sp.multinomial, + examples=[ + {"probs": [0.1, 0.6, 0.3], "test_data": [2]}, + {"logits": list(map(math.log, [0.1, 0.6, 0.3])), "test_data": [2]}, + { + "logits": [ + list(map(math.log, [0.1, 0.6, 0.3])), + list(map(math.log, [0.2, 0.4, 0.4])), + ], + "test_data": [2, 0], + }, + {"probs": [[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]], "test_data": [2, 0]}, + ], + test_data_indices=[0, 1, 2], + batch_data_indices=[-1, -2], + scipy_arg_fn=None, + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.DirichletMultinomial, + examples=[ + {"concentration": [0.1, 0.6, 0.3], "test_data": [0.0, 1.0, 0.0]}, + { + "concentration": [0.5, 1.0, 2.0], + "total_count": 8, + "test_data": [0.0, 2.0, 6.0], + }, + { + "concentration": [[0.5, 1.0, 2.0], [3.0, 3.0, 0.1]], + "total_count": 8, + "test_data": [[0.0, 2.0, 6.0], [5.0, 2.0, 1.0]], + }, + ], + prec=0.08, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.GammaPoisson, + examples=[ + {"concentration": [1.0], "rate": [2.0], "test_data": [0.0]}, + {"concentration": [1.0], "rate": [2.0], "test_data": [1.0]}, + {"concentration": [1.0], "rate": [2.0], "test_data": [4.0]}, + { + "concentration": [1.0, 1.0, 1.0], + "rate": [2.0, 2.0, 3.0], + "test_data": [[0.0, 1.0, 4.0], [0.0, 1.0, 4.0]], + }, + { + "concentration": [[1.0], [1.0], [1.0]], + "rate": [[2.0], [2.0], [3.0]], + "test_data": [[0.0], [1.0], [4.0]], + }, + ], + prec=0.08, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.OneHotCategorical, + scipy_dist=sp.multinomial, + examples=[ + {"probs": [0.1, 0.6, 0.3], "test_data": [0.0, 0.0, 1.0]}, + { + "logits": list(map(math.log, [0.1, 0.6, 0.3])), + "test_data": [0.0, 0.0, 1.0], + }, + { + "logits": [ + list(map(math.log, [0.1, 0.6, 0.3])), + list(map(math.log, [0.2, 0.4, 0.4])), + ], + "test_data": [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + }, + { + "probs": [[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]], + "test_data": [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + }, + ], + test_data_indices=[0, 1, 2], + batch_data_indices=[-1, -2], + scipy_arg_fn=lambda probs: ((1, np.array(probs)), {}), + prec=0.05, + min_samples=10000, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Poisson, + scipy_dist=sp.poisson, + examples=[ + {"rate": [2.0], "test_data": [0.0]}, + {"rate": [3.0], "test_data": [1.0]}, + {"rate": [6.0], "test_data": [4.0]}, + {"rate": [2.0, 3.0, 6.0], "test_data": [[0.0, 1.0, 4.0], [0.0, 1.0, 4.0]]}, + {"rate": [[2.0], [3.0], [6.0]], "test_data": [[0.0], [1.0], [4.0]]}, + ], + scipy_arg_fn=lambda rate: ((np.array(rate),), {}), + prec=0.08, + is_discrete=True, + ), + Fixture( + pyro_dist=SparsePoisson, + scipy_dist=sp.poisson, + examples=[ + {"rate": [2.0], "test_data": [0.0]}, + {"rate": [3.0], "test_data": [1.0]}, + {"rate": [6.0], "test_data": [4.0]}, + {"rate": [2.0, 3.0, 6.0], "test_data": [[0.0, 1.0, 4.0], [0.0, 1.0, 4.0]]}, + {"rate": [[2.0], [3.0], [6.0]], "test_data": [[0.0], [1.0], [4.0]]}, + ], + scipy_arg_fn=lambda rate: ((np.array(rate),), {}), + prec=0.08, + is_discrete=True, + ), + Fixture( + pyro_dist=dist.Geometric, + scipy_dist=sp.geom, + examples=[ + {"logits": [2.0], "test_data": [0.0]}, + {"logits": [3.0], "test_data": [1.0]}, + {"logits": [-6.0], "test_data": [4.0]}, + { + "logits": [2.0, 3.0, -6.0], + "test_data": [[0.0, 1.0, 4.0], [0.0, 1.0, 4.0]], + }, + {"logits": [[2.0], [3.0], [-6.0]], "test_data": [[0.0], [1.0], [4.0]]}, + ], + scipy_arg_fn=lambda probs: ((np.array(probs), -1), {}), + prec=0.08, + is_discrete=True, + ), ] -@pytest.fixture(name='dist', - params=continuous_dists + discrete_dists, - ids=lambda x: x.get_test_distribution_name()) +@pytest.fixture( + name="dist", + params=continuous_dists + discrete_dists, + ids=lambda x: x.get_test_distribution_name(), +) def all_distributions(request): return request.param -@pytest.fixture(name='continuous_dist', - params=continuous_dists, - ids=lambda x: x.get_test_distribution_name()) +@pytest.fixture( + name="continuous_dist", + params=continuous_dists, + ids=lambda x: x.get_test_distribution_name(), +) def continuous_distributions(request): return request.param -@pytest.fixture(name='discrete_dist', - params=discrete_dists, - ids=lambda x: x.get_test_distribution_name()) +@pytest.fixture( + name="discrete_dist", + params=discrete_dists, + ids=lambda x: x.get_test_distribution_name(), +) def discrete_distributions(request): return request.param diff --git a/tests/distributions/dist_fixture.py b/tests/distributions/dist_fixture.py index d5fabbcb21..63ac0dfa11 100644 --- a/tests/distributions/dist_fixture.py +++ b/tests/distributions/dist_fixture.py @@ -14,18 +14,20 @@ class Fixture: - def __init__(self, - pyro_dist=None, - scipy_dist=None, - examples=None, - scipy_arg_fn=None, - prec=0.05, - min_samples=None, - is_discrete=False, - expected_support_non_vec=None, - expected_support=None, - test_data_indices=None, - batch_data_indices=None): + def __init__( + self, + pyro_dist=None, + scipy_dist=None, + examples=None, + scipy_arg_fn=None, + prec=0.05, + min_samples=None, + is_discrete=False, + expected_support_non_vec=None, + expected_support=None, + test_data_indices=None, + batch_data_indices=None, + ): self.pyro_dist = pyro_dist self.scipy_dist = scipy_dist self.dist_params, self.test_data = self._extract_fixture_data(examples) @@ -54,7 +56,7 @@ def get_test_data_indices(self): def _extract_fixture_data(self, examples): dist_params, test_data = [], [] for ex in examples: - test_data.append(ex.pop('test_data')) + test_data.append(ex.pop("test_data")) dist_params.append(ex) return dist_params, test_data @@ -62,7 +64,9 @@ def get_num_test_data(self): return len(self.test_data) def get_samples(self, num_samples, **dist_params): - return self.pyro_dist(**dist_params).sample(sample_shape=torch.Size((num_samples,))) + return self.pyro_dist(**dist_params).sample( + sample_shape=torch.Size((num_samples,)) + ) def get_test_data(self, idx, wrap_tensor=True): if not wrap_tensor: @@ -75,11 +79,14 @@ def get_dist_params(self, idx, wrap_tensor=True): return tensor_wrap(**self.dist_params[idx]) def _convert_logits_to_ps(self, dist_params): - if 'logits' in dist_params: - logits = torch.tensor(dist_params.pop('logits')) - is_multidimensional = self.get_test_distribution_name() not in ['Bernoulli', 'Geometric'] + if "logits" in dist_params: + logits = torch.tensor(dist_params.pop("logits")) + is_multidimensional = self.get_test_distribution_name() not in [ + "Bernoulli", + "Geometric", + ] probs = logits_to_probs(logits, is_binary=not is_multidimensional) - dist_params['probs'] = list(probs.detach().cpu().numpy()) + dist_params["probs"] = list(probs.detach().cpu().numpy()) return dist_params def get_scipy_logpdf(self, idx): @@ -89,9 +96,13 @@ def get_scipy_logpdf(self, idx): dist_params = self._convert_logits_to_ps(dist_params) args, kwargs = self.scipy_arg_fn(**dist_params) if self.is_discrete: - log_prob = self.scipy_dist.logpmf(self.get_test_data(idx, wrap_tensor=False), *args, **kwargs) + log_prob = self.scipy_dist.logpmf( + self.get_test_data(idx, wrap_tensor=False), *args, **kwargs + ) else: - log_prob = self.scipy_dist.logpdf(self.get_test_data(idx, wrap_tensor=False), *args, **kwargs) + log_prob = self.scipy_dist.logpdf( + self.get_test_data(idx, wrap_tensor=False), *args, **kwargs + ) return np.sum(log_prob) def get_scipy_batch_logpdf(self, idx): @@ -102,7 +113,9 @@ def get_scipy_batch_logpdf(self, idx): dist_params = self._convert_logits_to_ps(dist_params) test_data = self.get_test_data(idx, wrap_tensor=False) test_data_wrapped = self.get_test_data(idx) - shape = broadcast_shape(self.pyro_dist(**dist_params_wrapped).shape(), test_data_wrapped.size()) + shape = broadcast_shape( + self.pyro_dist(**dist_params_wrapped).shape(), test_data_wrapped.size() + ) log_prob = [] for i in range(len(test_data)): batch_params = {} @@ -137,7 +150,9 @@ def get_num_samples(self, idx): try: fourth_moment = np.max(self.scipy_dist.moment(4, *args, **kwargs)) var = np.max(self.scipy_dist.var(*args, **kwargs)) - min_computed_samples = int(math.ceil((fourth_moment - math.pow(var, 2)) / required_precision)) + min_computed_samples = int( + math.ceil((fourth_moment - math.pow(var, 2)) / required_precision) + ) except (AttributeError, ValueError): return min_samples return max(min_samples, min_computed_samples) diff --git a/tests/distributions/test_binomial.py b/tests/distributions/test_binomial.py index 31a8f56f09..5849a3fad7 100644 --- a/tests/distributions/test_binomial.py +++ b/tests/distributions/test_binomial.py @@ -26,8 +26,8 @@ def test_binomial_approx_sample(total_count, prob): @pytest.mark.parametrize("total_count", [10, 100, 1000, 4000]) -@pytest.mark.parametrize("concentration1", [0.1, 1.0, 10.]) -@pytest.mark.parametrize("concentration0", [0.1, 1.0, 10.]) +@pytest.mark.parametrize("concentration1", [0.1, 1.0, 10.0]) +@pytest.mark.parametrize("concentration0", [0.1, 1.0, 10.0]) def test_beta_binomial_approx_sample(concentration1, concentration0, total_count): sample_shape = (10000,) d = dist.BetaBinomial(concentration1, concentration0, total_count) @@ -39,13 +39,25 @@ def test_beta_binomial_approx_sample(concentration1, concentration0, total_count assert_close(expected.std(), actual.std(), rtol=0.1) -@pytest.mark.parametrize("tol", [ - 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., -]) +@pytest.mark.parametrize( + "tol", + [ + 1e-8, + 1e-6, + 1e-4, + 1e-2, + 0.02, + 0.05, + 0.1, + 0.2, + 0.1, + 1.0, + ], +) def test_binomial_approx_log_prob(tol): - logits = torch.linspace(-10., 10., 100) - k = torch.arange(100.).unsqueeze(-1) - n_minus_k = torch.arange(100.).unsqueeze(-1).unsqueeze(-1) + logits = torch.linspace(-10.0, 10.0, 100) + k = torch.arange(100.0).unsqueeze(-1) + n_minus_k = torch.arange(100.0).unsqueeze(-1).unsqueeze(-1) n = k + n_minus_k expected = torch.distributions.Binomial(n, logits=logits).log_prob(k) diff --git a/tests/distributions/test_categorical.py b/tests/distributions/test_categorical.py index 3d8642d30f..e60b4f14c1 100644 --- a/tests/distributions/test_categorical.py +++ b/tests/distributions/test_categorical.py @@ -37,16 +37,26 @@ def setUp(self): self.support = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) def test_log_prob_sum(self): - log_px_torch = dist.Categorical(self.probs).log_prob(self.test_data).sum().item() - log_px_np = float(sp.multinomial.logpmf(np.array([0, 0, 1]), 1, self.probs.detach().cpu().numpy())) + log_px_torch = ( + dist.Categorical(self.probs).log_prob(self.test_data).sum().item() + ) + log_px_np = float( + sp.multinomial.logpmf( + np.array([0, 0, 1]), 1, self.probs.detach().cpu().numpy() + ) + ) assert_equal(log_px_torch, log_px_np, prec=1e-4) def test_mean_and_var(self): - torch_samples = [dist.Categorical(self.probs).sample().detach().cpu().numpy() - for _ in range(self.n_samples)] + torch_samples = [ + dist.Categorical(self.probs).sample().detach().cpu().numpy() + for _ in range(self.n_samples) + ] _, counts = np.unique(torch_samples, return_counts=True) computed_mean = float(counts[0]) / self.n_samples - assert_equal(computed_mean, self.analytic_mean.detach().cpu().numpy()[0], prec=0.05) + assert_equal( + computed_mean, self.analytic_mean.detach().cpu().numpy()[0], prec=0.05 + ) def test_support_non_vectorized(self): s = dist.Categorical(self.d_ps[0].squeeze(0)).enumerate_support() @@ -60,7 +70,7 @@ def test_support(self): def wrap_nested(x, dim): if dim == 0: return x - return wrap_nested([x], dim-1) + return wrap_nested([x], dim - 1) @pytest.fixture(params=[1, 2, 3], ids=lambda x: "dim=" + str(x)) @@ -74,7 +84,7 @@ def probs(request): def modify_params_using_dims(probs, dim): - return torch.tensor(wrap_nested(probs, dim-1)) + return torch.tensor(wrap_nested(probs, dim - 1)) def test_support_dims(dim, probs): diff --git a/tests/distributions/test_coalescent.py b/tests/distributions/test_coalescent.py index 6df7cafd57..2fe90ad8b5 100644 --- a/tests/distributions/test_coalescent.py +++ b/tests/distributions/test_coalescent.py @@ -44,7 +44,7 @@ def test_simple_smoke(num_leaves, num_steps, batch_shape, sample_shape): leaf_times = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps d = CoalescentTimes(leaf_times) coal_times = d.sample(sample_shape) - assert coal_times.shape == sample_shape + batch_shape + (num_leaves-1,) + assert coal_times.shape == sample_shape + batch_shape + (num_leaves - 1,) actual = d.log_prob(coal_times) assert actual.shape == sample_shape + batch_shape @@ -55,14 +55,17 @@ def test_simple_smoke(num_leaves, num_steps, batch_shape, sample_shape): @pytest.mark.parametrize("rate_grid_shape", [(), (2,), (3, 1), (3, 2)], ids=str) @pytest.mark.parametrize("leaf_times_shape", [(), (2,), (3, 1), (3, 2)], ids=str) @pytest.mark.parametrize("num_leaves", [2, 7, 11]) -def test_with_rate_smoke(num_leaves, num_steps, leaf_times_shape, rate_grid_shape, sample_shape): +def test_with_rate_smoke( + num_leaves, num_steps, leaf_times_shape, rate_grid_shape, sample_shape +): batch_shape = broadcast_shape(leaf_times_shape, rate_grid_shape) leaf_times = torch.rand(leaf_times_shape + (num_leaves,)).pow(0.5) * num_steps rate_grid = torch.rand(rate_grid_shape + (num_steps,)) d = CoalescentTimesWithRate(leaf_times, rate_grid) coal_times = _sample_coalescent_times( - leaf_times.expand(sample_shape + batch_shape + (-1,))) - assert coal_times.shape == sample_shape + batch_shape + (num_leaves-1,) + leaf_times.expand(sample_shape + batch_shape + (-1,)) + ) + assert coal_times.shape == sample_shape + batch_shape + (num_leaves - 1,) actual = d.log_prob(coal_times) assert actual.shape == sample_shape + batch_shape @@ -188,8 +191,7 @@ def test_likelihood_sequential(num_leaves, num_steps, batch_shape, clamped): expected = d.log_prob(coal_times) likelihood = CoalescentRateLikelihood(leaf_times, coal_times, num_steps) - actual = sum(likelihood(rate_grid[..., t], t) - for t in range(num_steps)) + actual = sum(likelihood(rate_grid[..., t], t) for t in range(num_steps)) assert_close(actual, expected) diff --git a/tests/distributions/test_conjugate.py b/tests/distributions/test_conjugate.py index 940a1922c7..da3e29a8c2 100644 --- a/tests/distributions/test_conjugate.py +++ b/tests/distributions/test_conjugate.py @@ -11,14 +11,23 @@ from tests.common import assert_close -@pytest.mark.parametrize("dist", [ - BetaBinomial(2., 5., 10.), - BetaBinomial(torch.tensor([2., 4.]), torch.tensor([5., 8.]), torch.tensor([10., 12.])), - DirichletMultinomial(torch.tensor([0.5, 1.0, 2.0]), 5), - DirichletMultinomial(torch.tensor([[0.5, 1.0, 2.0], [0.2, 0.5, 0.8]]), torch.tensor(10.)), - GammaPoisson(2., 2.), - GammaPoisson(torch.tensor([6., 2]), torch.tensor([2., 8.])), -]) +@pytest.mark.parametrize( + "dist", + [ + BetaBinomial(2.0, 5.0, 10.0), + BetaBinomial( + torch.tensor([2.0, 4.0]), + torch.tensor([5.0, 8.0]), + torch.tensor([10.0, 12.0]), + ), + DirichletMultinomial(torch.tensor([0.5, 1.0, 2.0]), 5), + DirichletMultinomial( + torch.tensor([[0.5, 1.0, 2.0], [0.2, 0.5, 0.8]]), torch.tensor(10.0) + ), + GammaPoisson(2.0, 2.0), + GammaPoisson(torch.tensor([6.0, 2]), torch.tensor([2.0, 8.0])), + ], +) def test_mean(dist): analytic_mean = dist.mean num_samples = 500000 @@ -26,14 +35,23 @@ def test_mean(dist): assert_close(sample_mean, analytic_mean, atol=0.01) -@pytest.mark.parametrize("dist", [ - BetaBinomial(2., 5., 10.), - BetaBinomial(torch.tensor([2., 4.]), torch.tensor([5., 8.]), torch.tensor([10., 12.])), - DirichletMultinomial(torch.tensor([0.5, 1.0, 2.0]), 5), - DirichletMultinomial(torch.tensor([[0.5, 1.0, 2.0], [0.2, 0.5, 0.8]]), torch.tensor(10.)), - GammaPoisson(2., 2.), - GammaPoisson(torch.tensor([6., 2]), torch.tensor([2., 8.])), -]) +@pytest.mark.parametrize( + "dist", + [ + BetaBinomial(2.0, 5.0, 10.0), + BetaBinomial( + torch.tensor([2.0, 4.0]), + torch.tensor([5.0, 8.0]), + torch.tensor([10.0, 12.0]), + ), + DirichletMultinomial(torch.tensor([0.5, 1.0, 2.0]), 5), + DirichletMultinomial( + torch.tensor([[0.5, 1.0, 2.0], [0.2, 0.5, 0.8]]), torch.tensor(10.0) + ), + GammaPoisson(2.0, 2.0), + GammaPoisson(torch.tensor([6.0, 2]), torch.tensor([2.0, 8.0])), + ], +) def test_variance(dist): analytic_var = dist.variance num_samples = 500000 @@ -41,17 +59,20 @@ def test_variance(dist): assert_close(sample_var, analytic_var, rtol=0.01) -@pytest.mark.parametrize("dist, values", [ - (BetaBinomial(2., 5., 10), None), - (BetaBinomial(2., 5., 10), None), - (GammaPoisson(2., 2.), torch.arange(10.)), - (GammaPoisson(6., 2.), torch.arange(20.)), -]) +@pytest.mark.parametrize( + "dist, values", + [ + (BetaBinomial(2.0, 5.0, 10), None), + (BetaBinomial(2.0, 5.0, 10), None), + (GammaPoisson(2.0, 2.0), torch.arange(10.0)), + (GammaPoisson(6.0, 2.0), torch.arange(20.0)), + ], +) def test_log_prob_support(dist, values): if values is None: values = dist.enumerate_support() log_probs = dist.log_prob(values) - assert_close(log_probs.logsumexp(0), torch.tensor(0.), atol=0.01) + assert_close(log_probs.logsumexp(0), torch.tensor(0.0), atol=0.01) @pytest.mark.parametrize("total_count", [1, 2, 3, 10]) @@ -59,7 +80,7 @@ def test_log_prob_support(dist, values): def test_beta_binomial_log_prob(total_count, shape): concentration0 = torch.randn(shape).exp() concentration1 = torch.randn(shape).exp() - value = torch.arange(1. + total_count) + value = torch.arange(1.0 + total_count) num_samples = 100000 probs = dist.Beta(concentration1, concentration0).sample((num_samples,)) @@ -77,7 +98,9 @@ def test_dirichlet_multinomial_log_prob(total_count, batch_shape, is_sparse): event_shape = (3,) concentration = torch.rand(batch_shape + event_shape).exp() # test on one-hots - value = total_count * torch.eye(3).reshape(event_shape + (1,) * len(batch_shape) + event_shape) + value = total_count * torch.eye(3).reshape( + event_shape + (1,) * len(batch_shape) + event_shape + ) num_samples = 100000 probs = dist.Dirichlet(concentration).sample((num_samples, 1)) @@ -93,7 +116,7 @@ def test_dirichlet_multinomial_log_prob(total_count, batch_shape, is_sparse): def test_gamma_poisson_log_prob(shape): gamma_conc = torch.randn(shape).exp() gamma_rate = torch.randn(shape).exp() - value = torch.arange(20.) + value = torch.arange(20.0) num_samples = 300000 poisson_rate = dist.Gamma(gamma_conc, gamma_rate).sample((num_samples,)) diff --git a/tests/distributions/test_conjugate_update.py b/tests/distributions/test_conjugate_update.py index 7d352045ec..d467979c89 100644 --- a/tests/distributions/test_conjugate_update.py +++ b/tests/distributions/test_conjugate_update.py @@ -46,7 +46,7 @@ def test_gamma_poisson(sample_shape, batch_shape): concentration = torch.randn(batch_shape).exp() rate = torch.randn(batch_shape).exp() nobs = 5 - obs = dist.Poisson(10.).sample((nobs,) + sample_shape + batch_shape).sum(0) + obs = dist.Poisson(10.0).sample((nobs,) + sample_shape + batch_shape).sum(0) f = dist.Gamma(concentration, rate) g = dist.Gamma(1 + obs, nobs) diff --git a/tests/distributions/test_constraints.py b/tests/distributions/test_constraints.py index 6a86273b48..0abe1c90a3 100644 --- a/tests/distributions/test_constraints.py +++ b/tests/distributions/test_constraints.py @@ -22,9 +22,12 @@ def test_sphere_check(dim): @pytest.mark.parametrize("batch_shape", [(), (3, 4)]) @pytest.mark.parametrize( "constraint, event_shape", - [(constraints.positive_ordered_vector, (5,)), - (constraints.corr_matrix, (6,)), - (constraints.positive_definite, (3, 3))]) + [ + (constraints.positive_ordered_vector, (5,)), + (constraints.corr_matrix, (6,)), + (constraints.positive_definite, (3, 3)), + ], +) def test_constraints(constraint, batch_shape, event_shape): x = torch.randn(batch_shape + event_shape) y = torch.distributions.transform_to(constraint)(x) diff --git a/tests/distributions/test_cuda.py b/tests/distributions/test_cuda.py index 9cb887b467..4e8bd70068 100644 --- a/tests/distributions/test_cuda.py +++ b/tests/distributions/test_cuda.py @@ -24,7 +24,7 @@ def test_sample(dist): with xfail_if_not_implemented(): cpu_value = dist.pyro_dist(**params).sample() except ValueError as e: - pytest.xfail('CPU version fails: {}'.format(e)) + pytest.xfail("CPU version fails: {}".format(e)) assert not cpu_value.is_cuda # Compute GPU value. @@ -45,8 +45,11 @@ def test_rsample(dist): # Compute CPU value. with tensors_default_to("cpu"): params = dist.get_dist_params(idx) - grad_params = [key for key, val in params.items() - if torch.is_tensor(val) and val.dtype in (torch.float32, torch.float64)] + grad_params = [ + key + for key, val in params.items() + if torch.is_tensor(val) and val.dtype in (torch.float32, torch.float64) + ] for key in grad_params: val = params[key].clone() val.requires_grad = True @@ -56,7 +59,7 @@ def test_rsample(dist): cpu_value = dist.pyro_dist(**params).rsample() cpu_grads = grad(cpu_value.sum(), [params[key] for key in grad_params]) except ValueError as e: - pytest.xfail('CPU version fails: {}'.format(e)) + pytest.xfail("CPU version fails: {}".format(e)) assert not cpu_value.is_cuda # Compute GPU value. diff --git a/tests/distributions/test_delta.py b/tests/distributions/test_delta.py index 6adf33500b..932a923d1b 100644 --- a/tests/distributions/test_delta.py +++ b/tests/distributions/test_delta.py @@ -17,13 +17,13 @@ def setUp(self): self.vs = torch.tensor([[0.0], [1.0], [2.0], [3.0]]) self.vs_expanded = self.vs.expand(4, 3) self.test_data = torch.tensor([[3.0], [3.0], [3.0]]) - self.batch_test_data_1 = torch.arange(0., 4.).unsqueeze(1).expand(4, 3) - self.batch_test_data_2 = torch.arange(4., 8.).unsqueeze(1).expand(4, 3) - self.batch_test_data_3 = torch.Tensor([[3.], [3.], [3.], [3.]]) - self.expected_support = [[[0.], [1.], [2.], [3.]]] - self.expected_support_non_vec = [[3.]] - self.analytic_mean = 3. - self.analytic_var = 0. + self.batch_test_data_1 = torch.arange(0.0, 4.0).unsqueeze(1).expand(4, 3) + self.batch_test_data_2 = torch.arange(4.0, 8.0).unsqueeze(1).expand(4, 3) + self.batch_test_data_3 = torch.Tensor([[3.0], [3.0], [3.0], [3.0]]) + self.expected_support = [[[0.0], [1.0], [2.0], [3.0]]] + self.expected_support_non_vec = [[3.0]] + self.analytic_mean = 3.0 + self.analytic_var = 0.0 self.n_samples = 10 def test_log_prob_sum(self): @@ -31,27 +31,34 @@ def test_log_prob_sum(self): assert_equal(log_px_torch.item(), 0) def test_batch_log_prob(self): - log_px_torch = dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_1).data + log_px_torch = ( + dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_1).data + ) assert_equal(log_px_torch.sum().item(), 0) - log_px_torch = dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_2).data - assert_equal(log_px_torch.sum().item(), float('-inf')) + log_px_torch = ( + dist.Delta(self.vs_expanded).log_prob(self.batch_test_data_2).data + ) + assert_equal(log_px_torch.sum().item(), float("-inf")) def test_batch_log_prob_shape(self): assert dist.Delta(self.vs).log_prob(self.batch_test_data_3).size() == (4, 1) assert dist.Delta(self.v).log_prob(self.batch_test_data_3).size() == (4, 1) def test_mean_and_var(self): - torch_samples = [dist.Delta(self.v).sample().detach().cpu().numpy() - for _ in range(self.n_samples)] + torch_samples = [ + dist.Delta(self.v).sample().detach().cpu().numpy() + for _ in range(self.n_samples) + ] torch_mean = np.mean(torch_samples) torch_var = np.var(torch_samples) assert_equal(torch_mean, self.analytic_mean) assert_equal(torch_var, self.analytic_var) -@pytest.mark.parametrize('batch_dim,event_dim', - [(b, e) for b in range(4) for e in range(1+b)]) -@pytest.mark.parametrize('has_log_density', [False, True]) +@pytest.mark.parametrize( + "batch_dim,event_dim", [(b, e) for b in range(4) for e in range(1 + b)] +) +@pytest.mark.parametrize("has_log_density", [False, True]) def test_shapes(batch_dim, event_dim, has_log_density): shape = tuple(range(2, 2 + batch_dim + event_dim)) batch_shape = shape[:batch_dim] @@ -64,7 +71,7 @@ def test_shapes(batch_dim, event_dim, has_log_density): assert (d.log_prob(x) == log_density).all() -@pytest.mark.parametrize('batch_shape', [(), [], (2,), [2], torch.Size([2]), [2, 3]]) +@pytest.mark.parametrize("batch_shape", [(), [], (2,), [2], torch.Size([2]), [2, 3]]) def test_expand(batch_shape): d1 = dist.Delta(torch.tensor(1.234)) d2 = d1.expand(batch_shape) diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index f68047827a..cef9a3ce0e 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -28,6 +28,7 @@ def _log_prob_shape(dist, x_size=torch.Size()): # Distribution tests - all distributions + def test_support_shape(dist): for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) @@ -35,7 +36,9 @@ def test_support_shape(dist): assert d.support.event_dim == d.event_dim x = dist.get_test_data(idx) ok = d.support.check(x) - assert ok.shape == broadcast_shape(d.batch_shape, x.shape[:x.dim() - d.event_dim]) + assert ok.shape == broadcast_shape( + d.batch_shape, x.shape[: x.dim() - d.event_dim] + ) assert ok.all() @@ -44,8 +47,10 @@ def test_infer_shapes(dist): pytest.xfail(reason="cannot statically compute shape") for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) - arg_shapes = {k: v.shape if isinstance(v, torch.Tensor) else () - for k, v in dist_params.items()} + arg_shapes = { + k: v.shape if isinstance(v, torch.Tensor) else () + for k, v in dist_params.items() + } batch_shape, event_shape = dist.pyro_dist.infer_shapes(**arg_shapes) d = dist.pyro_dist(**dist_params) assert d.batch_shape == batch_shape @@ -54,7 +59,9 @@ def test_infer_shapes(dist): def test_batch_log_prob(dist): if dist.scipy_arg_fn is None: - pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__)) + pytest.skip( + "{}.log_prob_sum has no scipy equivalent".format(dist.pyro_dist.__name__) + ) for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) @@ -93,10 +100,14 @@ def test_score_errors_event_dim_mismatch(dist): d = dist.pyro_dist(**dist_params) test_data_wrong_dims = torch.ones(d.shape() + (1,)) if len(d.event_shape) > 0: - if dist.get_test_distribution_name() == 'MultivariateNormal': - pytest.skip('MultivariateNormal does not do shape validation in log_prob.') - elif dist.get_test_distribution_name() == 'LowRankMultivariateNormal': - pytest.skip('LowRankMultivariateNormal does not do shape validation in log_prob.') + if dist.get_test_distribution_name() == "MultivariateNormal": + pytest.skip( + "MultivariateNormal does not do shape validation in log_prob." + ) + elif dist.get_test_distribution_name() == "LowRankMultivariateNormal": + pytest.skip( + "LowRankMultivariateNormal does not do shape validation in log_prob." + ) with pytest.raises((ValueError, RuntimeError)): d.log_prob(test_data_wrong_dims) @@ -105,8 +116,8 @@ def test_score_errors_non_broadcastable_data_shape(dist): for idx in dist.get_batch_data_indices(): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) - if dist.get_test_distribution_name() == 'LKJCholesky': - pytest.skip('https://github.com/pytorch/pytorch/issues/52724') + if dist.get_test_distribution_name() == "LKJCholesky": + pytest.skip("https://github.com/pytorch/pytorch/issues/52724") shape = d.shape() non_broadcastable_shape = (shape[0] + 1,) + shape[1:] test_data_non_broadcastable = torch.ones(non_broadcastable_shape) @@ -116,6 +127,7 @@ def test_score_errors_non_broadcastable_data_shape(dist): # Distributions tests - continuous distributions + def test_support_is_not_discrete(continuous_dist): Dist = continuous_dist.pyro_dist for i in range(continuous_dist.get_num_test_data()): @@ -152,7 +164,13 @@ def test_gof(continuous_dist): def test_mean(continuous_dist): Dist = continuous_dist.pyro_dist - if Dist.__name__ in ["Cauchy", "HalfCauchy", "SineBivariateVonMises", "VonMises", "ProjectedNormal"]: + if Dist.__name__ in [ + "Cauchy", + "HalfCauchy", + "SineBivariateVonMises", + "VonMises", + "ProjectedNormal", + ]: pytest.xfail(reason="Euclidean mean is not defined") for i in range(continuous_dist.get_num_test_data()): d = Dist(**continuous_dist.get_dist_params(i)) @@ -201,6 +219,7 @@ def test_cdf_icdf(continuous_dist): # Distributions tests - discrete distributions + def test_support_is_discrete(discrete_dist): Dist = discrete_dist.pyro_dist for i in range(discrete_dist.get_num_test_data()): @@ -214,7 +233,9 @@ def test_enumerate_support(discrete_dist): if not expected_support: pytest.skip("enumerate_support not tested for distribution") Dist = discrete_dist.pyro_dist - actual_support_non_vec = Dist(**discrete_dist.get_dist_params(0)).enumerate_support() + actual_support_non_vec = Dist( + **discrete_dist.get_dist_params(0) + ).enumerate_support() actual_support = Dist(**discrete_dist.get_dist_params(-1)).enumerate_support() assert_equal(actual_support.data, torch.tensor(expected_support)) assert_equal(actual_support_non_vec.data, torch.tensor(expected_support_non_vec)) @@ -235,15 +256,21 @@ def test_enumerate_support_shape(dist): assert_equal(support, support_expanded) support_unexpanded = d.enumerate_support(expand=False) - assert support_unexpanded.shape == (n,) + (1,) * len(d.batch_shape) + d.event_shape + assert ( + support_unexpanded.shape + == (n,) + (1,) * len(d.batch_shape) + d.event_shape + ) assert (support_expanded == support_unexpanded).all() -@pytest.mark.parametrize("dist_class, args", [ - (dist.Normal, {"loc": torch.tensor(0.0), "scale": torch.tensor(-1.0)}), - (dist.Gamma, {"concentration": -1.0, "rate": 1.0}), - (dist.Exponential, {"rate": -2}) -]) +@pytest.mark.parametrize( + "dist_class, args", + [ + (dist.Normal, {"loc": torch.tensor(0.0), "scale": torch.tensor(-1.0)}), + (dist.Gamma, {"concentration": -1.0, "rate": 1.0}), + (dist.Exponential, {"rate": -2}), + ], +) @pytest.mark.parametrize("validate_args", [True, False]) def test_distribution_validate_args(dist_class, args, validate_args): with pyro.validation_enabled(validate_args): @@ -256,7 +283,9 @@ def test_distribution_validate_args(dist_class, args, validate_args): def check_sample_shapes(small, large): dist_instance = small - if isinstance(dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises)): + if isinstance( + dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises) + ): # Ignore broadcasting bug in LogNormal: # https://github.com/pytorch/pytorch/pull/7269 return @@ -266,37 +295,39 @@ def check_sample_shapes(small, large): assert_equal(small.log_prob(x), large.log_prob(x)) -@pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)]) -@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("shape_type", [torch.Size, tuple, list]) def test_expand_by(dist, sample_shape, shape_type): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) large = small.expand_by(shape_type(sample_shape)) assert large.batch_shape == sample_shape + small.batch_shape - if dist.get_test_distribution_name() == 'Stable': - pytest.skip('Stable does not implement a log_prob method.') + if dist.get_test_distribution_name() == "Stable": + pytest.skip("Stable does not implement a log_prob method.") check_sample_shapes(small, large) -@pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)]) -@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list]) -@pytest.mark.parametrize('default', [False, True]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("shape_type", [torch.Size, tuple, list]) +@pytest.mark.parametrize("default", [False, True]) def test_expand_new_dim(dist, sample_shape, shape_type, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) if default: - large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape)) + large = TorchDistribution.expand( + small, shape_type(sample_shape + small.batch_shape) + ) else: with xfail_if_not_implemented(): large = small.expand(shape_type(sample_shape + small.batch_shape)) assert large.batch_shape == sample_shape + small.batch_shape - if dist.get_test_distribution_name() == 'Stable': - pytest.skip('Stable does not implement a log_prob method.') + if dist.get_test_distribution_name() == "Stable": + pytest.skip("Stable does not implement a log_prob method.") check_sample_shapes(small, large) -@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list]) -@pytest.mark.parametrize('default', [False, True]) +@pytest.mark.parametrize("shape_type", [torch.Size, tuple, list]) +@pytest.mark.parametrize("default", [False, True]) def test_expand_existing_dim(dist, shape_type, default): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) @@ -312,16 +343,19 @@ def test_expand_existing_dim(dist, shape_type, default): with xfail_if_not_implemented(): large = small.expand(shape_type(batch_shape)) assert large.batch_shape == batch_shape - if dist.get_test_distribution_name() == 'Stable': - pytest.skip('Stable does not implement a log_prob method.') + if dist.get_test_distribution_name() == "Stable": + pytest.skip("Stable does not implement a log_prob method.") check_sample_shapes(small, large) -@pytest.mark.parametrize("sample_shapes", [ - [(2, 1), (2, 3)], - [(2, 1, 1), (2, 1, 3), (2, 5, 3)], -]) -@pytest.mark.parametrize('default', [False, True]) +@pytest.mark.parametrize( + "sample_shapes", + [ + [(2, 1), (2, 3)], + [(2, 1, 1), (2, 1, 3), (2, 5, 3)], + ], +) +@pytest.mark.parametrize("default", [False, True]) def test_subsequent_expands_ok(dist, sample_shapes, default): for idx in range(dist.get_num_test_data()): d = dist.pyro_dist(**dist.get_dist_params(idx)) @@ -339,11 +373,14 @@ def test_subsequent_expands_ok(dist, sample_shapes, default): d = n -@pytest.mark.parametrize("initial_shape, proposed_shape", [ - [(2, 1), (4, 3)], - [(2, 4), (2, 2, 1)], - [(1, 2, 1), (2, 1)], -]) +@pytest.mark.parametrize( + "initial_shape, proposed_shape", + [ + [(2, 1), (4, 3)], + [(2, 4), (2, 2, 1)], + [(1, 2, 1), (2, 1)], + ], +) @pytest.mark.parametrize("default", [False, True]) def test_expand_error(dist, initial_shape, proposed_shape, default): for idx in range(dist.get_num_test_data()): @@ -358,19 +395,24 @@ def test_expand_error(dist, initial_shape, proposed_shape, default): large.expand(proposed_batch_shape) -@pytest.mark.parametrize("extra_event_dims,expand_shape", [ - (0, [4, 3, 2, 1]), - (0, [4, 3, 2, 2]), - (1, [5, 4, 3, 2]), - (2, [5, 4, 3]), -]) -@pytest.mark.parametrize('default', [False, True]) +@pytest.mark.parametrize( + "extra_event_dims,expand_shape", + [ + (0, [4, 3, 2, 1]), + (0, [4, 3, 2, 2]), + (1, [5, 4, 3, 2]), + (2, [5, 4, 3]), + ], +) +@pytest.mark.parametrize("default", [False, True]) def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default): probs = torch.ones(1, 6) / 6 d = dist.OneHotCategorical(probs) full_shape = torch.Size([4, 1, 1, 1, 6]) if default: - reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event(extra_event_dims) + reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event( + extra_event_dims + ) else: reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims) cut = 4 - extra_event_dims @@ -393,5 +435,7 @@ def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default): def test_expand_enumerate_support(): probs = torch.ones(3, 6) / 6 d = dist.Categorical(probs) - actual_enum_shape = TorchDistribution.expand(d, (4, 3)).enumerate_support(expand=True).shape + actual_enum_shape = ( + TorchDistribution.expand(d, (4, 3)).enumerate_support(expand=True).shape + ) assert actual_enum_shape == (6, 4, 3) diff --git a/tests/distributions/test_empirical.py b/tests/distributions/test_empirical.py index 3f2d4435dd..2d64e831e3 100644 --- a/tests/distributions/test_empirical.py +++ b/tests/distributions/test_empirical.py @@ -22,13 +22,16 @@ def test_unweighted_mean_and_var(size, dtype): assert_equal(empirical_dist.variance, true_var) -@pytest.mark.parametrize("batch_shape, event_shape", [ - ([], []), - ([2], []), - ([2], [5]), - ([2], [5, 3]), - ([2, 5], [3]), -]) +@pytest.mark.parametrize( + "batch_shape, event_shape", + [ + ([], []), + ([2], []), + ([2], [5]), + ([2], [5, 3]), + ([2, 5], [3]), + ], +) @pytest.mark.parametrize("sample_shape", [[], [20], [20, 3, 4]]) @pytest.mark.parametrize("dtype", [torch.long, torch.float32, torch.float64]) def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype): @@ -36,9 +39,11 @@ def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype): # empirical samples with desired shape dim_ordering = list(range(len(batch_shape + event_shape) + 1)) # +1 for agg dim dim_ordering.insert(len(batch_shape), dim_ordering.pop()) - emp_samples = torch.arange(agg_dim_size, dtype=dtype)\ - .expand(batch_shape + event_shape + [agg_dim_size])\ + emp_samples = ( + torch.arange(agg_dim_size, dtype=dtype) + .expand(batch_shape + event_shape + [agg_dim_size]) .permute(dim_ordering) + ) # initial weight assignment weights = torch.ones(batch_shape + [agg_dim_size]) empirical_dist = Empirical(emp_samples, weights) @@ -46,18 +51,23 @@ def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype): assert_equal(samples.size(), torch.Size(sample_shape + batch_shape + event_shape)) -@pytest.mark.parametrize("sample, weights, expected_mean, expected_var", [( - torch.tensor([[0., 0., 0.], [1., 1., 1.]]), - torch.ones(2), - torch.tensor([0.5, 0.5, 0.5]), - torch.tensor([0.25, 0.25, 0.25]), - ), ( - torch.tensor([[0., 0., 0.], [1., 1., 1.]]), - torch.ones(2, 3), - torch.tensor([0., 1.]), - torch.tensor([0., 0.]), - ), -]) +@pytest.mark.parametrize( + "sample, weights, expected_mean, expected_var", + [ + ( + torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), + torch.ones(2), + torch.tensor([0.5, 0.5, 0.5]), + torch.tensor([0.25, 0.25, 0.25]), + ), + ( + torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), + torch.ones(2, 3), + torch.tensor([0.0, 1.0]), + torch.tensor([0.0, 0.0]), + ), + ], +) def test_sample_examples(sample, weights, expected_mean, expected_var): emp_dist = Empirical(sample, weights) num_samples = 10000 @@ -68,20 +78,23 @@ def test_sample_examples(sample, weights, expected_mean, expected_var): assert_close(emp_samples.var(0), emp_dist.variance, rtol=1e-2) -@pytest.mark.parametrize("batch_shape, event_shape", [ - ([], []), - ([1], []), - ([10], []), - ([10, 8], [3]), - ([10, 8], [3, 4]), -]) +@pytest.mark.parametrize( + "batch_shape, event_shape", + [ + ([], []), + ([1], []), + ([10], []), + ([10, 8], [3]), + ([10, 8], [3, 4]), + ], +) @pytest.mark.parametrize("dtype", [torch.long, torch.float32, torch.float64]) def test_log_prob(batch_shape, event_shape, dtype): samples = [] for i in range(5): samples.append(torch.ones(event_shape, dtype=dtype) * i) samples = torch.stack(samples).expand(batch_shape + [5] + event_shape) - weights = torch.tensor(1.).expand(batch_shape + [5]) + weights = torch.tensor(1.0).expand(batch_shape + [5]) empirical_dist = Empirical(samples, weights) sample_to_score = torch.tensor(1, dtype=dtype).expand(batch_shape + event_shape) log_prob = empirical_dist.log_prob(sample_to_score) @@ -116,7 +129,9 @@ def test_weighted_sample_coherence(event_shape, dtype): samples = empirical_dist.sample(sample_shape=torch.Size((1000,))) zeros = torch.zeros(event_shape, dtype=dtype) ones = torch.ones(event_shape, dtype=dtype) - num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() + num_zeros = ( + samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() + ) num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum() assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02) assert_equal(num_ones.item() / 1000, 0.25, prec=0.02) @@ -151,7 +166,7 @@ def test_mean_var_non_nan(): samples, weights = [], [] for i in range(10): samples.append(true_mean) - weights.append(torch.tensor(-1000.)) + weights.append(torch.tensor(-1000.0)) samples, weights = torch.stack(samples), torch.stack(weights) empirical_dist = Empirical(samples, weights) assert_equal(empirical_dist.mean, true_mean) diff --git a/tests/distributions/test_extended.py b/tests/distributions/test_extended.py index ca7952b8d9..dbb1c1eda1 100644 --- a/tests/distributions/test_extended.py +++ b/tests/distributions/test_extended.py @@ -17,10 +17,10 @@ def check_grad(value, *params): assert all(torch.isfinite(g).all() for g in grads) -@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1]) +@pytest.mark.parametrize("tol", [0.0, 0.02, 0.05, 0.1]) def test_extended_binomial(tol): with set_approx_log_prob_tol(tol): - total_count = torch.tensor([0., 1., 2., 10.]) + total_count = torch.tensor([0.0, 1.0, 2.0, 10.0]) probs = torch.tensor([0.5, 0.5, 0.4, 0.2]).requires_grad_() d1 = dist.Binomial(total_count, probs) @@ -30,7 +30,7 @@ def test_extended_binomial(tol): assert_equal(d1.log_prob(data), d2.log_prob(data)) # Check on extended data. - data = torch.arange(-10., 20.).unsqueeze(-1) + data = torch.arange(-10.0, 20.0).unsqueeze(-1) with pytest.raises(ValueError): d1.log_prob(data) log_prob = d2.log_prob(data) @@ -40,14 +40,14 @@ def test_extended_binomial(tol): # Check on shape error. with pytest.raises(ValueError): - d2.log_prob(torch.tensor([0., 0.])) + d2.log_prob(torch.tensor([0.0, 0.0])) # Check on value error. with pytest.raises(ValueError): d2.log_prob(torch.tensor(0.5)) # Check on negative total_count. - total_count = torch.arange(-10, 0.) + total_count = torch.arange(-10, 0.0) probs = torch.tensor(0.5).requires_grad_() d = dist.ExtendedBinomial(total_count, probs) log_prob = d.log_prob(data) @@ -55,12 +55,12 @@ def test_extended_binomial(tol): check_grad(log_prob, probs) -@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1]) +@pytest.mark.parametrize("tol", [0.0, 0.02, 0.05, 0.1]) def test_extended_beta_binomial(tol): with set_approx_log_prob_tol(tol): concentration1 = torch.tensor([0.2, 1.0, 2.0, 1.0]).requires_grad_() concentration0 = torch.tensor([0.2, 0.5, 1.0, 2.0]).requires_grad_() - total_count = torch.tensor([0., 1., 2., 10.]) + total_count = torch.tensor([0.0, 1.0, 2.0, 10.0]) d1 = dist.BetaBinomial(concentration1, concentration0, total_count) d2 = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count) @@ -70,7 +70,7 @@ def test_extended_beta_binomial(tol): assert_equal(d1.log_prob(data), d2.log_prob(data)) # Check on extended data. - data = torch.arange(-10., 20.).unsqueeze(-1) + data = torch.arange(-10.0, 20.0).unsqueeze(-1) with pytest.raises(ValueError): d1.log_prob(data) log_prob = d2.log_prob(data) @@ -80,7 +80,7 @@ def test_extended_beta_binomial(tol): # Check on shape error. with pytest.raises(ValueError): - d2.log_prob(torch.tensor([0., 0.])) + d2.log_prob(torch.tensor([0.0, 0.0])) # Check on value error. with pytest.raises(ValueError): @@ -89,7 +89,7 @@ def test_extended_beta_binomial(tol): # Check on negative total_count. concentration1 = torch.tensor(1.5).requires_grad_() concentration0 = torch.tensor(1.5).requires_grad_() - total_count = torch.arange(-10, 0.) + total_count = torch.arange(-10, 0.0) d = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count) log_prob = d.log_prob(data) assert (log_prob == -math.inf).all() diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index 3fcccd32f5..7623d4e14b 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -17,12 +17,15 @@ logger = logging.getLogger(__name__) -@pytest.mark.parametrize('mix_dist', [MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture]) -@pytest.mark.parametrize('K', [3]) -@pytest.mark.parametrize('D', [2, 4]) -@pytest.mark.parametrize('batch_mode', [True, False]) -@pytest.mark.parametrize('flat_logits', [True, False]) -@pytest.mark.parametrize('cost_function', ['quadratic']) +@pytest.mark.parametrize( + "mix_dist", + [MixtureOfDiagNormals, MixtureOfDiagNormalsSharedCovariance, GaussianScaleMixture], +) +@pytest.mark.parametrize("K", [3]) +@pytest.mark.parametrize("D", [2, 4]) +@pytest.mark.parametrize("batch_mode", [True, False]) +@pytest.mark.parametrize("flat_logits", [True, False]) +@pytest.mark.parametrize("cost_function", ["quadratic"]) def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): n_samples = 200000 if batch_mode: @@ -53,21 +56,28 @@ def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): _pis = torch.exp(component_logits) pis = _pis / _pis.sum() - if cost_function == 'cosine': + if cost_function == "cosine": analytic1 = torch.cos((omega * locs).sum(-1)) - analytic2 = torch.exp(-0.5 * torch.pow(omega * coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1)) + analytic2 = torch.exp( + -0.5 + * torch.pow(omega * coord_scale * component_scale.unsqueeze(-1), 2.0).sum( + -1 + ) + ) analytic = (pis * analytic1 * analytic2).sum() analytic.backward() - elif cost_function == 'quadratic': - analytic = torch.pow(coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1) + torch.pow(locs, 2.0).sum(-1) + elif cost_function == "quadratic": + analytic = torch.pow(coord_scale * component_scale.unsqueeze(-1), 2.0).sum( + -1 + ) + torch.pow(locs, 2.0).sum(-1) analytic = (pis * analytic).sum() analytic.backward() analytic_grads = {} - analytic_grads['locs'] = locs.grad.clone() - analytic_grads['coord_scale'] = coord_scale.grad.clone() - analytic_grads['component_logits'] = component_logits.grad.clone() - analytic_grads['component_scale'] = component_scale.grad.clone() + analytic_grads["locs"] = locs.grad.clone() + analytic_grads["coord_scale"] = coord_scale.grad.clone() + analytic_grads["component_logits"] = component_logits.grad.clone() + analytic_grads["component_scale"] = component_scale.grad.clone() assert locs.grad.shape == locs.shape assert coord_scale.grad.shape == coord_scale.shape @@ -80,25 +90,45 @@ def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): component_scale.grad.zero_() if mix_dist == MixtureOfDiagNormalsSharedCovariance: - params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} + params = { + "locs": locs, + "coord_scale": coord_scale, + "component_logits": component_logits, + } if batch_mode: locs = locs.unsqueeze(0).expand(n_samples, K, D) coord_scale = coord_scale.unsqueeze(0).expand(n_samples, D) component_logits = component_logits.unsqueeze(0).expand(n_samples, K) - dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} + dist_params = { + "locs": locs, + "coord_scale": coord_scale, + "component_logits": component_logits, + } else: dist_params = params elif mix_dist == MixtureOfDiagNormals: - params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} + params = { + "locs": locs, + "coord_scale": coord_scale, + "component_logits": component_logits, + } if batch_mode: locs = locs.unsqueeze(0).expand(n_samples, K, D) coord_scale = coord_scale.unsqueeze(0).expand(n_samples, K, D) component_logits = component_logits.unsqueeze(0).expand(n_samples, K) - dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits} + dist_params = { + "locs": locs, + "coord_scale": coord_scale, + "component_logits": component_logits, + } else: dist_params = params elif mix_dist == GaussianScaleMixture: - params = {'coord_scale': coord_scale, 'component_logits': component_logits, 'component_scale': component_scale} + params = { + "coord_scale": coord_scale, + "component_logits": component_logits, + "component_scale": component_scale, + } if batch_mode: return # distribution does not support batched parameters else: @@ -107,25 +137,38 @@ def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): dist = mix_dist(**dist_params) z = dist.rsample(sample_shape=sample_shape) assert z.shape == (n_samples, D) - if cost_function == 'cosine': + if cost_function == "cosine": cost = torch.cos((omega * z).sum(-1)).sum() / float(n_samples) - elif cost_function == 'quadratic': + elif cost_function == "quadratic": cost = torch.pow(z, 2.0).sum() / float(n_samples) cost.backward() - assert_equal(analytic, cost, prec=0.1, - msg='bad cost function evaluation for {} test (expected {}, got {})'.format( - mix_dist.__name__, analytic.item(), cost.item())) - logger.debug("analytic_grads_logit: {}" - .format(analytic_grads['component_logits'].detach().cpu().numpy())) + assert_equal( + analytic, + cost, + prec=0.1, + msg="bad cost function evaluation for {} test (expected {}, got {})".format( + mix_dist.__name__, analytic.item(), cost.item() + ), + ) + logger.debug( + "analytic_grads_logit: {}".format( + analytic_grads["component_logits"].detach().cpu().numpy() + ) + ) for param_name, param in params.items(): - assert_equal(param.grad, analytic_grads[param_name], prec=0.1, - msg='bad {} grad for {} (expected {}, got {})'.format( - param_name, mix_dist.__name__, analytic_grads[param_name], param.grad)) + assert_equal( + param.grad, + analytic_grads[param_name], + prec=0.1, + msg="bad {} grad for {} (expected {}, got {})".format( + param_name, mix_dist.__name__, analytic_grads[param_name], param.grad + ), + ) -@pytest.mark.parametrize('batch_size', [1, 3]) +@pytest.mark.parametrize("batch_size", [1, 3]) def test_mix_of_diag_normals_shared_cov_log_prob(batch_size): locs = torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) sigmas = torch.tensor([2.0, 2.0]) @@ -138,14 +181,18 @@ def test_mix_of_diag_normals_shared_cov_log_prob(batch_size): value = value.unsqueeze(0).expand(batch_size, 2) dist = MixtureOfDiagNormalsSharedCovariance(locs, sigmas, logits) log_prob = dist.log_prob(value) - correct_log_prob = 0.25 * math.exp(- 2.25 / 4.0) - correct_log_prob += 0.75 * math.exp(- 0.25 / 4.0) + correct_log_prob = 0.25 * math.exp(-2.25 / 4.0) + correct_log_prob += 0.75 * math.exp(-0.25 / 4.0) correct_log_prob /= 8.0 * math.pi correct_log_prob = math.log(correct_log_prob) if batch_size > 1: correct_log_prob = [correct_log_prob] * batch_size correct_log_prob = torch.tensor(correct_log_prob) - assert_equal(log_prob, correct_log_prob, msg='bad log prob for MixtureOfDiagNormalsSharedCovariance') + assert_equal( + log_prob, + correct_log_prob, + msg="bad log prob for MixtureOfDiagNormalsSharedCovariance", + ) def test_gsm_log_prob(): @@ -159,10 +206,12 @@ def test_gsm_log_prob(): correct_log_prob += 0.75 * math.exp(-0.50 / (4.0 * 6.25)) / 6.25 correct_log_prob /= (2.0 * math.pi) * 4.0 correct_log_prob = math.log(correct_log_prob) - assert_equal(log_prob, correct_log_prob, msg='bad log prob for GaussianScaleMixture') + assert_equal( + log_prob, correct_log_prob, msg="bad log prob for GaussianScaleMixture" + ) -@pytest.mark.parametrize('batch_size', [1, 3]) +@pytest.mark.parametrize("batch_size", [1, 3]) def test_mix_of_diag_normals_log_prob(batch_size): sigmas = torch.tensor([[2.0, 1.5], [1.5, 2.0]]) locs = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) @@ -177,9 +226,11 @@ def test_mix_of_diag_normals_log_prob(batch_size): log_prob = dist.log_prob(value) correct_log_prob = 0.25 * math.exp(-0.5 * (0.25 / 4.0 + 0.5625 / 2.25)) / 3.0 correct_log_prob += 0.75 * math.exp(-0.5 * (2.25 / 2.25 + 0.0625 / 4.0)) / 3.0 - correct_log_prob /= (2.0 * math.pi) + correct_log_prob /= 2.0 * math.pi correct_log_prob = math.log(correct_log_prob) if batch_size > 1: correct_log_prob = [correct_log_prob] * batch_size correct_log_prob = torch.tensor(correct_log_prob) - assert_equal(log_prob, correct_log_prob, msg='bad log prob for MixtureOfDiagNormals') + assert_equal( + log_prob, correct_log_prob, msg="bad log prob for MixtureOfDiagNormals" + ) diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 53857c2791..8d6b9c06bb 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -8,7 +8,7 @@ from tests.common import assert_equal -@pytest.mark.parametrize('size', [1, 3, 4, 7, 8, 9]) +@pytest.mark.parametrize("size", [1, 3, 4, 7, 8, 9]) def test_haar_ortho(size): haar = HaarTransform() eye = torch.eye(size) diff --git a/tests/distributions/test_hmm.py b/tests/distributions/test_hmm.py index 5b5118da23..ac0e67919d 100644 --- a/tests/distributions/test_hmm.py +++ b/tests/distributions/test_hmm.py @@ -55,9 +55,9 @@ def check_expand(old_dist, old_data): assert new_dist.log_prob(new_data).shape == new_batch_shape -@pytest.mark.parametrize('num_steps', list(range(1, 20))) -@pytest.mark.parametrize('state_dim', [2, 3]) -@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 4)], ids=str) +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): logits = torch.randn(batch_shape + (num_steps, state_dim, state_dim)) actual = _sequential_logmatmulexp(logits) @@ -66,18 +66,27 @@ def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): # Check against einsum. operands = list(logits.unbind(-3)) symbol = (opt_einsum.get_symbol(i) for i in range(1000)) - batch_symbols = ''.join(next(symbol) for _ in batch_shape) + batch_symbols = "".join(next(symbol) for _ in batch_shape) state_symbols = [next(symbol) for _ in range(num_steps + 1)] - equation = (','.join(batch_symbols + state_symbols[t] + state_symbols[t + 1] - for t in range(num_steps)) + - '->' + batch_symbols + state_symbols[0] + state_symbols[-1]) - expected = opt_einsum.contract(equation, *operands, backend='pyro.ops.einsum.torch_log') + equation = ( + ",".join( + batch_symbols + state_symbols[t] + state_symbols[t + 1] + for t in range(num_steps) + ) + + "->" + + batch_symbols + + state_symbols[0] + + state_symbols[-1] + ) + expected = opt_einsum.contract( + equation, *operands, backend="pyro.ops.einsum.torch_log" + ) assert_close(actual, expected) -@pytest.mark.parametrize('num_steps', list(range(1, 20))) -@pytest.mark.parametrize('state_dim', [1, 2, 3]) -@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 4)], ids=str) +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) actual = _sequential_gaussian_tensordot(g) @@ -91,20 +100,22 @@ def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): assert_close_gaussian(actual, expected) -@pytest.mark.parametrize('num_steps', list(range(1, 20))) -@pytest.mark.parametrize('state_dim', [1, 2, 3]) -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize('sample_shape', [(), (4,), (3, 2)], ids=str) -def test_sequential_gaussian_filter_sample(sample_shape, batch_shape, state_dim, num_steps): +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) +def test_sequential_gaussian_filter_sample( + sample_shape, batch_shape, state_dim, num_steps +): init = random_gaussian(batch_shape, state_dim) trans = random_gaussian(batch_shape + (num_steps,), state_dim + state_dim) sample = _sequential_gaussian_filter_sample(init, trans, sample_shape) assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim) -@pytest.mark.parametrize('num_steps', list(range(1, 20))) -@pytest.mark.parametrize('state_dim', [1, 2, 3]) -@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 4)], ids=str) +@pytest.mark.parametrize("num_steps", list(range(1, 20))) +@pytest.mark.parametrize("state_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim, num_steps): g = random_gamma_gaussian(batch_shape + (num_steps,), state_dim + state_dim) actual = _sequential_gamma_gaussian_tensordot(g) @@ -118,31 +129,37 @@ def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim, num_steps): assert_close_gamma_gaussian(actual, expected) -@pytest.mark.parametrize('state_dim', [2, 3]) -@pytest.mark.parametrize('event_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('ok,init_shape,trans_shape,obs_shape', [ - (True, (), (), (1,)), - (True, (), (1,), (1,)), - (True, (), (), (7,)), - (True, (), (7,), (7,)), - (True, (), (1,), (7,)), - (True, (), (7,), (11, 7)), - (True, (), (11, 7), (7,)), - (True, (), (11, 7), (11, 7)), - (True, (11,), (7,), (7,)), - (True, (11,), (7,), (11, 7)), - (True, (11,), (11, 7), (7,)), - (True, (11,), (11, 7), (11, 7)), - (True, (4, 1, 1), (3, 1, 7), (2, 7)), - (False, (), (1,), ()), - (False, (), (7,), ()), - (False, (), (7,), (1,)), - (False, (), (7,), (6,)), - (False, (3,), (4, 7), (7,)), - (False, (3,), (7,), (4, 7)), - (False, (), (3, 7), (4, 7)), -], ids=str) -def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape, event_shape, state_dim): +@pytest.mark.parametrize("state_dim", [2, 3]) +@pytest.mark.parametrize("event_shape", [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize( + "ok,init_shape,trans_shape,obs_shape", + [ + (True, (), (), (1,)), + (True, (), (1,), (1,)), + (True, (), (), (7,)), + (True, (), (7,), (7,)), + (True, (), (1,), (7,)), + (True, (), (7,), (11, 7)), + (True, (), (11, 7), (7,)), + (True, (), (11, 7), (11, 7)), + (True, (11,), (7,), (7,)), + (True, (11,), (7,), (11, 7)), + (True, (11,), (11, 7), (7,)), + (True, (11,), (11, 7), (11, 7)), + (True, (4, 1, 1), (3, 1, 7), (2, 7)), + (False, (), (1,), ()), + (False, (), (7,), ()), + (False, (), (7,), (1,)), + (False, (), (7,), (6,)), + (False, (3,), (4, 7), (7,)), + (False, (3,), (7,), (4, 7)), + (False, (), (3, 7), (4, 7)), + ], + ids=str, +) +def test_discrete_hmm_shape( + ok, init_shape, trans_shape, obs_shape, event_shape, state_dim +): init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim,) + event_shape) @@ -170,23 +187,29 @@ def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape, event_shape, assert final.support.upper_bound == state_dim - 1 -@pytest.mark.parametrize('event_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('state_dim', [2, 3]) -@pytest.mark.parametrize('num_steps', [1, 2, 3]) -@pytest.mark.parametrize('init_shape,trans_shape,obs_shape', [ - ((), (), ()), - ((), (1,), ()), - ((), (), (1,)), - ((), (1,), (7, 1)), - ((), (7, 1), (1,)), - ((), (7, 1), (7, 1)), - ((7,), (1,), (1,)), - ((7,), (1,), (7, 1)), - ((7,), (7, 1), (1,)), - ((7,), (7, 1), (7, 1)), - ((4, 1, 1), (3, 1, 1), (2, 1)), -], ids=str) -def test_discrete_hmm_homogeneous_trick(init_shape, trans_shape, obs_shape, event_shape, state_dim, num_steps): +@pytest.mark.parametrize("event_shape", [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize("state_dim", [2, 3]) +@pytest.mark.parametrize("num_steps", [1, 2, 3]) +@pytest.mark.parametrize( + "init_shape,trans_shape,obs_shape", + [ + ((), (), ()), + ((), (1,), ()), + ((), (), (1,)), + ((), (1,), (7, 1)), + ((), (7, 1), (1,)), + ((), (7, 1), (7, 1)), + ((7,), (1,), (1,)), + ((7,), (1,), (7, 1)), + ((7,), (7, 1), (1,)), + ((7,), (7, 1), (7, 1)), + ((4, 1, 1), (3, 1, 1), (2, 1)), + ], + ids=str, +) +def test_discrete_hmm_homogeneous_trick( + init_shape, trans_shape, obs_shape, event_shape, state_dim, num_steps +): batch_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1]) init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) @@ -208,7 +231,7 @@ def empty_guide(*args, **kwargs): pass -@pytest.mark.parametrize('num_steps', list(range(1, 10))) +@pytest.mark.parametrize("num_steps", list(range(1, 10))) def test_discrete_hmm_categorical(num_steps): state_dim = 3 obs_dim = 4 @@ -226,18 +249,22 @@ def test_discrete_hmm_categorical(num_steps): def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): - x = pyro.sample("x_{}".format(t), - dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) - pyro.sample("obs_{}".format(t), - dist.Categorical(logits=Vindex(obs_dist.logits)[..., t, x, :]), - obs=data[..., t]) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]), + ) + pyro.sample( + "obs_{}".format(t), + dist.Categorical(logits=Vindex(obs_dist.logits)[..., t, x, :]), + obs=data[..., t], + ) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss) -@pytest.mark.parametrize('num_steps', list(range(1, 10))) +@pytest.mark.parametrize("num_steps", list(range(1, 10))) def test_discrete_hmm_diag_normal(num_steps): state_dim = 3 event_size = 2 @@ -257,40 +284,57 @@ def test_discrete_hmm_diag_normal(num_steps): def model(data): x = pyro.sample("x_init", dist.Categorical(logits=init_logits)) for t in range(num_steps): - x = pyro.sample("x_{}".format(t), - dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :])) - pyro.sample("obs_{}".format(t), - dist.Normal(Vindex(loc)[..., t, x, :], - Vindex(scale)[..., t, x, :]).to_event(1), - obs=data[..., t, :]) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(logits=Vindex(trans_logits)[..., t, x, :]), + ) + pyro.sample( + "obs_{}".format(t), + dist.Normal( + Vindex(loc)[..., t, x, :], Vindex(scale)[..., t, x, :] + ).to_event(1), + obs=data[..., t, :], + ) expected_loss = TraceEnum_ELBO().loss(model, empty_guide, data) actual_loss = -float(actual.sum()) assert_close(actual_loss, expected_loss) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 3]) -@pytest.mark.parametrize('init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape', [ - ((), (), (), (), ()), - ((), (6,), (), (), ()), - ((), (), (6,), (), ()), - ((), (), (), (6,), ()), - ((), (), (), (), (6,)), - ((), (6,), (6,), (6,), (6,)), - ((5,), (6,), (), (), ()), - ((), (5, 1), (6,), (), ()), - ((), (), (5, 1), (6,), ()), - ((), (), (), (5, 1), (6,)), - ((), (6,), (5, 1), (), ()), - ((), (), (6,), (5, 1), ()), - ((), (), (), (6,), (5, 1)), - ((5,), (), (), (), (6,)), - ((5,), (5, 6), (5, 6), (5, 6), (5, 6)), -], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 3]) +@pytest.mark.parametrize( + "init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape", + [ + ((), (), (), (), ()), + ((), (6,), (), (), ()), + ((), (), (6,), (), ()), + ((), (), (), (6,), ()), + ((), (), (), (), (6,)), + ((), (6,), (6,), (6,), (6,)), + ((5,), (6,), (), (), ()), + ((), (5, 1), (6,), (), ()), + ((), (), (5, 1), (6,), ()), + ((), (), (), (5, 1), (6,)), + ((), (6,), (5, 1), (), ()), + ((), (), (6,), (5, 1), ()), + ((), (), (), (6,), (5, 1)), + ((5,), (), (), (), (6,)), + ((5,), (5, 6), (5, 6), (5, 6), (5, 6)), + ], + ids=str, +) @pytest.mark.parametrize("diag", [False, True], ids=["full", "diag"]) -def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, - obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): +def test_gaussian_hmm_shape( + diag, + init_shape, + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + hidden_dim, + obs_dim, +): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) @@ -299,14 +343,17 @@ def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) - d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, - duration=6) - - shape = broadcast_shape(init_shape + (6,), - trans_mat_shape, - trans_mvn_shape, - obs_mat_shape, - obs_mvn_shape) + d = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6 + ) + + shape = broadcast_shape( + init_shape + (6,), + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape @@ -364,19 +411,22 @@ def test_gaussian_hmm_high_obs_dim(): loc = torch.randn((duration, obs_dim)) scale = torch.randn((duration, obs_dim)).exp() obs_dist = dist.Normal(loc, scale).to_event(1) - d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, - duration=duration) + d = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) x = d.rsample(sample_shape) assert x.shape == sample_shape + (duration, obs_dim) -@pytest.mark.parametrize('sample_shape', [(), (5,)], ids=str) -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 2]) -@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) +@pytest.mark.parametrize("sample_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 2]) +@pytest.mark.parametrize("num_steps", [1, 2, 3, 4]) @pytest.mark.parametrize("diag", [False, True], ids=["full", "diag"]) -def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): +def test_gaussian_hmm_distribution( + diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim +): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) @@ -385,10 +435,13 @@ def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, h if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) - d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) + d = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps + ) if diag: - obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, - scale_tril=obs_dist.base_dist.scale.diag_embed()) + obs_mvn = dist.MultivariateNormal( + obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed() + ) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) @@ -411,21 +464,32 @@ def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, h like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) - unrolled_trans = reduce(operator.add, [ - trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) - for t in range(T) - ]) - unrolled_obs = reduce(operator.add, [ - obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) - for t in range(T) - ]) - unrolled_like = reduce(operator.add, [ - like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) - for t in range(T) - ]) + unrolled_trans = reduce( + operator.add, + [ + trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) + for t in range(T) + ], + ) + unrolled_obs = reduce( + operator.add, + [ + obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) + for t in range(T) + ], + ) + unrolled_like = reduce( + operator.add, + [ + like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) + for t in range(T) + ], + ) # Permute obs from HOHOHO to HHHOOO. - perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + - [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) + perm = torch.cat( + [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)] + ) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) @@ -438,8 +502,10 @@ def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, h assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) - assert_close(d.log_prob(data) + like_dist.log_prob(data), - d_posterior.log_prob(data) + log_normalizer) + assert_close( + d.log_prob(data) + like_dist.log_prob(data), + d_posterior.log_prob(data) + log_normalizer, + ) if batch_shape or sample_shape: return @@ -456,34 +522,42 @@ def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, h delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() - actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) + actual_corr = actual_cov / ( + actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2) + ) expected_cov = torch.linalg.cholesky(g.precision).cholesky_inverse() expected_mean = expected_cov.matmul(g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() - expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) + expected_corr = expected_cov / ( + expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2) + ) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02) -@pytest.mark.parametrize('obs_dim', [1, 2, 3]) -@pytest.mark.parametrize('hidden_dim', [1, 2, 3]) -@pytest.mark.parametrize('init_shape,trans_shape,obs_shape', [ - ((), (7,), ()), - ((), (), (7,)), - ((), (7,), (1,)), - ((), (1,), (7,)), - ((), (7,), (11, 7)), - ((), (11, 7), (7,)), - ((), (11, 7), (11, 7)), - ((11,), (7,), (7,)), - ((11,), (7,), (11, 7)), - ((11,), (11, 7), (7,)), - ((11,), (11, 7), (11, 7)), - ((4, 1, 1), (3, 1, 7), (2, 7)), -], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2, 3]) +@pytest.mark.parametrize("hidden_dim", [1, 2, 3]) +@pytest.mark.parametrize( + "init_shape,trans_shape,obs_shape", + [ + ((), (7,), ()), + ((), (), (7,)), + ((), (7,), (1,)), + ((), (1,), (7,)), + ((), (7,), (11, 7)), + ((), (11, 7), (7,)), + ((), (11, 7), (11, 7)), + ((11,), (7,), (7,)), + ((11,), (7,), (11, 7)), + ((11,), (11, 7), (7,)), + ((11,), (11, 7), (11, 7)), + ((4, 1, 1), (3, 1, 7), (2, 7)), + ], + ids=str, +) def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) @@ -504,12 +578,14 @@ def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_ check_expand(d, data) -@pytest.mark.parametrize('sample_shape', [(), (5,)], ids=str) -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 2]) -@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) -def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): +@pytest.mark.parametrize("sample_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 2]) +@pytest.mark.parametrize("num_steps", [1, 2, 3, 4]) +def test_gaussian_mrf_log_prob( + sample_shape, batch_shape, num_steps, hidden_dim, obs_dim +): init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim) obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim) @@ -531,17 +607,25 @@ def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, trans = mvn_to_gaussian(trans_dist) obs = mvn_to_gaussian(obs_dist) - unrolled_trans = reduce(operator.add, [ - trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) - for t in range(T) - ]) - unrolled_obs = reduce(operator.add, [ - obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) - for t in range(T) - ]) + unrolled_trans = reduce( + operator.add, + [ + trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) + for t in range(T) + ], + ) + unrolled_obs = reduce( + operator.add, + [ + obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) + for t in range(T) + ], + ) # Permute obs from HOHOHO to HHHOOO. - perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + - [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) + perm = torch.cat( + [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)] + ) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) @@ -555,12 +639,14 @@ def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, assert_close(actual_log_prob, expected_log_prob) -@pytest.mark.parametrize('sample_shape', [(), (5,)], ids=str) -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 2]) -@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) -def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): +@pytest.mark.parametrize("sample_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 2]) +@pytest.mark.parametrize("num_steps", [1, 2, 3, 4]) +def test_gaussian_mrf_log_prob_block_diag( + sample_shape, batch_shape, num_steps, hidden_dim, obs_dim +): # Construct a block-diagonal obs dist, so observations are independent of hidden state. obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim) precision = obs_dist.precision_matrix @@ -569,7 +655,8 @@ def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, obs_dist = dist.MultivariateNormal(obs_dist.loc, precision_matrix=precision) marginal_obs_dist = dist.MultivariateNormal( obs_dist.loc[..., hidden_dim:], - precision_matrix=precision[..., hidden_dim:, hidden_dim:]) + precision_matrix=precision[..., hidden_dim:, hidden_dim:], + ) init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim) @@ -581,41 +668,57 @@ def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, assert_close(actual_log_prob, expected_log_prob) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 3]) -@pytest.mark.parametrize('scale_shape,init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape', [ - ((5,), (), (6,), (), (), ()), - ((), (), (6,), (), (), ()), - ((), (), (), (6,), (), ()), - ((), (), (), (), (6,), ()), - ((), (), (), (), (), (6,)), - ((), (), (6,), (6,), (6,), (6,)), - ((), (5,), (6,), (), (), ()), - ((), (), (5, 1), (6,), (), ()), - ((), (), (), (5, 1), (6,), ()), - ((), (), (), (), (5, 1), (6,)), - ((), (), (6,), (5, 1), (), ()), - ((), (), (), (6,), (5, 1), ()), - ((), (), (), (), (6,), (5, 1)), - ((), (5,), (), (), (), (6,)), - ((5,), (5,), (5, 6), (5, 6), (5, 6), (5, 6)), -], ids=str) -def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape, trans_mvn_shape, - obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 3]) +@pytest.mark.parametrize( + "scale_shape,init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape", + [ + ((5,), (), (6,), (), (), ()), + ((), (), (6,), (), (), ()), + ((), (), (), (6,), (), ()), + ((), (), (), (), (6,), ()), + ((), (), (), (), (), (6,)), + ((), (), (6,), (6,), (6,), (6,)), + ((), (5,), (6,), (), (), ()), + ((), (), (5, 1), (6,), (), ()), + ((), (), (), (5, 1), (6,), ()), + ((), (), (), (), (5, 1), (6,)), + ((), (), (6,), (5, 1), (), ()), + ((), (), (), (6,), (5, 1), ()), + ((), (), (), (), (6,), (5, 1)), + ((), (5,), (), (), (), (6,)), + ((5,), (5,), (5, 6), (5, 6), (5, 6), (5, 6)), + ], + ids=str, +) +def test_gamma_gaussian_hmm_shape( + scale_shape, + init_shape, + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + hidden_dim, + obs_dim, +): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) scale_dist = random_gamma(scale_shape) - d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) - - shape = broadcast_shape(scale_shape + (1,), - init_shape + (1,), - trans_mat_shape, - trans_mvn_shape, - obs_mat_shape, - obs_mvn_shape) + d = dist.GammaGaussianHMM( + scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist + ) + + shape = broadcast_shape( + scale_shape + (1,), + init_shape + (1,), + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape @@ -637,19 +740,23 @@ def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape, tran assert final.event_shape == (hidden_dim,) -@pytest.mark.parametrize('sample_shape', [(), (5,)], ids=str) -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 2]) -@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) -def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): +@pytest.mark.parametrize("sample_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 2]) +@pytest.mark.parametrize("num_steps", [1, 2, 3, 4]) +def test_gamma_gaussian_hmm_log_prob( + sample_shape, batch_shape, num_steps, hidden_dim, obs_dim +): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) scale_dist = random_gamma(batch_shape) - d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) + d = dist.GammaGaussianHMM( + scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist + ) obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() @@ -668,17 +775,25 @@ def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps, hidde trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn) - unrolled_trans = reduce(operator.add, [ - trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) - for t in range(T) - ]) - unrolled_obs = reduce(operator.add, [ - obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) - for t in range(T) - ]) + unrolled_trans = reduce( + operator.add, + [ + trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) + for t in range(T) + ], + ) + unrolled_obs = reduce( + operator.add, + [ + obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) + for t in range(T) + ], + ) # Permute obs from HOHOHO to HHHOOO. - perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + - [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) + perm = torch.cat( + [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)] + ) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) @@ -699,41 +814,53 @@ def random_stable(stability, skew_scale_loc_shape): return dist.Stable(stability, skew, scale, loc) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 3]) -@pytest.mark.parametrize('init_shape,trans_mat_shape,trans_dist_shape,obs_mat_shape,obs_dist_shape', [ - ((), (), (), (), ()), - ((), (4,), (), (), ()), - ((), (), (4,), (), ()), - ((), (), (), (4,), ()), - ((), (), (), (), (4,)), - ((), (4,), (4,), (4,), (4,)), - ((5,), (4,), (), (), ()), - ((), (5, 1), (4,), (), ()), - ((), (), (5, 1), (4,), ()), - ((), (), (), (5, 1), (4,)), - ((), (4,), (5, 1), (), ()), - ((), (), (4,), (5, 1), ()), - ((), (), (), (4,), (5, 1)), - ((5,), (), (), (), (4,)), - ((5,), (5, 4), (5, 4), (5, 4), (5, 4)), -], ids=str) -def test_stable_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, - obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim): +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 3]) +@pytest.mark.parametrize( + "init_shape,trans_mat_shape,trans_dist_shape,obs_mat_shape,obs_dist_shape", + [ + ((), (), (), (), ()), + ((), (4,), (), (), ()), + ((), (), (4,), (), ()), + ((), (), (), (4,), ()), + ((), (), (), (), (4,)), + ((), (4,), (4,), (4,), (4,)), + ((5,), (4,), (), (), ()), + ((), (5, 1), (4,), (), ()), + ((), (), (5, 1), (4,), ()), + ((), (), (), (5, 1), (4,)), + ((), (4,), (5, 1), (), ()), + ((), (), (4,), (5, 1), ()), + ((), (), (), (4,), (5, 1)), + ((5,), (), (), (), (4,)), + ((5,), (5, 4), (5, 4), (5, 4), (5, 4)), + ], + ids=str, +) +def test_stable_hmm_shape( + init_shape, + trans_mat_shape, + trans_dist_shape, + obs_mat_shape, + obs_dist_shape, + hidden_dim, + obs_dim, +): stability = dist.Uniform(0, 2).sample() init_dist = random_stable(stability, init_shape + (hidden_dim,)).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_stable(stability, trans_dist_shape + (hidden_dim,)).to_event(1) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_stable(stability, obs_dist_shape + (obs_dim,)).to_event(1) - d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, - duration=4) - - shape = broadcast_shape(init_shape + (4,), - trans_mat_shape, - trans_dist_shape, - obs_mat_shape, - obs_dist_shape) + d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=4) + + shape = broadcast_shape( + init_shape + (4,), + trans_mat_shape, + trans_dist_shape, + obs_mat_shape, + obs_dist_shape, + ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape @@ -755,26 +882,37 @@ def random_studentt(shape): return dist.StudentT(df, loc, scale) -@pytest.mark.parametrize('obs_dim', [1, 2]) -@pytest.mark.parametrize('hidden_dim', [1, 3]) -@pytest.mark.parametrize('init_shape,trans_mat_shape,trans_dist_shape,obs_mat_shape,obs_dist_shape', [ - ((), (4,), (), (), ()), - ((), (), (4,), (), ()), - ((), (), (), (4,), ()), - ((), (), (), (), (4,)), - ((), (4,), (4,), (4,), (4,)), - ((5,), (4,), (), (), ()), - ((), (5, 1), (4,), (), ()), - ((), (), (5, 1), (4,), ()), - ((), (), (), (5, 1), (4,)), - ((), (4,), (5, 1), (), ()), - ((), (), (4,), (5, 1), ()), - ((), (), (), (4,), (5, 1)), - ((5,), (), (), (), (4,)), - ((5,), (5, 4), (5, 4), (5, 4), (5, 4)), -], ids=str) -def test_studentt_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, - obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim): +@pytest.mark.parametrize("obs_dim", [1, 2]) +@pytest.mark.parametrize("hidden_dim", [1, 3]) +@pytest.mark.parametrize( + "init_shape,trans_mat_shape,trans_dist_shape,obs_mat_shape,obs_dist_shape", + [ + ((), (4,), (), (), ()), + ((), (), (4,), (), ()), + ((), (), (), (4,), ()), + ((), (), (), (), (4,)), + ((), (4,), (4,), (4,), (4,)), + ((5,), (4,), (), (), ()), + ((), (5, 1), (4,), (), ()), + ((), (), (5, 1), (4,), ()), + ((), (), (), (5, 1), (4,)), + ((), (4,), (5, 1), (), ()), + ((), (), (4,), (5, 1), ()), + ((), (), (), (4,), (5, 1)), + ((5,), (), (), (), (4,)), + ((5,), (5, 4), (5, 4), (5, 4), (5, 4)), + ], + ids=str, +) +def test_studentt_hmm_shape( + init_shape, + trans_mat_shape, + trans_dist_shape, + obs_mat_shape, + obs_dist_shape, + hidden_dim, + obs_dim, +): init_dist = random_studentt(init_shape + (hidden_dim,)).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_studentt(trans_dist_shape + (hidden_dim,)).to_event(1) @@ -782,11 +920,13 @@ def test_studentt_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, obs_dist = random_studentt(obs_dist_shape + (obs_dim,)).to_event(1) d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) - shape = broadcast_shape(init_shape + (1,), - trans_mat_shape, - trans_dist_shape, - obs_mat_shape, - obs_dist_shape) + shape = broadcast_shape( + init_shape + (1,), + trans_mat_shape, + trans_dist_shape, + obs_mat_shape, + obs_dist_shape, + ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape @@ -801,46 +941,73 @@ def test_studentt_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, assert x.shape == (6, 5) + d.event_shape -@pytest.mark.parametrize('obs_dim', [1, 3]) -@pytest.mark.parametrize('hidden_dim', [1, 2]) -@pytest.mark.parametrize('init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape', [ - ((), (), (), (), ()), - ((), (6,), (), (), ()), - ((), (), (6,), (), ()), - ((), (), (), (6,), ()), - ((), (), (), (), (6,)), - ((), (6,), (6,), (6,), (6,)), - ((5,), (6,), (), (), ()), - ((), (5, 1), (6,), (), ()), - ((), (), (5, 1), (6,), ()), - ((), (), (), (5, 1), (6,)), - ((), (6,), (5, 1), (), ()), - ((), (), (6,), (5, 1), ()), - ((), (), (), (6,), (5, 1)), - ((5,), (), (), (), (6,)), - ((5,), (5, 6), (5, 6), (5, 6), (5, 6)), -], ids=str) -def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape, - obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): +@pytest.mark.parametrize("obs_dim", [1, 3]) +@pytest.mark.parametrize("hidden_dim", [1, 2]) +@pytest.mark.parametrize( + "init_shape,trans_mat_shape,trans_mvn_shape,obs_mat_shape,obs_mvn_shape", + [ + ((), (), (), (), ()), + ((), (6,), (), (), ()), + ((), (), (6,), (), ()), + ((), (), (), (6,), ()), + ((), (), (), (), (6,)), + ((), (6,), (6,), (6,), (6,)), + ((5,), (6,), (), (), ()), + ((), (5, 1), (6,), (), ()), + ((), (), (5, 1), (6,), ()), + ((), (), (), (5, 1), (6,)), + ((), (6,), (5, 1), (), ()), + ((), (), (6,), (5, 1), ()), + ((), (), (), (6,), (5, 1)), + ((5,), (), (), (), (6,)), + ((5,), (5, 6), (5, 6), (5, 6), (5, 6)), + ], + ids=str, +) +def test_independent_hmm_shape( + init_shape, + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + hidden_dim, + obs_dim, +): base_init_shape = init_shape + (obs_dim,) - base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1] if trans_mat_shape else 6) - base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1] if trans_mvn_shape else 6) - base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1] if obs_mat_shape else 6) - base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1] if obs_mvn_shape else 6) + base_trans_mat_shape = trans_mat_shape[:-1] + ( + obs_dim, + trans_mat_shape[-1] if trans_mat_shape else 6, + ) + base_trans_mvn_shape = trans_mvn_shape[:-1] + ( + obs_dim, + trans_mvn_shape[-1] if trans_mvn_shape else 6, + ) + base_obs_mat_shape = obs_mat_shape[:-1] + ( + obs_dim, + obs_mat_shape[-1] if obs_mat_shape else 6, + ) + base_obs_mvn_shape = obs_mvn_shape[:-1] + ( + obs_dim, + obs_mvn_shape[-1] if obs_mvn_shape else 6, + ) init_dist = random_mvn(base_init_shape, hidden_dim) trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim) obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1)) obs_dist = random_mvn(base_obs_mvn_shape, 1) - d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) + d = dist.GaussianHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6 + ) d = dist.IndependentHMM(d) - shape = broadcast_shape(init_shape + (6,), - trans_mat_shape, - trans_mvn_shape, - obs_mat_shape, - obs_mvn_shape) + shape = broadcast_shape( + init_shape + (6,), + trans_mat_shape, + trans_mvn_shape, + obs_mat_shape, + obs_mvn_shape, + ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape diff --git a/tests/distributions/test_ig.py b/tests/distributions/test_ig.py index 5091e02ad7..c1b2d8de95 100644 --- a/tests/distributions/test_ig.py +++ b/tests/distributions/test_ig.py @@ -10,8 +10,8 @@ from tests.common import assert_equal -@pytest.mark.parametrize('concentration', [3.3, 4.0]) -@pytest.mark.parametrize('rate', [2.5, 3.0]) +@pytest.mark.parametrize("concentration", [3.3, 4.0]) +@pytest.mark.parametrize("rate", [2.5, 3.0]) def test_sample(concentration, rate, n_samples=int(1e6)): samples = InverseGamma(concentration, rate).sample((n_samples,)) mean, std = samples.mean().item(), samples.std().item() @@ -21,11 +21,13 @@ def test_sample(concentration, rate, n_samples=int(1e6)): assert_equal(std, expected_std, prec=0.03) -@pytest.mark.parametrize('concentration', [2.5, 4.0]) -@pytest.mark.parametrize('rate', [2.5, 3.0]) -@pytest.mark.parametrize('value', [0.5, 1.7]) +@pytest.mark.parametrize("concentration", [2.5, 4.0]) +@pytest.mark.parametrize("rate", [2.5, 3.0]) +@pytest.mark.parametrize("value", [0.5, 1.7]) def test_log_prob(concentration, rate, value): value = torch.tensor(value) log_prob = InverseGamma(concentration, rate).log_prob(value) - expected_log_prob = Gamma(concentration, rate).log_prob(1.0 / value) - 2.0 * value.log() + expected_log_prob = ( + Gamma(concentration, rate).log_prob(1.0 / value) - 2.0 * value.log() + ) assert_equal(log_prob, expected_log_prob, prec=1e-6) diff --git a/tests/distributions/test_improper_uniform.py b/tests/distributions/test_improper_uniform.py index 744c3377e1..64eba1bf97 100644 --- a/tests/distributions/test_improper_uniform.py +++ b/tests/distributions/test_improper_uniform.py @@ -9,11 +9,15 @@ from tests.common import assert_equal -@pytest.mark.parametrize("constraint", [ - constraints.real, - constraints.positive, - constraints.unit_interval, -], ids=str) +@pytest.mark.parametrize( + "constraint", + [ + constraints.real, + constraints.positive, + constraints.unit_interval, + ], + ids=str, +) @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) @pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)], ids=str) def test_improper_uniform(constraint, batch_shape, event_shape): diff --git a/tests/distributions/test_independent.py b/tests/distributions/test_independent.py index 76d4ca58f3..af863a1e64 100644 --- a/tests/distributions/test_independent.py +++ b/tests/distributions/test_independent.py @@ -10,13 +10,18 @@ from tests.common import assert_equal -@pytest.mark.parametrize('sample_shape', [(), (6,), (4, 2)]) -@pytest.mark.parametrize('batch_shape', [(), (7,), (5, 3), (5, 3, 2)]) -@pytest.mark.parametrize('reinterpreted_batch_ndims', [0, 1, 2, 3]) -@pytest.mark.parametrize('base_dist', - [dist.Normal(1., 2.), dist.Exponential(2.), - dist.MultivariateNormal(torch.zeros(2), torch.eye(2))], - ids=['normal', 'exponential', 'mvn']) +@pytest.mark.parametrize("sample_shape", [(), (6,), (4, 2)]) +@pytest.mark.parametrize("batch_shape", [(), (7,), (5, 3), (5, 3, 2)]) +@pytest.mark.parametrize("reinterpreted_batch_ndims", [0, 1, 2, 3]) +@pytest.mark.parametrize( + "base_dist", + [ + dist.Normal(1.0, 2.0), + dist.Exponential(2.0), + dist.MultivariateNormal(torch.zeros(2), torch.eye(2)), + ], + ids=["normal", "exponential", "mvn"], +) def test_independent(base_dist, sample_shape, batch_shape, reinterpreted_batch_ndims): if batch_shape: base_dist = base_dist.expand_by(batch_shape) @@ -25,8 +30,14 @@ def test_independent(base_dist, sample_shape, batch_shape, reinterpreted_batch_n d = dist.Independent(base_dist, reinterpreted_batch_ndims) else: d = dist.Independent(base_dist, reinterpreted_batch_ndims) - assert d.batch_shape == batch_shape[:len(batch_shape) - reinterpreted_batch_ndims] - assert d.event_shape == batch_shape[len(batch_shape) - reinterpreted_batch_ndims:] + base_dist.event_shape + assert ( + d.batch_shape == batch_shape[: len(batch_shape) - reinterpreted_batch_ndims] + ) + assert ( + d.event_shape + == batch_shape[len(batch_shape) - reinterpreted_batch_ndims :] + + base_dist.event_shape + ) assert d.sample().shape == batch_shape + base_dist.event_shape assert d.mean.shape == batch_shape + base_dist.event_shape @@ -35,16 +46,25 @@ def test_independent(base_dist, sample_shape, batch_shape, reinterpreted_batch_n assert x.shape == sample_shape + d.batch_shape + d.event_shape log_prob = d.log_prob(x) - assert log_prob.shape == sample_shape + batch_shape[:len(batch_shape) - reinterpreted_batch_ndims] + assert ( + log_prob.shape + == sample_shape + + batch_shape[: len(batch_shape) - reinterpreted_batch_ndims] + ) assert not torch_isnan(log_prob) log_prob_0 = base_dist.log_prob(x) assert_equal(log_prob, _sum_rightmost(log_prob_0, reinterpreted_batch_ndims)) -@pytest.mark.parametrize('base_dist', - [dist.Normal(1., 2.), dist.Exponential(2.), - dist.MultivariateNormal(torch.zeros(2), torch.eye(2))], - ids=['normal', 'exponential', 'mvn']) +@pytest.mark.parametrize( + "base_dist", + [ + dist.Normal(1.0, 2.0), + dist.Exponential(2.0), + dist.MultivariateNormal(torch.zeros(2), torch.eye(2)), + ], + ids=["normal", "exponential", "mvn"], +) def test_to_event(base_dist): base_dist = base_dist.expand([2, 3]) d = base_dist @@ -85,9 +105,9 @@ def test_to_event(base_dist): assert d is base_dist -@pytest.mark.parametrize('event_shape', [(), (2,), (2, 3)]) -@pytest.mark.parametrize('batch_shape', [(), (3,), (5, 3)]) -@pytest.mark.parametrize('sample_shape', [(), (2,), (4, 2)]) +@pytest.mark.parametrize("event_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("batch_shape", [(), (3,), (5, 3)]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (4, 2)]) def test_expand(sample_shape, batch_shape, event_shape): ones_shape = torch.Size((1,) * len(batch_shape)) zero = torch.zeros(ones_shape + event_shape) @@ -98,20 +118,34 @@ def test_expand(sample_shape, batch_shape, event_shape): assert d0.variance.shape == ones_shape + event_shape assert d0.sample(sample_shape).shape == sample_shape + ones_shape + event_shape - assert d0.expand(sample_shape + batch_shape).batch_shape == sample_shape + batch_shape - assert d0.expand(sample_shape + batch_shape).sample().shape == sample_shape + batch_shape + event_shape - assert d0.expand(sample_shape + batch_shape).mean.shape == sample_shape + batch_shape + event_shape - assert d0.expand(sample_shape + batch_shape).variance.shape == sample_shape + batch_shape + event_shape - - base_dist = dist.MultivariateNormal(torch.zeros(2).expand(*(event_shape + (2,))), - torch.eye(2).expand(*(event_shape + (2, 2)))) + assert ( + d0.expand(sample_shape + batch_shape).batch_shape == sample_shape + batch_shape + ) + assert ( + d0.expand(sample_shape + batch_shape).sample().shape + == sample_shape + batch_shape + event_shape + ) + assert ( + d0.expand(sample_shape + batch_shape).mean.shape + == sample_shape + batch_shape + event_shape + ) + assert ( + d0.expand(sample_shape + batch_shape).variance.shape + == sample_shape + batch_shape + event_shape + ) + + base_dist = dist.MultivariateNormal( + torch.zeros(2).expand(*(event_shape + (2,))), + torch.eye(2).expand(*(event_shape + (2, 2))), + ) if len(event_shape) > len(base_dist.batch_shape): with pytest.raises(ValueError): base_dist.to_event(len(event_shape)).expand(batch_shape) else: expanded = base_dist.to_event(len(event_shape)).expand(batch_shape) - expanded_batch_ndims = getattr(expanded, 'reinterpreted_batch_ndims', 0) + expanded_batch_ndims = getattr(expanded, "reinterpreted_batch_ndims", 0) assert expanded.batch_shape == batch_shape - assert expanded.event_shape == (base_dist.batch_shape[len(base_dist.batch_shape) - - expanded_batch_ndims:] + - base_dist.event_shape) + assert expanded.event_shape == ( + base_dist.batch_shape[len(base_dist.batch_shape) - expanded_batch_ndims :] + + base_dist.event_shape + ) diff --git a/tests/distributions/test_kl.py b/tests/distributions/test_kl.py index 9159052c6f..7bcfe729f4 100644 --- a/tests/distributions/test_kl.py +++ b/tests/distributions/test_kl.py @@ -10,7 +10,7 @@ from tests.common import assert_close -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) def test_kl_delta_normal_shape(batch_shape): v = torch.randn(batch_shape) loc = torch.randn(batch_shape) @@ -20,8 +20,8 @@ def test_kl_delta_normal_shape(batch_shape): assert kl_divergence(p, q).shape == batch_shape -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) -@pytest.mark.parametrize('size', [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("size", [1, 2, 3]) def test_kl_delta_mvn_shape(batch_shape, size): v = torch.randn(batch_shape + (size,)) p = dist.Delta(v, event_dim=1) @@ -33,20 +33,21 @@ def test_kl_delta_mvn_shape(batch_shape, size): assert kl_divergence(p, q).shape == batch_shape -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) -@pytest.mark.parametrize('event_shape', [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("event_shape", [(), (4,), (2, 3)], ids=str) def test_kl_independent_normal(batch_shape, event_shape): shape = batch_shape + event_shape p = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) q = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) - actual = kl_divergence(dist.Independent(p, len(event_shape)), - dist.Independent(q, len(event_shape))) + actual = kl_divergence( + dist.Independent(p, len(event_shape)), dist.Independent(q, len(event_shape)) + ) expected = sum_rightmost(kl_divergence(p, q), len(event_shape)) assert_close(actual, expected) -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) -@pytest.mark.parametrize('size', [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("size", [1, 2, 3]) def test_kl_independent_delta_mvn_shape(batch_shape, size): v = torch.randn(batch_shape + (size,)) p = dist.Independent(dist.Delta(v), 1) @@ -58,8 +59,8 @@ def test_kl_independent_delta_mvn_shape(batch_shape, size): assert kl_divergence(p, q).shape == batch_shape -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) -@pytest.mark.parametrize('size', [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("size", [1, 2, 3]) def test_kl_independent_normal_mvn(batch_shape, size): loc = torch.randn(batch_shape + (size,)) scale = torch.randn(batch_shape + (size,)).exp() @@ -76,9 +77,11 @@ def test_kl_independent_normal_mvn(batch_shape, size): assert_close(actual, expected) -@pytest.mark.parametrize('shape', [(5,), (4, 5), (2, 3, 5)], ids=str) -@pytest.mark.parametrize('event_dim', [0, 1]) -@pytest.mark.parametrize('transform', [transforms.ExpTransform(), transforms.StickBreakingTransform()]) +@pytest.mark.parametrize("shape", [(5,), (4, 5), (2, 3, 5)], ids=str) +@pytest.mark.parametrize("event_dim", [0, 1]) +@pytest.mark.parametrize( + "transform", [transforms.ExpTransform(), transforms.StickBreakingTransform()] +) def test_kl_transformed_transformed(shape, event_dim, transform): p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim) q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim) diff --git a/tests/distributions/test_lkj.py b/tests/distributions/test_lkj.py index a5afdb5869..92216fb47c 100644 --- a/tests/distributions/test_lkj.py +++ b/tests/distributions/test_lkj.py @@ -29,8 +29,12 @@ def test_constraint(value_shape): def _autograd_log_det(ys, x): # computes log_abs_det_jacobian of y w.r.t. x - return torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0] - for y in ys]).det().abs().log() + return ( + torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0] for y in ys]) + .det() + .abs() + .log() + ) @pytest.mark.parametrize("y_shape", [(1,), (3, 1), (6,), (1, 6), (2, 6)]) @@ -89,7 +93,7 @@ def test_corr_cholesky_transform(x_shape, mapping): @pytest.mark.parametrize("dim", [2, 3, 4, 10]) def test_log_prob_conc1(dim): - dist = LKJCholesky(dim, torch.tensor([1.])) + dist = LKJCholesky(dim, torch.tensor([1.0])) a_sample = dist.sample(torch.Size([100])) lp = dist.log_prob(a_sample) @@ -97,17 +101,30 @@ def test_log_prob_conc1(dim): if dim == 2: assert_equal(lp, lp.new_full(lp.size(), -math.log(2))) else: - ladj = a_sample.diagonal(dim1=-2, dim2=-1).log().mul( - torch.linspace(start=dim-1, end=0, steps=dim, device=a_sample.device, dtype=a_sample.dtype) - ).sum(-1) + ladj = ( + a_sample.diagonal(dim1=-2, dim2=-1) + .log() + .mul( + torch.linspace( + start=dim - 1, + end=0, + steps=dim, + device=a_sample.device, + dtype=a_sample.dtype, + ) + ) + .sum(-1) + ) lps_less_ladj = lp - ladj assert (lps_less_ladj - lps_less_ladj.min()).abs().sum() < 1e-4 -@pytest.mark.parametrize("concentration", [.1, .5, 1., 2., 5.]) +@pytest.mark.parametrize("concentration", [0.1, 0.5, 1.0, 2.0, 5.0]) def test_log_prob_d2(concentration): dist = LKJCholesky(2, torch.tensor([concentration])) - test_dist = TransformedDistribution(Beta(concentration, concentration), AffineTransform(loc=-1., scale=2.0)) + test_dist = TransformedDistribution( + Beta(concentration, concentration), AffineTransform(loc=-1.0, scale=2.0) + ) samples = dist.sample(torch.Size([100])) lp = dist.log_prob(samples) diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index 27cfdc4910..770acda06d 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -12,20 +12,21 @@ def checker_mask(shape): - mask = tensor(0.) + mask = tensor(0.0) for size in shape: mask = mask.unsqueeze(-1) + torch.arange(float(size)) return mask.fmod(2).bool() -@pytest.mark.parametrize('batch_dim,mask_dim', - [(b, m) for b in range(3) for m in range(1 + b)]) -@pytest.mark.parametrize('event_dim', [0, 1, 2]) +@pytest.mark.parametrize( + "batch_dim,mask_dim", [(b, m) for b in range(3) for m in range(1 + b)] +) +@pytest.mark.parametrize("event_dim", [0, 1, 2]) def test_mask(batch_dim, event_dim, mask_dim): # Construct base distribution. - shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim]) + shape = torch.Size([2, 3, 4, 5, 6][: batch_dim + event_dim]) batch_shape = shape[:batch_dim] - mask_shape = batch_shape[batch_dim - mask_dim:] + mask_shape = batch_shape[batch_dim - mask_dim :] base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim) # Construct masked distribution. @@ -42,14 +43,24 @@ def test_mask(batch_dim, event_dim, mask_dim): # Check values. assert_equal(dist.mean, base_dist.mean) assert_equal(dist.variance, base_dist.variance) - assert_equal(dist.log_prob(sample), - scale_and_mask(base_dist.log_prob(sample), mask=mask)) - assert_equal(dist.score_parts(sample), - base_dist.score_parts(sample).scale_and_mask(mask=mask), prec=0) + assert_equal( + dist.log_prob(sample), scale_and_mask(base_dist.log_prob(sample), mask=mask) + ) + assert_equal( + dist.score_parts(sample), + base_dist.score_parts(sample).scale_and_mask(mask=mask), + prec=0, + ) if not dist.event_shape: assert_equal(dist.enumerate_support(), base_dist.enumerate_support()) - assert_equal(dist.enumerate_support(expand=True), base_dist.enumerate_support(expand=True)) - assert_equal(dist.enumerate_support(expand=False), base_dist.enumerate_support(expand=False)) + assert_equal( + dist.enumerate_support(expand=True), + base_dist.enumerate_support(expand=True), + ) + assert_equal( + dist.enumerate_support(expand=False), + base_dist.enumerate_support(expand=False), + ) @pytest.mark.parametrize("mask", [False, True, torch.tensor(False), torch.tensor(True)]) @@ -77,7 +88,7 @@ def test_mask_type(mask): @pytest.mark.parametrize("mask_shape", [(), (3,), (2, 1), (2, 3)]) def test_broadcast(event_shape, dist_shape, mask_shape): mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool() - base_dist = Normal(torch.zeros(dist_shape + event_shape), 1.) + base_dist = Normal(torch.zeros(dist_shape + event_shape), 1.0) base_dist = base_dist.to_event(len(event_shape)) assert base_dist.batch_shape == dist_shape assert base_dist.event_shape == event_shape @@ -93,27 +104,32 @@ def test_kl_divergence(): p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) expected = kl_divergence(p.to_event(2), q.to_event(2)) - actual = (kl_divergence(p.mask(mask).to_event(2), - q.mask(mask).to_event(2)) + - kl_divergence(p.mask(~mask).to_event(2), - q.mask(~mask).to_event(2))) + actual = kl_divergence( + p.mask(mask).to_event(2), q.mask(mask).to_event(2) + ) + kl_divergence(p.mask(~mask).to_event(2), q.mask(~mask).to_event(2)) assert_equal(actual, expected) -@pytest.mark.parametrize("p_mask", [False, True, torch.tensor(False), torch.tensor(True)]) -@pytest.mark.parametrize("q_mask", [False, True, torch.tensor(False), torch.tensor(True)]) +@pytest.mark.parametrize( + "p_mask", [False, True, torch.tensor(False), torch.tensor(True)] +) +@pytest.mark.parametrize( + "q_mask", [False, True, torch.tensor(False), torch.tensor(True)] +) def test_kl_divergence_type(p_mask, q_mask): p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) - mask = ((torch.tensor(p_mask) if isinstance(p_mask, bool) else p_mask) & - (torch.tensor(q_mask) if isinstance(q_mask, bool) else q_mask)).expand(2, 2) + mask = ( + (torch.tensor(p_mask) if isinstance(p_mask, bool) else p_mask) + & (torch.tensor(q_mask) if isinstance(q_mask, bool) else q_mask) + ).expand(2, 2) expected = kl_divergence(p, q) expected[~mask] = 0 actual = kl_divergence(p.mask(p_mask), q.mask(q_mask)) if p_mask is False or q_mask is False: - assert isinstance(actual, float) and actual == 0. + assert isinstance(actual, float) and actual == 0.0 else: assert_equal(actual, expected) diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 441025acc8..438cbf055b 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -10,14 +10,18 @@ from tests.common import assert_equal -@pytest.mark.parametrize('sample_shape', [(), (6,), (4, 2)]) -@pytest.mark.parametrize('batch_shape', [(), (7,), (5, 3)]) -@pytest.mark.parametrize('component1', - [dist.Normal(1., 2.), dist.Exponential(2.)], - ids=['normal', 'exponential']) -@pytest.mark.parametrize('component0', - [dist.Normal(1., 2.), dist.Exponential(2.)], - ids=['normal', 'exponential']) +@pytest.mark.parametrize("sample_shape", [(), (6,), (4, 2)]) +@pytest.mark.parametrize("batch_shape", [(), (7,), (5, 3)]) +@pytest.mark.parametrize( + "component1", + [dist.Normal(1.0, 2.0), dist.Exponential(2.0)], + ids=["normal", "exponential"], +) +@pytest.mark.parametrize( + "component0", + [dist.Normal(1.0, 2.0), dist.Exponential(2.0)], + ids=["normal", "exponential"], +) def test_masked_mixture_univariate(component0, component1, sample_shape, batch_shape): if batch_shape: component0 = component0.expand_by(batch_shape) @@ -43,12 +47,16 @@ def test_masked_mixture_univariate(component0, component1, sample_shape, batch_s assert_equal(log_prob[~mask], log_prob_0[~mask]) -@pytest.mark.parametrize('sample_shape', [(), (6,), (4, 2)]) -@pytest.mark.parametrize('batch_shape', [(), (7,), (5, 3)]) +@pytest.mark.parametrize("sample_shape", [(), (6,), (4, 2)]) +@pytest.mark.parametrize("batch_shape", [(), (7,), (5, 3)]) def test_masked_mixture_multivariate(sample_shape, batch_shape): event_shape = torch.Size((8,)) - component0 = dist.MultivariateNormal(torch.zeros(event_shape), torch.eye(event_shape[0])) - component1 = dist.Uniform(torch.zeros(event_shape), torch.ones(event_shape)).to_event(1) + component0 = dist.MultivariateNormal( + torch.zeros(event_shape), torch.eye(event_shape[0]) + ) + component1 = dist.Uniform( + torch.zeros(event_shape), torch.ones(event_shape) + ).to_event(1) if batch_shape: component0 = component0.expand_by(batch_shape) component1 = component1.expand_by(batch_shape) @@ -73,13 +81,13 @@ def test_masked_mixture_multivariate(sample_shape, batch_shape): assert_equal(log_prob[~mask], log_prob_0[~mask]) -@pytest.mark.parametrize('value_shape', [(), (5, 1, 1, 1), (6, 1, 1, 1, 1)]) -@pytest.mark.parametrize('component1_shape', [(), (4, 1, 1), (6, 1, 1, 1, 1)]) -@pytest.mark.parametrize('component0_shape', [(), (3, 1), (6, 1, 1, 1, 1)]) -@pytest.mark.parametrize('mask_shape', [(), (2,), (6, 1, 1, 1, 1)]) +@pytest.mark.parametrize("value_shape", [(), (5, 1, 1, 1), (6, 1, 1, 1, 1)]) +@pytest.mark.parametrize("component1_shape", [(), (4, 1, 1), (6, 1, 1, 1, 1)]) +@pytest.mark.parametrize("component0_shape", [(), (3, 1), (6, 1, 1, 1, 1)]) +@pytest.mark.parametrize("mask_shape", [(), (2,), (6, 1, 1, 1, 1)]) def test_broadcast(mask_shape, component0_shape, component1_shape, value_shape): mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool() - component0 = dist.Normal(torch.zeros(component0_shape), 1.) + component0 = dist.Normal(torch.zeros(component0_shape), 1.0) component1 = dist.Exponential(torch.ones(component1_shape)) value = torch.ones(value_shape) @@ -91,9 +99,9 @@ def test_broadcast(mask_shape, component0_shape, component1_shape, value_shape): assert d.log_prob(value).shape == log_prob_shape -@pytest.mark.parametrize('event_shape', [(), (2,), (2, 3)]) -@pytest.mark.parametrize('batch_shape', [(), (3,), (5, 3)]) -@pytest.mark.parametrize('sample_shape', [(), (2,), (4, 2)]) +@pytest.mark.parametrize("event_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("batch_shape", [(), (3,), (5, 3)]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (4, 2)]) def test_expand(sample_shape, batch_shape, event_shape): ones_shape = torch.Size((1,) * len(batch_shape)) mask = torch.empty(ones_shape).bernoulli_(0.5).bool() @@ -107,7 +115,18 @@ def test_expand(sample_shape, batch_shape, event_shape): assert d.variance.shape == ones_shape + event_shape assert d.sample(sample_shape).shape == sample_shape + ones_shape + event_shape - assert d.expand(sample_shape + batch_shape).batch_shape == sample_shape + batch_shape - assert d.expand(sample_shape + batch_shape).sample().shape == sample_shape + batch_shape + event_shape - assert d.expand(sample_shape + batch_shape).mean.shape == sample_shape + batch_shape + event_shape - assert d.expand(sample_shape + batch_shape).variance.shape == sample_shape + batch_shape + event_shape + assert ( + d.expand(sample_shape + batch_shape).batch_shape == sample_shape + batch_shape + ) + assert ( + d.expand(sample_shape + batch_shape).sample().shape + == sample_shape + batch_shape + event_shape + ) + assert ( + d.expand(sample_shape + batch_shape).mean.shape + == sample_shape + batch_shape + event_shape + ) + assert ( + d.expand(sample_shape + batch_shape).variance.shape + == sample_shape + batch_shape + event_shape + ) diff --git a/tests/distributions/test_mvn.py b/tests/distributions/test_mvn.py index 1f51aa29e7..4a68ce262e 100644 --- a/tests/distributions/test_mvn.py +++ b/tests/distributions/test_mvn.py @@ -19,15 +19,30 @@ def random_mvn(loc_shape, cov_shape, dim): return MultivariateNormal(loc, cov) -@pytest.mark.parametrize('loc_shape', [ - (), (2,), (3, 2), -]) -@pytest.mark.parametrize('cov_shape', [ - (), (2,), (3, 2), -]) -@pytest.mark.parametrize('dim', [ - 1, 3, 5, -]) +@pytest.mark.parametrize( + "loc_shape", + [ + (), + (2,), + (3, 2), + ], +) +@pytest.mark.parametrize( + "cov_shape", + [ + (), + (2,), + (3, 2), + ], +) +@pytest.mark.parametrize( + "dim", + [ + 1, + 3, + 5, + ], +) def test_shape(loc_shape, cov_shape, dim): mvn = random_mvn(loc_shape, cov_shape, dim) assert mvn.loc.shape == mvn.batch_shape + mvn.event_shape diff --git a/tests/distributions/test_mvt.py b/tests/distributions/test_mvt.py index feedf4b125..e216944fea 100644 --- a/tests/distributions/test_mvt.py +++ b/tests/distributions/test_mvt.py @@ -24,18 +24,38 @@ def random_mvt(df_shape, loc_shape, cov_shape, dim): return MultivariateStudentT(df, loc, scale_tril) -@pytest.mark.parametrize('df_shape', [ - (), (2,), (3, 2), -]) -@pytest.mark.parametrize('loc_shape', [ - (), (2,), (3, 2), -]) -@pytest.mark.parametrize('cov_shape', [ - (), (2,), (3, 2), -]) -@pytest.mark.parametrize('dim', [ - 1, 3, 5, -]) +@pytest.mark.parametrize( + "df_shape", + [ + (), + (2,), + (3, 2), + ], +) +@pytest.mark.parametrize( + "loc_shape", + [ + (), + (2,), + (3, 2), + ], +) +@pytest.mark.parametrize( + "cov_shape", + [ + (), + (2,), + (3, 2), + ], +) +@pytest.mark.parametrize( + "dim", + [ + 1, + 3, + 5, + ], +) def test_shape(df_shape, loc_shape, cov_shape, dim): mvt = random_mvt(df_shape, loc_shape, cov_shape, dim) assert mvt.df.shape == mvt.batch_shape @@ -50,11 +70,15 @@ def test_shape(df_shape, loc_shape, cov_shape, dim): (mvt.precision_matrix.sum() + mvt.log_prob(torch.zeros(dim)).sum()).backward() -@pytest.mark.parametrize("batch_shape", [ - (), - (3, 2), - (4,), -], ids=str) +@pytest.mark.parametrize( + "batch_shape", + [ + (), + (3, 2), + (4,), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2]) def test_log_prob(batch_shape, dim): loc = torch.randn(batch_shape + (dim,)) @@ -65,7 +89,9 @@ def test_log_prob(batch_shape, dim): actual_log_prob = MultivariateStudentT(df, loc, scale_tril).log_prob(x) if dim == 1: - expected_log_prob = StudentT(df.unsqueeze(-1), loc, scale_tril[..., 0]).log_prob(x).sum(-1) + expected_log_prob = ( + StudentT(df.unsqueeze(-1), loc, scale_tril[..., 0]).log_prob(x).sum(-1) + ) assert_equal(actual_log_prob, expected_log_prob) # test the fact MVT(df, loc, scale)(x) = int MVN(loc, scale / m)(x) Gamma(df/2,df/2)(m) dm @@ -101,7 +127,9 @@ def test_rsample(dim, df, num_samples=200 * 1000): @pytest.mark.parametrize("dim", [1, 2]) def test_log_prob_normalization(dim, df=6.1, grid_size=2000, domain_width=5.0): - scale_tril = (0.2 * torch.randn(dim) - 1.5).exp().diag() + 0.1 * torch.randn(dim, dim) + scale_tril = (0.2 * torch.randn(dim) - 1.5).exp().diag() + 0.1 * torch.randn( + dim, dim + ) scale_tril = 0.1 * scale_tril.tril(0) volume_factor = domain_width @@ -111,7 +139,9 @@ def test_log_prob_normalization(dim, df=6.1, grid_size=2000, domain_width=5.0): prec = 0.05 sample_shape = (grid_size * grid_size, dim) - z = torch.distributions.Uniform(-0.5 * domain_width, 0.5 * domain_width).sample(sample_shape) + z = torch.distributions.Uniform(-0.5 * domain_width, 0.5 * domain_width).sample( + sample_shape + ) d = MultivariateStudentT(torch.tensor(df), torch.zeros(dim), scale_tril) normalizer = d.log_prob(z).exp().mean().item() * volume_factor @@ -119,11 +149,15 @@ def test_log_prob_normalization(dim, df=6.1, grid_size=2000, domain_width=5.0): assert_equal(normalizer, 1.0, prec=prec) -@pytest.mark.parametrize("batch_shape", [ - (), - (3, 2), - (4,), -], ids=str) +@pytest.mark.parametrize( + "batch_shape", + [ + (), + (3, 2), + (4,), + ], + ids=str, +) def test_mean_var(batch_shape): dim = 2 loc = torch.randn(batch_shape + (dim,)) @@ -138,9 +172,15 @@ def test_mean_var(batch_shape): assert_equal(d.mean, expected_mean, prec=0.1) assert_equal(d.variance, expected_variance, prec=0.2) - assert_equal(MultivariateStudentT(0.5, loc, scale_tril).mean, - torch.full(batch_shape + (dim,), float('nan'))) - assert_equal(MultivariateStudentT(0.5, loc, scale_tril).variance, - torch.full(batch_shape + (dim,), float('nan'))) - assert_equal(MultivariateStudentT(1.5, loc, scale_tril).variance, - torch.full(batch_shape + (dim,), float('inf'))) + assert_equal( + MultivariateStudentT(0.5, loc, scale_tril).mean, + torch.full(batch_shape + (dim,), float("nan")), + ) + assert_equal( + MultivariateStudentT(0.5, loc, scale_tril).variance, + torch.full(batch_shape + (dim,), float("nan")), + ) + assert_equal( + MultivariateStudentT(1.5, loc, scale_tril).variance, + torch.full(batch_shape + (dim,), float("inf")), + ) diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index ead4a0314a..353b2374d5 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -15,20 +15,24 @@ def analytic_grad(L11=1.0, L22=1.0, L21=1.0, omega1=1.0, omega2=1.0): dp = L11 * omega1 + L21 * omega2 - fact_1 = - omega2 * dp - fact_2 = np.exp(- 0.5 * (L22 * omega2) ** 2) - fact_3 = np.exp(- 0.5 * dp ** 2) + fact_1 = -omega2 * dp + fact_2 = np.exp(-0.5 * (L22 * omega2) ** 2) + fact_3 = np.exp(-0.5 * dp ** 2) return fact_1 * fact_2 * fact_3 -@pytest.mark.parametrize('L21', [0.4, 1.1]) -@pytest.mark.parametrize('L11', [0.6]) -@pytest.mark.parametrize('omega1', [0.5]) -@pytest.mark.parametrize('sample_shape', [torch.Size([1000, 2000]), torch.Size([200000])]) -@pytest.mark.parametrize('k', [1]) -@pytest.mark.parametrize('mvn_dist', ['OMTMultivariateNormal', 'AVFMultivariateNormal']) -def test_mean_gradient(mvn_dist, k, sample_shape, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75): - if mvn_dist == 'OMTMultivariateNormal' and k > 1: +@pytest.mark.parametrize("L21", [0.4, 1.1]) +@pytest.mark.parametrize("L11", [0.6]) +@pytest.mark.parametrize("omega1", [0.5]) +@pytest.mark.parametrize( + "sample_shape", [torch.Size([1000, 2000]), torch.Size([200000])] +) +@pytest.mark.parametrize("k", [1]) +@pytest.mark.parametrize("mvn_dist", ["OMTMultivariateNormal", "AVFMultivariateNormal"]) +def test_mean_gradient( + mvn_dist, k, sample_shape, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75 +): + if mvn_dist == "OMTMultivariateNormal" and k > 1: return omega = torch.tensor([omega1, omega2, 0.0]) @@ -37,41 +41,47 @@ def test_mean_gradient(mvn_dist, k, sample_shape, L21, omega1, L11, L22=0.8, L33 off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec], requires_grad=True) L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag - if mvn_dist == 'OMTMultivariateNormal': + if mvn_dist == "OMTMultivariateNormal": dist = OMTMultivariateNormal(loc, L) - elif mvn_dist == 'AVFMultivariateNormal': + elif mvn_dist == "AVFMultivariateNormal": CV = (1.1 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) z = dist.rsample(sample_shape) - torch.cos((omega*z).sum(-1)).mean().backward() + torch.cos((omega * z).sum(-1)).mean().backward() computed_grad = off_diag.grad.cpu().data.numpy()[1, 0] analytic = analytic_grad(L11=L11, L22=L22, L21=L21, omega1=omega1, omega2=omega2) - assert(off_diag.grad.size() == off_diag.size()) - assert(loc.grad.size() == loc.size()) - assert(torch.triu(off_diag.grad, 1).sum() == 0.0) - assert_equal(analytic, computed_grad, prec=0.005, - msg='bad cholesky grad for %s (expected %.5f, got %.5f)' % - (mvn_dist, analytic, computed_grad)) + assert off_diag.grad.size() == off_diag.size() + assert loc.grad.size() == loc.size() + assert torch.triu(off_diag.grad, 1).sum() == 0.0 + assert_equal( + analytic, + computed_grad, + prec=0.005, + msg="bad cholesky grad for %s (expected %.5f, got %.5f)" + % (mvn_dist, analytic, computed_grad), + ) @pytest.mark.skip(reason="Slow; tests to be run when refactoring") -@pytest.mark.parametrize('L21', [0.4, 1.1]) -@pytest.mark.parametrize('L11', [0.6, 0.95]) -@pytest.mark.parametrize('omega1', [0.5, 0.9]) -@pytest.mark.parametrize('k', [3]) -@pytest.mark.parametrize('mvn_dist', ['OMTMultivariateNormal', 'AVFMultivariateNormal']) -def test_mean_single_gradient(mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75, n_samples=20000): +@pytest.mark.parametrize("L21", [0.4, 1.1]) +@pytest.mark.parametrize("L11", [0.6, 0.95]) +@pytest.mark.parametrize("omega1", [0.5, 0.9]) +@pytest.mark.parametrize("k", [3]) +@pytest.mark.parametrize("mvn_dist", ["OMTMultivariateNormal", "AVFMultivariateNormal"]) +def test_mean_single_gradient( + mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75, n_samples=20000 +): omega = torch.tensor([omega1, omega2, 0.0]) loc = torch.zeros(3, requires_grad=True) zero_vec = [0.0, 0.0, 0.0] off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec], requires_grad=True) L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag - if mvn_dist == 'OMTMultivariateNormal': + if mvn_dist == "OMTMultivariateNormal": dist = OMTMultivariateNormal(loc, L) - elif mvn_dist == 'AVFMultivariateNormal': + elif mvn_dist == "AVFMultivariateNormal": CV = (0.2 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) @@ -79,10 +89,10 @@ def test_mean_single_gradient(mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, o for _ in range(n_samples): z = dist.rsample() - torch.cos((omega*z).sum(-1)).mean().backward() - assert(off_diag.grad.size() == off_diag.size()) - assert(loc.grad.size() == loc.size()) - assert(torch.triu(off_diag.grad, 1).sum() == 0.0) + torch.cos((omega * z).sum(-1)).mean().backward() + assert off_diag.grad.size() == off_diag.size() + assert loc.grad.size() == loc.size() + assert torch.triu(off_diag.grad, 1).sum() == 0.0 computed_grad = off_diag.grad.cpu()[1, 0].item() computed_grads.append(computed_grad) @@ -91,11 +101,16 @@ def test_mean_single_gradient(mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, o computed_grad = np.mean(computed_grads) analytic = analytic_grad(L11=L11, L22=L22, L21=L21, omega1=omega1, omega2=omega2) - assert_equal(analytic, computed_grad, prec=0.01, - msg='bad cholesky grad for %s (expected %.5f, got %.5f)' % (mvn_dist, analytic, computed_grad)) + assert_equal( + analytic, + computed_grad, + prec=0.01, + msg="bad cholesky grad for %s (expected %.5f, got %.5f)" + % (mvn_dist, analytic, computed_grad), + ) -@pytest.mark.parametrize('mvn_dist', [OMTMultivariateNormal, AVFMultivariateNormal]) +@pytest.mark.parametrize("mvn_dist", [OMTMultivariateNormal, AVFMultivariateNormal]) def test_log_prob(mvn_dist): loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0]) D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0]) diff --git a/tests/distributions/test_one_hot_categorical.py b/tests/distributions/test_one_hot_categorical.py index b72d7dda37..f206faba5a 100644 --- a/tests/distributions/test_one_hot_categorical.py +++ b/tests/distributions/test_one_hot_categorical.py @@ -30,20 +30,24 @@ def setUp(self): # Discrete Distribution self.d_ps = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.1, 0.8]]) self.d_test_data = torch.tensor([[0.0], [5.0]]) - self.d_v_test_data = [['a'], ['f']] + self.d_v_test_data = [["a"], ["f"]] self.n_samples = 50000 - self.support_one_hot_non_vec = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - self.support_one_hot = torch.tensor([[[1, 0, 0], [1, 0, 0]], - [[0, 1, 0], [0, 1, 0]], - [[0, 0, 1], [0, 0, 1]]]) + self.support_one_hot_non_vec = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + ) + self.support_one_hot = torch.tensor( + [[[1, 0, 0], [1, 0, 0]], [[0, 1, 0], [0, 1, 0]], [[0, 0, 1], [0, 0, 1]]] + ) self.support_non_vec = torch.LongTensor([[0], [1], [2]]) self.support = torch.LongTensor([[[0], [0]], [[1], [1]], [[2], [2]]]) self.discrete_support_non_vec = torch.tensor([[0.0], [1.0], [2.0]]) - self.discrete_support = torch.tensor([[[0.0], [3.0]], [[1.0], [4.0]], [[2.0], [5.0]]]) - self.discrete_arr_support_non_vec = [['a'], ['b'], ['c']] - self.discrete_arr_support = [[['a'], ['d']], [['b'], ['e']], [['c'], ['f']]] + self.discrete_support = torch.tensor( + [[[0.0], [3.0]], [[1.0], [4.0]], [[2.0], [5.0]]] + ) + self.discrete_arr_support_non_vec = [["a"], ["b"], ["c"]] + self.discrete_arr_support = [[["a"], ["d"]], [["b"], ["e"]], [["c"], ["f"]]] def test_support_non_vectorized(self): s = dist.OneHotCategorical(self.d_ps[0].squeeze(0)).enumerate_support() @@ -57,7 +61,7 @@ def test_support(self): def wrap_nested(x, dim): if dim == 0: return x - return wrap_nested([x], dim-1) + return wrap_nested([x], dim - 1) def assert_correct_dimensions(sample, probs): @@ -77,7 +81,7 @@ def probs(request): def modify_params_using_dims(probs, dim): - return torch.tensor(wrap_nested(probs, dim-1)) + return torch.tensor(wrap_nested(probs, dim - 1)) def test_support_dims(dim, probs): @@ -101,8 +105,10 @@ def test_sample_dims(dim, probs): def test_batch_log_dims(dim, probs): - batch_pdf_shape = (3,) + (1,) * (dim-1) - expected_log_prob_sum = np.array(wrap_nested(list(np.log(probs)), dim-1)).reshape(*batch_pdf_shape) + batch_pdf_shape = (3,) + (1,) * (dim - 1) + expected_log_prob_sum = np.array(wrap_nested(list(np.log(probs)), dim - 1)).reshape( + *batch_pdf_shape + ) probs = modify_params_using_dims(probs, dim) support = dist.OneHotCategorical(probs).enumerate_support() log_prob = dist.OneHotCategorical(probs).log_prob(support) diff --git a/tests/distributions/test_one_one_matching.py b/tests/distributions/test_one_one_matching.py index ba182cb82c..056ddeb3c0 100644 --- a/tests/distributions/test_one_one_matching.py +++ b/tests/distributions/test_one_one_matching.py @@ -49,26 +49,27 @@ def test_log_prob_full(num_nodes, dtype, bp_iters): d = dist.OneOneMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() - logging.info(f"log_total = {log_total:0.3g}, " + - f"log_Z = {d.log_partition_function:0.3g}") - assert_close(log_total, 0., atol=2.0) + logging.info( + f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}" + ) + assert_close(log_total, 0.0, atol=2.0) @pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str) @pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"]) def test_log_prob_hard(dtype, bp_iters): - logits = [[0., 0.], [0., -math.inf]] + logits = [[0.0, 0.0], [0.0, -math.inf]] logits = torch.tensor(logits, dtype=dtype) d = dist.OneOneMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() - logging.info(f"log_total = {log_total:0.3g}, " + - f"log_Z = {d.log_partition_function:0.3g}") - assert_close(log_total, 0., atol=0.5) + logging.info( + f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}" + ) + assert_close(log_total, 0.0, atol=0.5) def assert_grads_ok(logits, bp_iters=None): - def fn(logits): d = dist.OneOneMatching(logits, bp_iters=bp_iters) return d.log_partition_function @@ -81,8 +82,9 @@ def assert_grads_agree(logits): d2 = dist.OneOneMatching(logits, bp_iters=BP_ITERS) expected = torch.autograd.grad(d1.log_partition_function, [logits])[0] actual = torch.autograd.grad(d2.log_partition_function, [logits])[0] - assert torch.allclose(actual, expected, atol=0.2, rtol=1e-3), \ - f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}" + assert torch.allclose( + actual, expected, atol=0.2, rtol=1e-3 + ), f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}" @pytest.mark.parametrize("num_nodes", [2, 3, 4, 5]) diff --git a/tests/distributions/test_one_two_matching.py b/tests/distributions/test_one_two_matching.py index f905b7f702..827b9ae37e 100644 --- a/tests/distributions/test_one_two_matching.py +++ b/tests/distributions/test_one_two_matching.py @@ -27,7 +27,7 @@ def random_phylo_logits(num_leaves, dtype): # Convert to a one-two-matching problem. ids = torch.arange(len(times)) root = times.min(0).indices.item() - sources = torch.cat([ids[:root], ids[root+1:]]) + sources = torch.cat([ids[:root], ids[root + 1 :]]) destins = ids[num_leaves:] dt = times[sources][:, None] - times[destins] dt = dt * 10 / dt.detach().std() @@ -73,22 +73,24 @@ def test_log_prob_full(num_destins, dtype, bp_iters): d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() - logging.info(f"log_total = {log_total:0.3g}, " + - f"log_Z = {d.log_partition_function:0.3g}") - assert_close(log_total, 0., atol=1.0) + logging.info( + f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}" + ) + assert_close(log_total, 0.0, atol=1.0) @pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str) @pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"]) def test_log_prob_hard(dtype, bp_iters): - logits = [[0., 0.], [0., 0.], [0., 0.], [0., -math.inf]] + logits = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, -math.inf]] logits = torch.tensor(logits, dtype=dtype) d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() - logging.info(f"log_total = {log_total:0.3g}, " + - f"log_Z = {d.log_partition_function:0.3g}") - assert_close(log_total, 0., atol=0.5) + logging.info( + f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}" + ) + assert_close(log_total, 0.0, atol=0.5) @pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str) @@ -99,9 +101,10 @@ def test_log_prob_phylo(num_leaves, dtype, bp_iters): d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() - logging.info(f"log_total = {log_total:0.3g}, " + - f"log_Z = {d.log_partition_function:0.3g}") - assert_close(log_total, 0., atol=1.0) + logging.info( + f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}" + ) + assert_close(log_total, 0.0, atol=1.0) @pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str) @@ -117,7 +120,6 @@ def test_log_prob_phylo_smoke(num_leaves, dtype): def assert_grads_ok(logits, bp_iters=None): - def fn(logits): d = dist.OneTwoMatching(logits, bp_iters=bp_iters) return d.log_partition_function @@ -130,8 +132,9 @@ def assert_grads_agree(logits): d2 = dist.OneTwoMatching(logits, bp_iters=BP_ITERS) expected = torch.autograd.grad(d1.log_partition_function, [logits])[0] actual = torch.autograd.grad(d2.log_partition_function, [logits])[0] - assert torch.allclose(actual, expected, atol=0.2, rtol=1e-3), \ - f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}" + assert torch.allclose( + actual, expected, atol=0.2, rtol=1e-3 + ), f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}" @pytest.mark.parametrize("num_destins", [2, 3, 4, 5]) diff --git a/tests/distributions/test_ordered_logistic.py b/tests/distributions/test_ordered_logistic.py index 960fc81140..173c0004f9 100644 --- a/tests/distributions/test_ordered_logistic.py +++ b/tests/distributions/test_ordered_logistic.py @@ -40,7 +40,7 @@ def test_broadcast(): for cp in ( torch.arange(5), torch.arange(5).view(1, -1), - torch.stack(4*[torch.arange(5)]), + torch.stack(4 * [torch.arange(5)]), torch.sort(torch.randn(3, 4, 5), dim=-1).values, torch.sort(torch.randn(predictor.shape + (100,)), dim=-1).values, ): diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index 66f881bbc5..9d84c377b4 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -19,12 +19,15 @@ dist.OMTMultivariateNormal, ] XFAIL = { - dist.Gumbel: xfail_param(dist.Gumbel, reason='cannot pickle weakref'), + dist.Gumbel: xfail_param(dist.Gumbel, reason="cannot pickle weakref"), } -DISTRIBUTIONS = [d for d in dist.__dict__.values() - if isinstance(d, type) - if issubclass(d, TorchDistributionMixin) - if d not in BLACKLIST] +DISTRIBUTIONS = [ + d + for d in dist.__dict__.values() + if isinstance(d, type) + if issubclass(d, TorchDistributionMixin) + if d not in BLACKLIST +] DISTRIBUTIONS.sort(key=lambda d: d.__name__) DISTRIBUTIONS = [XFAIL.get(d, d) for d in DISTRIBUTIONS] @@ -34,29 +37,40 @@ dist.Bernoulli: [0.5], dist.Binomial: [2, 0.5], dist.Categorical: [torch.ones(2)], - dist.Delta: [torch.tensor(0.)], + dist.Delta: [torch.tensor(0.0)], dist.Dirichlet: [torch.ones(2)], dist.GaussianScaleMixture: [torch.ones(2), torch.ones(3), torch.ones(3)], dist.Geometric: [0.5], dist.Independent: [dist.Normal(torch.zeros(2), torch.ones(2)), 1], dist.LowRankMultivariateNormal: [torch.zeros(2), torch.ones(2, 2), torch.ones(2)], - dist.MaskedMixture: [torch.tensor([1, 0]).bool(), dist.Normal(0, 1), dist.Normal(0, 2)], + dist.MaskedMixture: [ + torch.tensor([1, 0]).bool(), + dist.Normal(0, 1), + dist.Normal(0, 2), + ], dist.MixtureOfDiagNormals: [torch.ones(2, 3), torch.ones(2, 3), torch.ones(2)], - dist.MixtureOfDiagNormalsSharedCovariance: [torch.ones(2, 3), torch.ones(3), torch.ones(2)], + dist.MixtureOfDiagNormalsSharedCovariance: [ + torch.ones(2, 3), + torch.ones(3), + torch.ones(2), + ], dist.Multinomial: [2, torch.ones(2)], dist.MultivariateNormal: [torch.ones(2), torch.eye(2)], dist.OneHotCategorical: [torch.ones(2)], dist.RelaxedBernoulli: [1.0, 0.5], dist.RelaxedBernoulliStraightThrough: [1.0, 0.5], - dist.RelaxedOneHotCategorical: [1., torch.ones(2)], - dist.RelaxedOneHotCategoricalStraightThrough: [1., torch.ones(2)], - dist.TransformedDistribution: [dist.Normal(0, 1), torch.distributions.ExpTransform()], + dist.RelaxedOneHotCategorical: [1.0, torch.ones(2)], + dist.RelaxedOneHotCategoricalStraightThrough: [1.0, torch.ones(2)], + dist.TransformedDistribution: [ + dist.Normal(0, 1), + torch.distributions.ExpTransform(), + ], dist.Uniform: [0, 1], - dist.VonMises3D: [torch.tensor([1., 0., 0.])], + dist.VonMises3D: [torch.tensor([1.0, 0.0, 0.0])], } -@pytest.mark.parametrize('Dist', DISTRIBUTIONS) +@pytest.mark.parametrize("Dist", DISTRIBUTIONS) def test_pickle(Dist): if Dist in ARGS: args = ARGS[Dist] @@ -74,7 +88,7 @@ def test_pickle(Dist): try: dist = Dist(*args) except Exception: - pytest.skip(msg='cannot construct distribution') + pytest.skip(msg="cannot construct distribution") buffer = io.BytesIO() # Note that pickling torch.Size() requires protocol >= 2 diff --git a/tests/distributions/test_polya_gamma.py b/tests/distributions/test_polya_gamma.py index 6abcf77de6..7a19b76153 100644 --- a/tests/distributions/test_polya_gamma.py +++ b/tests/distributions/test_polya_gamma.py @@ -13,8 +13,12 @@ def test_polya_gamma(batch_shape, num_points=20000): d = TruncatedPolyaGamma(prototype=torch.ones(1)).expand(batch_shape) # test density approximately normalized - x = torch.linspace(1.0e-6, d.truncation_point, num_points).expand(batch_shape + (num_points,)) - prob = (d.truncation_point / num_points) * torch.logsumexp(d.log_prob(x), dim=-1).exp() + x = torch.linspace(1.0e-6, d.truncation_point, num_points).expand( + batch_shape + (num_points,) + ) + prob = (d.truncation_point / num_points) * torch.logsumexp( + d.log_prob(x), dim=-1 + ).exp() assert_close(prob, torch.tensor(1.0).expand(batch_shape), rtol=1.0e-4) # test mean of approximate sampler diff --git a/tests/distributions/test_rejector.py b/tests/distributions/test_rejector.py index 6a1bc6f529..13c9bf6aed 100644 --- a/tests/distributions/test_rejector.py +++ b/tests/distributions/test_rejector.py @@ -18,8 +18,8 @@ SIZES = list(map(torch.Size, [[], [1], [2], [3], [1, 1], [1, 2], [2, 3, 4]])) -@pytest.mark.parametrize('sample_shape', SIZES) -@pytest.mark.parametrize('batch_shape', filter(bool, SIZES)) +@pytest.mark.parametrize("sample_shape", SIZES) +@pytest.mark.parametrize("batch_shape", filter(bool, SIZES)) def test_rejection_standard_gamma_sample_shape(sample_shape, batch_shape): alphas = torch.ones(batch_shape) dist = RejectionStandardGamma(alphas) @@ -27,8 +27,8 @@ def test_rejection_standard_gamma_sample_shape(sample_shape, batch_shape): assert x.shape == sample_shape + batch_shape -@pytest.mark.parametrize('sample_shape', SIZES) -@pytest.mark.parametrize('batch_shape', filter(bool, SIZES)) +@pytest.mark.parametrize("sample_shape", SIZES) +@pytest.mark.parametrize("batch_shape", filter(bool, SIZES)) def test_rejection_exponential_sample_shape(sample_shape, batch_shape): rates = torch.ones(batch_shape) factors = torch.ones(batch_shape) * 0.5 @@ -46,8 +46,8 @@ def compute_elbo_grad(model, guide, variables): return grad(surrogate_elbo.sum(), variables, create_graph=True) -@pytest.mark.parametrize('rate', [0.5, 1.0, 2.0]) -@pytest.mark.parametrize('factor', [0.25, 0.5, 1.0]) +@pytest.mark.parametrize("rate", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize("factor", [0.25, 0.5, 1.0]) def test_rejector(rate, factor): num_samples = 100000 rates = torch.tensor(rate).expand(num_samples, 1) @@ -57,13 +57,13 @@ def test_rejector(rate, factor): dist2 = RejectionExponential(rates, factors) # implemented using Rejector x1 = dist1.rsample() x2 = dist2.rsample() - assert_equal(x1.mean(), x2.mean(), prec=0.02, msg='bug in .rsample()') - assert_equal(x1.std(), x2.std(), prec=0.02, msg='bug in .rsample()') - assert_equal(dist1.log_prob(x1), dist2.log_prob(x1), msg='bug in .log_prob()') + assert_equal(x1.mean(), x2.mean(), prec=0.02, msg="bug in .rsample()") + assert_equal(x1.std(), x2.std(), prec=0.02, msg="bug in .rsample()") + assert_equal(dist1.log_prob(x1), dist2.log_prob(x1), msg="bug in .log_prob()") -@pytest.mark.parametrize('rate', [0.5, 1.0, 2.0]) -@pytest.mark.parametrize('factor', [0.25, 0.5, 1.0]) +@pytest.mark.parametrize("rate", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize("factor", [0.25, 0.5, 1.0]) def test_exponential_elbo(rate, factor): num_samples = 100000 rates = torch.full((num_samples, 1), rate).requires_grad_() @@ -76,13 +76,13 @@ def test_exponential_elbo(rate, factor): for guide in [guide1, guide2]: grads.append(compute_elbo_grad(model, guide, [rates])[0]) expected, actual = grads - assert_equal(actual.mean(), expected.mean(), prec=0.05, msg='bad grad for rate') + assert_equal(actual.mean(), expected.mean(), prec=0.05, msg="bad grad for rate") actual = compute_elbo_grad(model, guide2, [factors])[0] - assert_equal(actual.mean().item(), 0.0, prec=0.05, msg='bad grad for factor') + assert_equal(actual.mean().item(), 0.0, prec=0.05, msg="bad grad for factor") -@pytest.mark.parametrize('alpha', [1.0, 2.0, 5.0]) +@pytest.mark.parametrize("alpha", [1.0, 2.0, 5.0]) def test_standard_gamma_elbo(alpha): num_samples = 100000 alphas = torch.full((num_samples, 1), alpha).requires_grad_() @@ -96,11 +96,11 @@ def test_standard_gamma_elbo(alpha): for guide in [guide1, guide2]: grads.append(compute_elbo_grad(model, guide, [alphas])[0].data) expected, actual = grads - assert_equal(actual.mean(), expected.mean(), prec=0.01, msg='bad grad for alpha') + assert_equal(actual.mean(), expected.mean(), prec=0.01, msg="bad grad for alpha") -@pytest.mark.parametrize('alpha', [1.0, 2.0, 5.0]) -@pytest.mark.parametrize('beta', [0.2, 0.5, 1.0, 2.0, 5.0]) +@pytest.mark.parametrize("alpha", [1.0, 2.0, 5.0]) +@pytest.mark.parametrize("beta", [0.2, 0.5, 1.0, 2.0, 5.0]) def test_gamma_elbo(alpha, beta): num_samples = 100000 alphas = torch.full((num_samples, 1), alpha).requires_grad_() @@ -117,12 +117,19 @@ def test_gamma_elbo(alpha, beta): expected = [g.mean() for g in expected] actual = [g.mean() for g in actual] scale = [(1 + abs(g)) for g in expected] - assert_equal(actual[0] / scale[0], expected[0] / scale[0], prec=0.01, msg='bad grad for alpha') - assert_equal(actual[1] / scale[1], expected[1] / scale[1], prec=0.01, msg='bad grad for beta') - - -@pytest.mark.parametrize('alpha', [0.2, 0.5, 1.0, 2.0, 5.0]) -@pytest.mark.parametrize('beta', [0.2, 0.5, 1.0, 2.0, 5.0]) + assert_equal( + actual[0] / scale[0], + expected[0] / scale[0], + prec=0.01, + msg="bad grad for alpha", + ) + assert_equal( + actual[1] / scale[1], expected[1] / scale[1], prec=0.01, msg="bad grad for beta" + ) + + +@pytest.mark.parametrize("alpha", [0.2, 0.5, 1.0, 2.0, 5.0]) +@pytest.mark.parametrize("beta", [0.2, 0.5, 1.0, 2.0, 5.0]) def test_shape_augmented_gamma_elbo(alpha, beta): num_samples = 100000 alphas = torch.full((num_samples, 1), alpha).requires_grad_() @@ -139,12 +146,19 @@ def test_shape_augmented_gamma_elbo(alpha, beta): expected = [g.mean() for g in expected] actual = [g.mean() for g in actual] scale = [(1 + abs(g)) for g in expected] - assert_equal(actual[0] / scale[0], expected[0] / scale[0], prec=0.05, msg='bad grad for alpha') - assert_equal(actual[1] / scale[1], expected[1] / scale[1], prec=0.05, msg='bad grad for beta') - - -@pytest.mark.parametrize('alpha', [0.5, 1.0, 4.0]) -@pytest.mark.parametrize('beta', [0.5, 1.0, 4.0]) + assert_equal( + actual[0] / scale[0], + expected[0] / scale[0], + prec=0.05, + msg="bad grad for alpha", + ) + assert_equal( + actual[1] / scale[1], expected[1] / scale[1], prec=0.05, msg="bad grad for beta" + ) + + +@pytest.mark.parametrize("alpha", [0.5, 1.0, 4.0]) +@pytest.mark.parametrize("beta", [0.5, 1.0, 4.0]) def test_shape_augmented_beta(alpha, beta): num_samples = 10000 alphas = torch.full((num_samples, 1), alpha).requires_grad_() @@ -157,5 +171,7 @@ def test_shape_augmented_beta(alpha, beta): mean_beta_grad = betas.grad.mean().item() expected_alpha_grad = beta / (alpha + beta) ** 2 expected_beta_grad = -alpha / (alpha + beta) ** 2 - assert_equal(mean_alpha_grad, expected_alpha_grad, prec=0.02, msg='bad grad for alpha') - assert_equal(mean_beta_grad, expected_beta_grad, prec=0.02, msg='bad grad for beta') + assert_equal( + mean_alpha_grad, expected_alpha_grad, prec=0.02, msg="bad grad for alpha" + ) + assert_equal(mean_beta_grad, expected_beta_grad, prec=0.02, msg="bad grad for beta") diff --git a/tests/distributions/test_relaxed_straight_through.py b/tests/distributions/test_relaxed_straight_through.py index 54b57a55dc..ab1ae486e2 100644 --- a/tests/distributions/test_relaxed_straight_through.py +++ b/tests/distributions/test_relaxed_straight_through.py @@ -19,20 +19,17 @@ from tests.common import assert_equal ONEHOT_PROBS = [ - [0.25, 0.75], - [0.25, 0.5, 0.25], - [[0.25, 0.75], [0.75, 0.25]], - [[[0.25, 0.75]], [[0.75, 0.25]]], - [0.1] * 10, + [0.25, 0.75], + [0.25, 0.5, 0.25], + [[0.25, 0.75], [0.75, 0.25]], + [[[0.25, 0.75]], [[0.75, 0.25]]], + [0.1] * 10, ] -BERN_PROBS = [ - [0.25, 0.75], - [[0.25, 0.75], [0.75, 0.25]] -] +BERN_PROBS = [[0.25, 0.75], [[0.25, 0.75], [0.75, 0.25]]] -@pytest.mark.parametrize('probs', ONEHOT_PROBS) +@pytest.mark.parametrize("probs", ONEHOT_PROBS) def test_onehot_shapes(probs): temperature = torch.tensor(0.5) probs = torch.tensor(probs, requires_grad=True) @@ -43,7 +40,7 @@ def test_onehot_shapes(probs): assert grad_probs.shape == probs.shape -@pytest.mark.parametrize('temp', [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("temp", [0.3, 0.5, 1.0]) def test_onehot_entropy_grad(temp): num_samples = 2000000 q = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True) @@ -57,33 +54,45 @@ def test_onehot_entropy_grad(temp): z = dist_q.rsample(sample_shape=(num_samples,)) actual = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples - assert_equal(expected, actual, prec=0.08, - msg='bad grad for RelaxedOneHotCategoricalStraightThrough (expected {}, got {})'. - format(expected, actual)) + assert_equal( + expected, + actual, + prec=0.08, + msg="bad grad for RelaxedOneHotCategoricalStraightThrough (expected {}, got {})".format( + expected, actual + ), + ) def test_onehot_svi_usage(): - def model(): p = torch.tensor([0.25] * 4) - pyro.sample('z', OneHotCategorical(probs=p)) + pyro.sample("z", OneHotCategorical(probs=p)) def guide(): - q = pyro.param('q', torch.tensor([0.1, 0.2, 0.3, 0.4]), constraint=constraints.simplex) + q = pyro.param( + "q", torch.tensor([0.1, 0.2, 0.3, 0.4]), constraint=constraints.simplex + ) temp = torch.tensor(0.10) - pyro.sample('z', RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q)) + pyro.sample( + "z", RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q) + ) - adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) + adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) for k in range(6000): svi.step() - assert_equal(pyro.param('q'), torch.tensor([0.25] * 4), prec=0.01, - msg='test svi usage of RelaxedOneHotCategoricalStraightThrough failed') + assert_equal( + pyro.param("q"), + torch.tensor([0.25] * 4), + prec=0.01, + msg="test svi usage of RelaxedOneHotCategoricalStraightThrough failed", + ) -@pytest.mark.parametrize('probs', BERN_PROBS) +@pytest.mark.parametrize("probs", BERN_PROBS) def test_bernoulli_shapes(probs): temperature = torch.tensor(0.5) probs = torch.tensor(probs, requires_grad=True) @@ -94,7 +103,7 @@ def test_bernoulli_shapes(probs): assert grad_probs.shape == probs.shape -@pytest.mark.parametrize('temp', [0.5, 1.0]) +@pytest.mark.parametrize("temp", [0.5, 1.0]) def test_bernoulli_entropy_grad(temp): num_samples = 1500000 q = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True) @@ -108,6 +117,11 @@ def test_bernoulli_entropy_grad(temp): z = dist_q.rsample(sample_shape=(num_samples,)) actual = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples - assert_equal(expected, actual, prec=0.04, - msg='bad grad for RelaxedBernoulliStraightThrough (expected {}, got {})'. - format(expected, actual)) + assert_equal( + expected, + actual, + prec=0.04, + msg="bad grad for RelaxedBernoulliStraightThrough (expected {}, got {})".format( + expected, actual + ), + ) diff --git a/tests/distributions/test_reshape.py b/tests/distributions/test_reshape.py index dddcbb38f0..d0dbdce6dc 100644 --- a/tests/distributions/test_reshape.py +++ b/tests/distributions/test_reshape.py @@ -20,10 +20,10 @@ def test_sample_shape_order(): assert actual.batch_shape == expected.batch_shape -@pytest.mark.parametrize('batch_dim', [0, 1, 2]) -@pytest.mark.parametrize('event_dim', [0, 1, 2]) +@pytest.mark.parametrize("batch_dim", [0, 1, 2]) +@pytest.mark.parametrize("event_dim", [0, 1, 2]) def test_idempotent(batch_dim, event_dim): - shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim] + shape = torch.Size((1, 2, 3, 4))[: batch_dim + event_dim] batch_shape = shape[:batch_dim] event_shape = shape[batch_dim:] @@ -38,12 +38,13 @@ def test_idempotent(batch_dim, event_dim): assert dist.event_shape == dist0.event_shape -@pytest.mark.parametrize('sample_dim,extra_event_dims', - [(s, e) for s in range(4) for e in range(4 + s)]) +@pytest.mark.parametrize( + "sample_dim,extra_event_dims", [(s, e) for s in range(4) for e in range(4 + s)] +) def test_reshape(sample_dim, extra_event_dims): batch_dim = 3 batch_shape, event_shape = torch.Size((5, 4, 3)), torch.Size() - sample_shape = torch.Size((8, 7, 6))[3 - sample_dim:] + sample_shape = torch.Size((8, 7, 6))[3 - sample_dim :] shape = sample_shape + batch_shape + event_shape # Construct a base dist of desired starting shape. @@ -57,7 +58,10 @@ def test_reshape(sample_dim, extra_event_dims): assert sample.shape == shape assert dist.mean.shape == shape assert dist.variance.shape == shape - assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims] + assert ( + dist.log_prob(sample).shape + == shape[: sample_dim + batch_dim - extra_event_dims] + ) # Check enumerate support. if dist.event_shape: @@ -70,15 +74,19 @@ def test_reshape(sample_dim, extra_event_dims): else: assert dist.enumerate_support().shape == (2,) + shape assert dist.enumerate_support(expand=True).shape == (2,) + shape - assert dist.enumerate_support(expand=False).shape == (2,) + (1,) * len(sample_shape + batch_shape) + event_shape + assert ( + dist.enumerate_support(expand=False).shape + == (2,) + (1,) * len(sample_shape + batch_shape) + event_shape + ) -@pytest.mark.parametrize('sample_dim,extra_event_dims', - [(s, e) for s in range(3) for e in range(3 + s)]) +@pytest.mark.parametrize( + "sample_dim,extra_event_dims", [(s, e) for s in range(3) for e in range(3 + s)] +) def test_reshape_reshape(sample_dim, extra_event_dims): batch_dim = 2 batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3)) - sample_shape = torch.Size((8, 7))[2 - sample_dim:] + sample_shape = torch.Size((8, 7))[2 - sample_dim :] shape = sample_shape + batch_shape + event_shape # Construct a base dist of desired starting shape. @@ -93,7 +101,10 @@ def test_reshape_reshape(sample_dim, extra_event_dims): assert sample.shape == shape assert dist.mean.shape == shape assert dist.variance.shape == shape - assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims] + assert ( + dist.log_prob(sample).shape + == shape[: sample_dim + batch_dim - extra_event_dims] + ) # Check enumerate support. if dist.event_shape: @@ -106,17 +117,20 @@ def test_reshape_reshape(sample_dim, extra_event_dims): else: assert dist.enumerate_support().shape == (2,) + shape assert dist.enumerate_support(expand=True).shape == (2,) + shape - assert dist.enumerate_support(expand=False).shape == (2,) + (1,) * len(sample_shape + batch_shape) + event_shape + assert ( + dist.enumerate_support(expand=False).shape + == (2,) + (1,) * len(sample_shape + batch_shape) + event_shape + ) -@pytest.mark.parametrize('sample_dim', [0, 1, 2]) -@pytest.mark.parametrize('batch_dim', [0, 1, 2]) -@pytest.mark.parametrize('event_dim', [0, 1, 2]) +@pytest.mark.parametrize("sample_dim", [0, 1, 2]) +@pytest.mark.parametrize("batch_dim", [0, 1, 2]) +@pytest.mark.parametrize("event_dim", [0, 1, 2]) def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): shape = torch.Size(range(sample_dim + batch_dim + event_dim)) sample_shape = shape[:sample_dim] - batch_shape = shape[sample_dim:sample_dim+batch_dim] - event_shape = shape[sample_dim + batch_dim:] + batch_shape = shape[sample_dim : sample_dim + batch_dim] + event_shape = shape[sample_dim + batch_dim :] # Construct a base dist of desired starting shape. dist0 = Bernoulli(0.5).expand_by(batch_shape + event_shape).to_event(event_dim) @@ -126,8 +140,8 @@ def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): # Check .to_event(...) for valid values. for extra_event_dims in range(1 + sample_dim + batch_dim): dist = dist0.expand_by(sample_shape).to_event(extra_event_dims) - assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims] - assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:] + assert dist.batch_shape == shape[: sample_dim + batch_dim - extra_event_dims] + assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims :] # Check .to_event(...) for invalid values. for extra_event_dims in range(1 + sample_dim + batch_dim, 20): @@ -138,4 +152,4 @@ def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): def test_independent_entropy(): dist_univ = Bernoulli(0.5) dist_multi = Bernoulli(torch.Tensor([0.5, 0.5])).to_event(1) - assert_equal(dist_multi.entropy(), 2*dist_univ.entropy()) + assert_equal(dist_multi.entropy(), 2 * dist_univ.entropy()) diff --git a/tests/distributions/test_sine_bivariate_von_mises.py b/tests/distributions/test_sine_bivariate_von_mises.py index a5bc1210fe..7b25020833 100644 --- a/tests/distributions/test_sine_bivariate_von_mises.py +++ b/tests/distributions/test_sine_bivariate_von_mises.py @@ -17,22 +17,25 @@ def _unnorm_log_prob(value, loc1, loc2, conc1, conc2, corr): phi_val = value[..., 0] psi_val = value[..., 1] - return (conc1 * torch.cos(phi_val - loc1) + conc2 * torch.cos(psi_val - loc2) + - corr * torch.sin(phi_val - loc1) * torch.sin(psi_val - loc2)) + return ( + conc1 * torch.cos(phi_val - loc1) + + conc2 * torch.cos(psi_val - loc2) + + corr * torch.sin(phi_val - loc1) * torch.sin(psi_val - loc2) + ) -@pytest.mark.parametrize('n', [0, 1, 10, 20]) +@pytest.mark.parametrize("n", [0, 1, 10, 20]) def test_log_binomial(n): comp = SineBivariateVonMises._lbinoms(tensor(n)) act = tensor([binom(2 * i, i) for i in range(n)]).log() assert_equal(act, comp) -@pytest.mark.parametrize('batch_dim', [tuple(), (1,), (10,), (2, 1), (2, 1, 2)]) +@pytest.mark.parametrize("batch_dim", [tuple(), (1,), (10,), (2, 1), (2, 1, 2)]) def test_bvm_unnorm_log_prob(batch_dim): - vm = VonMises(tensor(0.), tensor(1.)) - hn = HalfNormal(tensor(1.)) - b = Beta(tensor(2.), tensor(2.)) + vm = VonMises(tensor(0.0), tensor(1.0)) + hn = HalfNormal(tensor(1.0)) + b = Beta(tensor(2.0), tensor(2.0)) while True: phi_psi = vm.sample((*batch_dim, 2)) @@ -42,15 +45,17 @@ def test_bvm_unnorm_log_prob(batch_dim): if torch.all(torch.prod(conc, dim=0) > corr ** 2): break bmv = SineBivariateVonMises(locs[0], locs[1], conc[0], conc[1], corr) - assert_equal(_unnorm_log_prob(phi_psi, locs[0], locs[1], conc[0], conc[1], corr), - bmv.log_prob(phi_psi) + bmv.norm_const) + assert_equal( + _unnorm_log_prob(phi_psi, locs[0], locs[1], conc[0], conc[1], corr), + bmv.log_prob(phi_psi) + bmv.norm_const, + ) def test_bvm_multidim(): - vm = VonMises(tensor(0.), tensor(1.)) - hn = HalfNormal(tensor(1.)) - b = Beta(tensor(2.), tensor(2.)) - g = Geometric(torch.tensor([.4, .2, .5])) + vm = VonMises(tensor(0.0), tensor(1.0)) + hn = HalfNormal(tensor(1.0)) + b = Beta(tensor(2.0), tensor(2.0)) + g = Geometric(torch.tensor([0.4, 0.2, 0.5])) for _ in range(25): while True: batch_dim = tuple(int(i) for i in g.sample() if i > 0) @@ -63,13 +68,15 @@ def test_bvm_multidim(): bmv = SineBivariateVonMises(locs[0], locs[1], conc[0], conc[1], corr) assert_equal(bmv.batch_shape, torch.Size(batch_dim)) - assert_equal(bmv.sample(sample_dim).shape, torch.Size((*sample_dim, *batch_dim, 2))) + assert_equal( + bmv.sample(sample_dim).shape, torch.Size((*sample_dim, *batch_dim, 2)) + ) def test_mle_bvm(): - vm = VonMises(tensor(0.), tensor(1.)) - hn = HalfNormal(tensor(.8)) - b = Beta(tensor(2.), tensor(5.)) + vm = VonMises(tensor(0.0), tensor(1.0)) + hn = HalfNormal(tensor(0.8)) + b = Beta(tensor(2.0), tensor(5.0)) while True: locs = vm.sample((2,)) conc = hn.sample((2,)) @@ -78,13 +85,17 @@ def test_mle_bvm(): break def mle_model(data): - phi_loc = pyro.param('phi_loc', tensor(0.), constraints.real) - psi_loc = pyro.param('psi_loc', tensor(0.), constraints.real) - phi_conc = pyro.param('phi_conc', tensor(1.), constraints.positive) - psi_conc = pyro.param('psi_conc', tensor(1.), constraints.positive) - corr = pyro.param('corr', tensor(.5), constraints.real) + phi_loc = pyro.param("phi_loc", tensor(0.0), constraints.real) + psi_loc = pyro.param("psi_loc", tensor(0.0), constraints.real) + phi_conc = pyro.param("phi_conc", tensor(1.0), constraints.positive) + psi_conc = pyro.param("psi_conc", tensor(1.0), constraints.positive) + corr = pyro.param("corr", tensor(0.5), constraints.real) with pyro.plate("data", data.size(-2)): - pyro.sample('obs', SineBivariateVonMises(phi_loc, psi_loc, phi_conc, psi_conc, corr), obs=data) + pyro.sample( + "obs", + SineBivariateVonMises(phi_loc, psi_loc, phi_conc, psi_conc, corr), + obs=data, + ) def guide(data): pass @@ -93,7 +104,7 @@ def guide(data): data = bmv.sample((10_000,)) pyro.clear_param_store() - adam = pyro.optim.Adam({"lr": .01}) + adam = pyro.optim.Adam({"lr": 0.01}) svi = SVI(mle_model, guide, adam, loss=Trace_ELBO()) losses = [] @@ -101,13 +112,21 @@ def guide(data): for step in range(steps): losses.append(svi.step(data)) - expected = {'phi_loc': locs[0], 'psi_loc': locs[1], 'phi_conc': conc[0], 'psi_conc': conc[1], 'corr': corr} + expected = { + "phi_loc": locs[0], + "psi_loc": locs[1], + "phi_conc": conc[0], + "psi_conc": conc[1], + "corr": corr, + } actuals = {k: v for k, v in pyro.get_param_store().items()} for k in expected.keys(): if k in actuals: actual = actuals[k] else: - actual = actuals['corr_weight'] * actuals['phi_conc'] * actuals['psi_conc'] # k == 'corr' + actual = ( + actuals["corr_weight"] * actuals["phi_conc"] * actuals["psi_conc"] + ) # k == 'corr' assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2) diff --git a/tests/distributions/test_sine_skewed.py b/tests/distributions/test_sine_skewed.py index f91e73fe35..72755c2093 100644 --- a/tests/distributions/test_sine_skewed.py +++ b/tests/distributions/test_sine_skewed.py @@ -12,7 +12,7 @@ from pyro.optim import Adam from tests.common import assert_equal -BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))] +BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0.0, 1.0))] def _skewness(event_shape): @@ -20,7 +20,7 @@ def _skewness(event_shape): done = False while not done: for i in range(event_shape.numel()): - max_ = 1. - skewness.abs().sum(-1) + max_ = 1.0 - skewness.abs().sum(-1) if torch.any(max_ < 1e-15): break skewness[i] = Uniform(-max_, max_).sample() @@ -33,13 +33,28 @@ def _skewness(event_shape): return skewness -@pytest.mark.parametrize('expand_shape', - [(1,), (2,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) -@pytest.mark.parametrize('dist', BASE_DISTS) +@pytest.mark.parametrize( + "expand_shape", + [ + (1,), + (2,), + (4,), + (1, 1), + (1, 2), + (10, 10), + (1, 3, 1), + (10, 1, 5), + (1, 1, 1), + (3, 2, 3), + ], +) +@pytest.mark.parametrize("dist", BASE_DISTS) def test_ss_multidim_log_prob(expand_shape, dist): - base_dist = dist[0](*(torch.tensor(param).expand(expand_shape) for param in dist[1])).to_event(1) + base_dist = dist[0]( + *(torch.tensor(param).expand(expand_shape) for param in dist[1]) + ).to_event(1) - loc = base_dist.sample((10,)) + Normal(0., 1e-3).sample() + loc = base_dist.sample((10,)) + Normal(0.0, 1e-3).sample() base_prob = base_dist.log_prob(loc) skewness = _skewness(base_dist.event_shape) @@ -49,10 +64,12 @@ def test_ss_multidim_log_prob(expand_shape, dist): assert_equal(ss.sample().shape, torch.Size(expand_shape)) -@pytest.mark.parametrize('dist', BASE_DISTS) -@pytest.mark.parametrize('dim', [1, 2]) +@pytest.mark.parametrize("dist", BASE_DISTS) +@pytest.mark.parametrize("dim", [1, 2]) def test_ss_mle(dim, dist): - base_dist = dist[0](*(torch.tensor(param).expand((dim,)) for param in dist[1])).to_event(1) + base_dist = dist[0]( + *(torch.tensor(param).expand((dim,)) for param in dist[1]) + ).to_event(1) skewness_tar = _skewness(base_dist.event_shape) data = SineSkewed(base_dist, skewness_tar).sample((1000,)) @@ -60,17 +77,23 @@ def test_ss_mle(dim, dist): def model(data, batch_shape): skews = [] for i in range(dim): - skews.append(pyro.param(f'skew{i}', .5 * torch.ones(batch_shape), constraint=constraints.interval(-1, 1))) + skews.append( + pyro.param( + f"skew{i}", + 0.5 * torch.ones(batch_shape), + constraint=constraints.interval(-1, 1), + ) + ) skewness = torch.stack(skews, dim=-1) with pyro.plate("data", data.size(-len(data.size()))): - pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) + pyro.sample("obs", SineSkewed(base_dist, skewness), obs=data) def guide(data, batch_shape): pass pyro.clear_param_store() - adam = Adam({"lr": .1}) + adam = Adam({"lr": 0.1}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) losses = [] @@ -78,5 +101,7 @@ def guide(data, batch_shape): for step in range(steps): losses.append(svi.step(data, base_dist.batch_shape)) - act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1) + act_skewness = torch.stack( + [v for k, v in pyro.get_param_store().items() if "skew" in k], dim=-1 + ) assert_equal(act_skewness, skewness_tar, 1e-1) diff --git a/tests/distributions/test_spanning_tree.py b/tests/distributions/test_spanning_tree.py index 05984d0aed..8c80f8c307 100644 --- a/tests/distributions/test_spanning_tree.py +++ b/tests/distributions/test_spanning_tree.py @@ -18,16 +18,21 @@ ) from tests.common import assert_equal, xfail_if_not_implemented -pytestmark = pytest.mark.skipif("CUDA_TEST" in os.environ, reason="spanning_tree unsupported on CUDA.") +pytestmark = pytest.mark.skipif( + "CUDA_TEST" in os.environ, reason="spanning_tree unsupported on CUDA." +) @pytest.mark.filterwarnings("always") -@pytest.mark.parametrize('num_vertices,expected_grid', [ - (2, [[0], [1]]), - (3, [[0, 0, 1], [1, 2, 2]]), - (4, [[0, 0, 1, 0, 1, 2], [1, 2, 2, 3, 3, 3]]), -]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) +@pytest.mark.parametrize( + "num_vertices,expected_grid", + [ + (2, [[0], [1]]), + (3, [[0, 0, 1], [1, 2, 2]]), + (4, [[0, 0, 1, 0, 1, 2], [1, 2, 2, 3, 3, 3]]), + ], +) +@pytest.mark.parametrize("backend", ["python", "cpp"]) def test_make_complete_graph(num_vertices, expected_grid, backend): V = num_vertices K = V * (V - 1) // 2 @@ -38,8 +43,8 @@ def test_make_complete_graph(num_vertices, expected_grid, backend): @pytest.mark.filterwarnings("always") -@pytest.mark.parametrize('num_edges', [1, 3, 10, 30, 100]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) +@pytest.mark.parametrize("num_edges", [1, 3, 10, 30, 100]) +@pytest.mark.parametrize("backend", ["python", "cpp"]) def test_sample_tree_mcmc_smoke(num_edges, backend): pyro.set_rng_seed(num_edges) E = num_edges @@ -52,8 +57,8 @@ def test_sample_tree_mcmc_smoke(num_edges, backend): @pytest.mark.filterwarnings("always") -@pytest.mark.parametrize('num_edges', [1, 3, 10, 30, 100]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) +@pytest.mark.parametrize("num_edges", [1, 3, 10, 30, 100]) +@pytest.mark.parametrize("backend", ["python", "cpp"]) def test_sample_tree_approx_smoke(num_edges, backend): pyro.set_rng_seed(num_edges) E = num_edges @@ -65,8 +70,8 @@ def test_sample_tree_approx_smoke(num_edges, backend): @pytest.mark.filterwarnings("always") -@pytest.mark.parametrize('num_edges', [1, 3, 10, 30, 100]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) +@pytest.mark.parametrize("num_edges", [1, 3, 10, 30, 100]) +@pytest.mark.parametrize("backend", ["python", "cpp"]) def test_find_best_tree_smoke(num_edges, backend): pyro.set_rng_seed(num_edges) E = num_edges @@ -77,7 +82,7 @@ def test_find_best_tree_smoke(num_edges, backend): find_best_tree(edge_logits, backend=backend) -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5, 6]) def test_enumerate_support(num_edges): pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges @@ -92,7 +97,7 @@ def test_enumerate_support(num_edges): assert support.size(0) == NUM_SPANNING_TREES[V] -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5, 6]) def test_partition_function(num_edges): pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges @@ -110,7 +115,7 @@ def test_partition_function(num_edges): assert (actual - expected).abs() < 1e-6, (actual, expected) -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5, 6]) def test_log_prob(num_edges): pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges @@ -126,7 +131,7 @@ def test_log_prob(num_edges): assert abs(log_total) < 1e-6, log_total -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5, 6]) def test_edge_mean_function(num_edges): pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges @@ -149,8 +154,8 @@ def test_edge_mean_function(num_edges): assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected) -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5, 6]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5, 6]) +@pytest.mark.parametrize("backend", ["python", "cpp"]) def test_mode(num_edges, backend): pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges @@ -169,12 +174,12 @@ def test_mode(num_edges, backend): @pytest.mark.filterwarnings("always") -@pytest.mark.parametrize('pattern', ["uniform", "random", "sparse"]) -@pytest.mark.parametrize('num_edges', [1, 2, 3, 4, 5]) -@pytest.mark.parametrize('backend', ["python", "cpp"]) -@pytest.mark.parametrize('method', ["mcmc", "approx"]) +@pytest.mark.parametrize("pattern", ["uniform", "random", "sparse"]) +@pytest.mark.parametrize("num_edges", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("backend", ["python", "cpp"]) +@pytest.mark.parametrize("method", ["mcmc", "approx"]) def test_sample_tree_gof(method, backend, num_edges, pattern): - goftests = pytest.importorskip('goftests') + goftests = pytest.importorskip("goftests") pyro.set_rng_seed(2 ** 32 - num_edges) E = num_edges V = 1 + E @@ -191,7 +196,7 @@ def test_sample_tree_gof(method, backend, num_edges, pattern): for v2 in range(V): for v1 in range(v2): if v1 + 1 < v2: - edge_logits[v1 + v2 * (v2 - 1) // 2] = -float('inf') + edge_logits[v1 + v2 * (v2 - 1) // 2] = -float("inf") num_samples = 10 * NUM_SPANNING_TREES[V] # Generate many samples. @@ -212,13 +217,14 @@ def test_sample_tree_gof(method, backend, num_edges, pattern): # Check accuracy using a Pearson's chi-squared test. keys = [k for k, _ in counts.most_common(100)] - truncated = (len(keys) < len(counts)) + truncated = len(keys) < len(counts) counts = torch.tensor([counts[k] for k in keys]) tensors = torch.stack([tensors[k] for k in keys]) probs = SpanningTree(edge_logits).log_prob(tensors).exp() gof = goftests.multinomial_goodness_of_fit( - probs.numpy(), counts.numpy(), num_samples, plot=True, truncated=truncated) - logging.info('gof = {}'.format(gof)) + probs.numpy(), counts.numpy(), num_samples, plot=True, truncated=truncated + ) + logging.info("gof = {}".format(gof)) if method == "approx": assert gof >= 0.0001 else: diff --git a/tests/distributions/test_stable.py b/tests/distributions/test_stable.py index 7acd40187c..56cb10eebd 100644 --- a/tests/distributions/test_stable.py +++ b/tests/distributions/test_stable.py @@ -60,11 +60,21 @@ def cdf(x): @pytest.mark.parametrize("beta", [-1.0, -0.5, 0.0, 0.5, 1.0]) -@pytest.mark.parametrize("alpha", [ - 0.1, 0.4, 0.8, 0.99, - 0.999999, 1.000001, # scipy sampler is buggy very close to 1 - 1.01, 1.3, 1.7, 2.0, -]) +@pytest.mark.parametrize( + "alpha", + [ + 0.1, + 0.4, + 0.8, + 0.99, + 0.999999, + 1.000001, # scipy sampler is buggy very close to 1 + 1.01, + 1.3, + 1.7, + 2.0, + ], +) def test_sample_2(alpha, beta): num_samples = 10000 @@ -96,7 +106,9 @@ def test_normal(loc, scale): @pytest.mark.parametrize("skew0", [-0.9, -0.5, 0.0, 0.5, 0.9]) @pytest.mark.parametrize("skew1", [-0.9, -0.5, 0.0, 0.5, 0.9]) -@pytest.mark.parametrize("scale0,scale1", [(0.1, 0.9), (0.2, 0.8), (0.4, 0.6), (0.5, 0.5)]) +@pytest.mark.parametrize( + "scale0,scale1", [(0.1, 0.9), (0.2, 0.8), (0.4, 0.6), (0.5, 0.5)] +) @pytest.mark.parametrize("stability", [0.5, 0.99, 1.01, 1.5, 1.9]) def test_additive(stability, skew0, skew1, scale0, scale1): num_samples = 10000 @@ -105,8 +117,9 @@ def test_additive(stability, skew0, skew1, scale0, scale1): expected = d0.sample([num_samples]) + d1.sample([num_samples]) scale = (scale0 ** stability + scale1 ** stability) ** (1 / stability) - skew = ((skew0 * scale0 ** stability + skew1 * scale1 ** stability) / - (scale0 ** stability + scale1 ** stability)) + skew = (skew0 * scale0 ** stability + skew1 * scale1 ** stability) / ( + scale0 ** stability + scale1 ** stability + ) d = dist.Stable(stability, skew, scale, coords="S") actual = d.sample([num_samples]) diff --git a/tests/distributions/test_tensor_type.py b/tests/distributions/test_tensor_type.py index 5f185d72fa..3955746812 100644 --- a/tests/distributions/test_tensor_type.py +++ b/tests/distributions/test_tensor_type.py @@ -52,7 +52,8 @@ def test_double_type(test_data, alpha, beta): log_px_np = sp.beta.logpdf( test_data.detach().cpu().numpy(), alpha.detach().cpu().numpy(), - beta.detach().cpu().numpy()) + beta.detach().cpu().numpy(), + ) assert_equal(log_px_val, log_px_np, prec=1e-4) @@ -63,11 +64,14 @@ def test_float_type(float_test_data, float_alpha, float_beta, test_data, alpha, log_px_np = sp.beta.logpdf( test_data.detach().cpu().numpy(), alpha.detach().cpu().numpy(), - beta.detach().cpu().numpy()) + beta.detach().cpu().numpy(), + ) assert_equal(log_px_val, log_px_np, prec=1e-4) -@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/43138#issuecomment-677804776") +@pytest.mark.xfail( + reason="https://github.com/pytorch/pytorch/issues/43138#issuecomment-677804776" +) def test_conflicting_types(test_data, float_alpha, beta): with pytest.raises((TypeError, RuntimeError)): dist.Beta(float_alpha, beta).log_prob(test_data) diff --git a/tests/distributions/test_torch_patch.py b/tests/distributions/test_torch_patch.py index 7675ac8fa4..c6c65ed93d 100644 --- a/tests/distributions/test_torch_patch.py +++ b/tests/distributions/test_torch_patch.py @@ -16,7 +16,7 @@ def test_dirichlet_grad_cuda(): @requires_cuda def test_linspace(): - x = torch.linspace(-1., 1., 100, device="cuda") + x = torch.linspace(-1.0, 1.0, 100, device="cuda") assert x.device.type == "cuda" diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 46c46028f3..8f249d4935 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -20,6 +20,7 @@ class Flatten(dist.TransformModule): """ Used to handle transforms with `event_dim > 1` until we have a Reshape transform in PyTorch """ + domain = constraints.real_vector codomain = constraints.real_vector @@ -28,7 +29,7 @@ def __init__(self, transform, input_shape): assert transform.domain.event_dim == len(input_shape) output_shape = transform.forward_shape(input_shape) assert len(output_shape) >= transform.codomain.event_dim - output_shape = output_shape[len(output_shape) - transform.codomain.event_dim:] + output_shape = output_shape[len(output_shape) - transform.codomain.event_dim :] self.transform = transform self.input_shape = input_shape @@ -37,13 +38,13 @@ def __init__(self, transform, input_shape): def _call(self, x): x = x.reshape(x.shape[:-1] + self.input_shape) y = self.transform._call(x) - y = y.reshape(y.shape[:y.dim() - len(self.output_shape)] + (-1,)) + y = y.reshape(y.shape[: y.dim() - len(self.output_shape)] + (-1,)) return y def _inverse(self, y): y = y.reshape(y.shape[:-1] + self.output_shape) x = self.transform._inverse(y) - x = x.reshape(x.shape[:x.dim() - len(self.input_shape)] + (-1,)) + x = x.reshape(x.shape[: x.dim() - len(self.input_shape)] + (-1,)) return x def log_abs_det_jacobian(self, x, y): @@ -80,11 +81,14 @@ def nonzero(x): for k in range(input_dim): epsilon_vector = torch.zeros(1, input_dim) epsilon_vector[0, j] = self.epsilon - delta = (transform(x + 0.5 * epsilon_vector) - transform(x - 0.5 * epsilon_vector)) / self.epsilon + delta = ( + transform(x + 0.5 * epsilon_vector) + - transform(x - 0.5 * epsilon_vector) + ) / self.epsilon jacobian[j, k] = float(delta[0, k].data.sum()) # Apply permutation for autoregressive flows with a network - if hasattr(transform, 'arn') and 'get_permutation' in dir(transform.arn): + if hasattr(transform, "arn") and "get_permutation" in dir(transform.arn): permutation = transform.arn.get_permutation() permuted_jacobian = jacobian.clone() for j in range(input_dim): @@ -93,7 +97,7 @@ def nonzero(x): jacobian = permuted_jacobian # For autoregressive flow, Jacobian is sum of diagonal, otherwise need full determinate - if hasattr(transform, 'autoregressive') and transform.autoregressive: + if hasattr(transform, "autoregressive") and transform.autoregressive: numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) else: numeric_ldt = torch.log(torch.abs(jacobian.det())) @@ -102,7 +106,7 @@ def nonzero(x): assert ldt_discrepancy < self.epsilon # Test that lower triangular with unit diagonal for autoregressive flows - if hasattr(transform, 'autoregressive'): + if hasattr(transform, "autoregressive"): diag_sum = torch.sum(torch.diag(nonzero(jacobian))) lower_sum = torch.sum(torch.tril(nonzero(jacobian), diagonal=-1)) assert diag_sum == float(input_dim) @@ -133,9 +137,11 @@ def _test_shape(self, base_shape, transform): sample = dist.TransformedDistribution(base_dist, [transform]).sample() assert sample.shape == base_shape - batch_shape = base_shape[:len(base_shape) - transform.domain.event_dim] - input_event_shape = base_shape[len(base_shape) - transform.domain.event_dim:] - output_event_shape = base_shape[len(base_shape) - transform.codomain.event_dim:] + batch_shape = base_shape[: len(base_shape) - transform.domain.event_dim] + input_event_shape = base_shape[len(base_shape) - transform.domain.event_dim :] + output_event_shape = base_shape[ + len(base_shape) - transform.codomain.event_dim : + ] output_shape = batch_shape + output_event_shape assert transform.forward_shape(input_event_shape) == output_event_shape assert transform.forward_shape(base_shape) == output_shape @@ -163,11 +169,21 @@ def _test_autodiff(self, input_dim, transform, inverse=False): loss.backward() optimizer.step() - def _test(self, transform_factory, shape=True, jacobian=True, inverse=True, autodiff=True, event_dim=1): + def _test( + self, + transform_factory, + shape=True, + jacobian=True, + inverse=True, + autodiff=True, + event_dim=1, + ): for event_shape in [(2,), (5,)]: if event_dim > 1: event_shape = tuple([event_shape[0] + i for i in range(event_dim)]) - transform = transform_factory(event_shape[0] if len(event_shape) == 1 else event_shape) + transform = transform_factory( + event_shape[0] if len(event_shape) == 1 else event_shape + ) if inverse: self._test_inverse(event_shape, transform) @@ -181,9 +197,13 @@ def _test(self, transform_factory, shape=True, jacobian=True, inverse=True, auto self._test_jacobian(reduce(operator.mul, event_shape, 1), transform) if autodiff: # If the function doesn't have an explicit inverse, then use the forward op for autodiff - self._test_autodiff(reduce(operator.mul, event_shape, 1), transform, inverse=not inverse) + self._test_autodiff( + reduce(operator.mul, event_shape, 1), transform, inverse=not inverse + ) - def _test_conditional(self, conditional_transform_factory, context_dim=3, event_dim=1, **kwargs): + def _test_conditional( + self, conditional_transform_factory, context_dim=3, event_dim=1, **kwargs + ): def transform_factory(input_dim, context_dim=context_dim): z = torch.rand(1, context_dim) cond_transform = conditional_transform_factory(input_dim, context_dim) @@ -193,6 +213,7 @@ def transform_factory(input_dim, context_dim=context_dim): transform.parameters = lambda: cond_transform.parameters() return transform + self._test(transform_factory, event_dim=event_dim, **kwargs) def test_affine_autoregressive(self): @@ -208,38 +229,53 @@ def test_batchnorm(self): # (see the docs about the differing behaviour of BatchNorm in train and eval modes) def transform_factory(input_dim): transform = T.batchnorm(input_dim) - transform._inverse(torch.normal(torch.arange(0., input_dim), torch.arange(1., 1. + input_dim) / input_dim)) + transform._inverse( + torch.normal( + torch.arange(0.0, input_dim), + torch.arange(1.0, 1.0 + input_dim) / input_dim, + ) + ) transform.eval() return transform self._test(transform_factory) def test_block_autoregressive_jacobians(self): - for activation in ['ELU', 'LeakyReLU', 'sigmoid', 'tanh']: - self._test(partial(T.block_autoregressive, activation=activation), inverse=False) + for activation in ["ELU", "LeakyReLU", "sigmoid", "tanh"]: + self._test( + partial(T.block_autoregressive, activation=activation), inverse=False + ) - for residual in [None, 'normal', 'gated']: - self._test(partial(T.block_autoregressive, residual=residual), inverse=False) + for residual in [None, "normal", "gated"]: + self._test( + partial(T.block_autoregressive, residual=residual), inverse=False + ) def test_conditional_affine_autoregressive(self): self._test_conditional(T.conditional_affine_autoregressive) def test_conditional_affine_coupling(self): for dim in [-1, -2]: - self._test_conditional(partial(T.conditional_affine_coupling, dim=dim), event_dim=-dim) + self._test_conditional( + partial(T.conditional_affine_coupling, dim=dim), event_dim=-dim + ) def test_conditional_generalized_channel_permute(self, context_dim=3): for shape in [(3, 16, 16), (1, 3, 32, 32), (2, 5, 3, 64, 64)]: # NOTE: Without changing the interface to generalized_channel_permute I can't reuse general # test for `event_dim > 1` transforms z = torch.rand(context_dim) - transform = T.conditional_generalized_channel_permute(context_dim=3, channels=shape[-3]).condition(z) + transform = T.conditional_generalized_channel_permute( + context_dim=3, channels=shape[-3] + ).condition(z) self._test_shape(shape, transform) self._test_inverse(shape, transform) for width_dim in [2, 4, 6]: - input_dim = (width_dim**2) * 3 - self._test_jacobian(input_dim, Flatten(transform, (3, width_dim, width_dim))) + input_dim = (width_dim ** 2) * 3 + self._test_jacobian( + input_dim, Flatten(transform, (3, width_dim, width_dim)) + ) def test_conditional_householder(self): self._test_conditional(T.conditional_householder) @@ -258,7 +294,7 @@ def test_conditional_radial(self): self._test_conditional(T.conditional_radial, inverse=False) def test_conditional_spline(self): - for order in ['linear', 'quadratic']: + for order in ["linear", "quadratic"]: self._test_conditional(partial(T.conditional_spline, order=order)) def test_conditional_spline_autoregressive(self): @@ -267,7 +303,10 @@ def test_conditional_spline_autoregressive(self): def test_discrete_cosine(self): # NOTE: Need following since helper function unimplemented for smooth in [0.0, 0.5, 1.0, 2.0]: - self._test(lambda input_dim: T.DiscreteCosineTransform(smooth=smooth), autodiff=False) + self._test( + lambda input_dim: T.DiscreteCosineTransform(smooth=smooth), + autodiff=False, + ) def test_haar_transform(self): # NOTE: Need following since helper function unimplemented @@ -287,8 +326,10 @@ def test_generalized_channel_permute(self): self._test_inverse(shape, transform) for width_dim in [2, 4, 6]: - input_dim = (width_dim**2) * 3 - self._test_jacobian(input_dim, Flatten(transform, (3, width_dim, width_dim))) + input_dim = (width_dim ** 2) * 3 + self._test_jacobian( + input_dim, Flatten(transform, (3, width_dim, width_dim)) + ) def test_householder(self): self._test(partial(T.householder, count_transforms=2)) @@ -301,7 +342,9 @@ def test_lower_cholesky_affine(self): # NOTE: Need following since helper function unimplemented def transform_factory(input_dim): loc = torch.randn(input_dim) - scale_tril = torch.randn(input_dim).exp().diag() + 0.03 * torch.randn(input_dim, input_dim) + scale_tril = torch.randn(input_dim).exp().diag() + 0.03 * torch.randn( + input_dim, input_dim + ) scale_tril = scale_tril.tril(0) return T.LowerCholeskyAffine(loc, scale_tril) @@ -311,8 +354,10 @@ def test_matrix_exponential(self): self._test(T.matrix_exponential) def test_neural_autoregressive(self): - for activation in ['ELU', 'LeakyReLU', 'sigmoid', 'tanh']: - self._test(partial(T.neural_autoregressive, activation=activation), inverse=False) + for activation in ["ELU", "LeakyReLU", "sigmoid", "tanh"]: + self._test( + partial(T.neural_autoregressive, activation=activation), inverse=False + ) def test_ordered_transform(self): # NOTE: Need following since transform takes no input parameters @@ -332,7 +377,7 @@ def test_radial(self): self._test(T.radial, inverse=False) def test_spline(self): - for order in ['linear', 'quadratic']: + for order in ["linear", "quadratic"]: self._test(partial(T.spline, order=order)) def test_spline_coupling(self): @@ -351,12 +396,16 @@ def test_softplus(self): self._test(lambda _: T.SoftplusTransform(), autodiff=False) -@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)]) -@pytest.mark.parametrize('dim', [2, 3, 5]) -@pytest.mark.parametrize('transform', [ - T.CholeskyTransform(), - T.CorrMatrixCholeskyTransform(), -], ids=lambda t: type(t).__name__) +@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 5)]) +@pytest.mark.parametrize("dim", [2, 3, 5]) +@pytest.mark.parametrize( + "transform", + [ + T.CholeskyTransform(), + T.CorrMatrixCholeskyTransform(), + ], + ids=lambda t: type(t).__name__, +) def test_cholesky_transform(batch_shape, dim, transform): arange = torch.arange(dim) domain = transform.domain @@ -370,7 +419,11 @@ def test_cholesky_transform(batch_shape, dim, transform): def vec_to_mat(x_vec): x_mat = x_vec.new_zeros(batch_shape + (dim, dim)) x_mat[..., tril_mask] = x_vec - x_mat = x_mat + x_mat.transpose(-2, -1) - x_mat.diagonal(dim1=-2, dim2=-1).diag_embed() + x_mat = ( + x_mat + + x_mat.transpose(-2, -1) + - x_mat.diagonal(dim1=-2, dim2=-1).diag_embed() + ) if domain == dist.constraints.corr_matrix: x_mat = x_mat + x_mat.new_ones(x_mat.shape[-1]).diag_embed() return x_mat @@ -392,12 +445,16 @@ def transform_to_vec(x_vec): assert_close(transform.inv(y), x_mat) -@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)]) -@pytest.mark.parametrize('dim', [2, 3, 5]) -@pytest.mark.parametrize('transform', [ - T.LowerCholeskyTransform(), - T.SoftplusLowerCholeskyTransform(), -], ids=lambda t: type(t).__name__) +@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 5)]) +@pytest.mark.parametrize("dim", [2, 3, 5]) +@pytest.mark.parametrize( + "transform", + [ + T.LowerCholeskyTransform(), + T.SoftplusLowerCholeskyTransform(), + ], + ids=lambda t: type(t).__name__, +) def test_lower_cholesky_transform(transform, batch_shape, dim): shape = batch_shape + (dim, dim) x = torch.randn(shape) diff --git a/tests/distributions/test_unit.py b/tests/distributions/test_unit.py index c236904db0..f2b45c5ac9 100644 --- a/tests/distributions/test_unit.py +++ b/tests/distributions/test_unit.py @@ -8,7 +8,7 @@ from tests.common import assert_equal -@pytest.mark.parametrize('batch_shape', [(), (4,), (3, 2)]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)]) def test_shapes(batch_shape): log_factor = torch.randn(batch_shape) @@ -18,8 +18,8 @@ def test_shapes(batch_shape): assert (d.log_prob(x) == log_factor).all() -@pytest.mark.parametrize('sample_shape', [(), (4,), (3, 2)]) -@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)]) +@pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)]) +@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 5)]) def test_expand(sample_shape, batch_shape): log_factor = torch.randn(batch_shape) d1 = dist.Unit(log_factor) diff --git a/tests/distributions/test_util.py b/tests/distributions/test_util.py index 2c8670934d..879e0121b4 100644 --- a/tests/distributions/test_util.py +++ b/tests/distributions/test_util.py @@ -17,72 +17,87 @@ ) from tests.common import assert_equal -INF = float('inf') - - -@pytest.mark.parametrize('shapes', [ - ([],), - ([1],), - ([2],), - ([], []), - ([], [1]), - ([], [2]), - ([1], []), - ([2], []), - ([1], [2]), - ([2], [1]), - ([2], [2]), - ([2], [3, 1]), - ([2, 1], [3]), - ([2, 1], [1, 3]), - ([1, 2, 4, 1, 3], [6, 7, 1, 1, 5, 1]), - ([], [3, 1], [2], [4, 3, 1], [5, 4, 1, 1]), -]) +INF = float("inf") + + +@pytest.mark.parametrize( + "shapes", + [ + ([],), + ([1],), + ([2],), + ([], []), + ([], [1]), + ([], [2]), + ([1], []), + ([2], []), + ([1], [2]), + ([2], [1]), + ([2], [2]), + ([2], [3, 1]), + ([2, 1], [3]), + ([2, 1], [1, 3]), + ([1, 2, 4, 1, 3], [6, 7, 1, 1, 5, 1]), + ([], [3, 1], [2], [4, 3, 1], [5, 4, 1, 1]), + ], +) def test_broadcast_shape(shapes): assert broadcast_shape(*shapes) == np.broadcast(*map(np.empty, shapes)).shape -@pytest.mark.parametrize('shapes', [ - ([3], [4]), - ([2, 1], [1, 3, 1]), -]) +@pytest.mark.parametrize( + "shapes", + [ + ([3], [4]), + ([2, 1], [1, 3, 1]), + ], +) def test_broadcast_shape_error(shapes): with pytest.raises((ValueError, RuntimeError)): broadcast_shape(*shapes) -@pytest.mark.parametrize('shapes', [ - ([],), - ([1],), - ([2],), - ([], []), - ([], [1]), - ([], [2]), - ([1], []), - ([2], []), - ([1], [1]), - ([2], [2]), - ([2], [2]), - ([2], [3, 2]), - ([2, 3], [3]), - ([2, 3], [2, 3]), - ([4], [1, 2, 3, 4], [2, 3, 4], [3, 4]), -]) +@pytest.mark.parametrize( + "shapes", + [ + ([],), + ([1],), + ([2],), + ([], []), + ([], [1]), + ([], [2]), + ([1], []), + ([2], []), + ([1], [1]), + ([2], [2]), + ([2], [2]), + ([2], [3, 2]), + ([2, 3], [3]), + ([2, 3], [2, 3]), + ([4], [1, 2, 3, 4], [2, 3, 4], [3, 4]), + ], +) def test_broadcast_shape_strict(shapes): - assert broadcast_shape(*shapes, strict=True) == np.broadcast(*map(np.empty, shapes)).shape - - -@pytest.mark.parametrize('shapes', [ - ([1], [2]), - ([2], [1]), - ([3], [4]), - ([2], [3, 1]), - ([2, 1], [3]), - ([2, 1], [1, 3]), - ([2, 1], [1, 3, 1]), - ([1, 2, 4, 1, 3], [6, 7, 1, 1, 5, 1]), - ([], [3, 1], [2], [4, 3, 1], [5, 4, 1, 1]), -]) + assert ( + broadcast_shape(*shapes, strict=True) + == np.broadcast(*map(np.empty, shapes)).shape + ) + + +@pytest.mark.parametrize( + "shapes", + [ + ([1], [2]), + ([2], [1]), + ([3], [4]), + ([2], [3, 1]), + ([2, 1], [3]), + ([2, 1], [1, 3]), + ([2, 1], [1, 3, 1]), + ([1, 2, 4, 1, 3], [6, 7, 1, 1, 5, 1]), + ([], [3, 1], [2], [4, 3, 1], [5, 4, 1, 1]), + ], +) def test_broadcast_shape_strict_error(shapes): with pytest.raises(ValueError): broadcast_shape(*shapes, strict=True) @@ -109,7 +124,6 @@ def test_sum_leftmost(): def test_weakmethod(): - class Foo: def __init__(self, state): self.state = state @@ -130,8 +144,8 @@ def _method(self, *args, **kwargs): @pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) def test_detach_normal(shape): - loc = torch.tensor(0., requires_grad=True) - scale = torch.tensor(1., requires_grad=True) + loc = torch.tensor(0.0, requires_grad=True) + scale = torch.tensor(1.0, requires_grad=True) d1 = dist.Normal(loc, scale) if shape is not None: d1 = d1.expand(shape) @@ -163,12 +177,13 @@ def test_detach_beta(shape): @pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) def test_detach_transformed(shape): - loc = torch.tensor(0., requires_grad=True) - scale = torch.tensor(1., requires_grad=True) - a = torch.tensor(2., requires_grad=True) - b = torch.tensor(3., requires_grad=True) - d1 = dist.TransformedDistribution(dist.Normal(loc, scale), - dist.transforms.AffineTransform(a, b)) + loc = torch.tensor(0.0, requires_grad=True) + scale = torch.tensor(1.0, requires_grad=True) + a = torch.tensor(2.0, requires_grad=True) + b = torch.tensor(3.0, requires_grad=True) + d1 = dist.TransformedDistribution( + dist.Normal(loc, scale), dist.transforms.AffineTransform(a, b) + ) if shape is not None: d1 = d1.expand(shape) @@ -190,8 +205,8 @@ def test_detach_transformed(shape): @pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) def test_detach_jit(shape): - loc = torch.tensor(0., requires_grad=True) - scale = torch.tensor(1., requires_grad=True) + loc = torch.tensor(0.0, requires_grad=True) + scale = torch.tensor(1.0, requires_grad=True) data = torch.randn(5, 1, 1) def fn(loc, scale, data): diff --git a/tests/distributions/test_von_mises.py b/tests/distributions/test_von_mises.py index fd502684b8..e668237d72 100644 --- a/tests/distributions/test_von_mises.py +++ b/tests/distributions/test_von_mises.py @@ -21,12 +21,46 @@ def _eval_poly(y, coef): return result -_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2] -_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2, - -0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2] -_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3] -_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1, - 0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2] +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] _COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] _COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] @@ -50,7 +84,7 @@ def _log_modified_bessel_fn(x, order=0): y = 3.75 / x large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() - mask = (x < 3.75) + mask = x < 3.75 result = large if mask.any(): result[mask] = small[mask] @@ -77,8 +111,9 @@ def _fit_params_from_samples(samples, n_iter): def bfgs_closure(): bfgs.zero_grad() - obj = (_log_modified_bessel_fn(kappa, order=1) - - _log_modified_bessel_fn(kappa, order=0)) + obj = _log_modified_bessel_fn(kappa, order=1) - _log_modified_bessel_fn( + kappa, order=0 + ) obj = (obj - samples_r.log()).abs() obj.backward() return obj @@ -88,10 +123,23 @@ def bfgs_closure(): return mu, kappa.detach() -@pytest.mark.parametrize('loc', [-math.pi/2.0, 0.0, math.pi/2.0]) -@pytest.mark.parametrize('concentration', [skipif_param(0.01, condition='CUDA_TEST' in os.environ, - reason='low precision.'), - 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]) +@pytest.mark.parametrize("loc", [-math.pi / 2.0, 0.0, math.pi / 2.0]) +@pytest.mark.parametrize( + "concentration", + [ + skipif_param( + 0.01, condition="CUDA_TEST" in os.environ, reason="low precision." + ), + 0.03, + 0.1, + 0.3, + 1.0, + 3.0, + 10.0, + 30.0, + 100.0, + ], +) def test_sample(loc, concentration, n_samples=int(1e6), n_iter=50): prob = VonMises(loc, concentration) samples = prob.sample((n_samples,)) @@ -100,16 +148,18 @@ def test_sample(loc, concentration, n_samples=int(1e6), n_iter=50): assert abs(concentration - kappa) < concentration * 0.1 -@pytest.mark.parametrize('concentration', [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]) +@pytest.mark.parametrize( + "concentration", [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0] +) def test_log_prob_normalized(concentration): - grid = torch.arange(0., 2 * math.pi, 1e-4) + grid = torch.arange(0.0, 2 * math.pi, 1e-4) prob = VonMises(0.0, concentration).log_prob(grid).exp() norm = prob.mean().item() * 2 * math.pi assert abs(norm - 1) < 1e-3, norm -@pytest.mark.parametrize('loc', [-math.pi/2.0, 0.0, math.pi/2.0]) -@pytest.mark.parametrize('concentration', [0.03, 0.1, 0.3, 1., 3., 10., 30.]) +@pytest.mark.parametrize("loc", [-math.pi / 2.0, 0.0, math.pi / 2.0]) +@pytest.mark.parametrize("concentration", [0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0]) def test_von_mises_gof(loc, concentration): d = VonMises(loc, concentration) samples = d.sample(torch.Size([100000])) @@ -118,7 +168,7 @@ def test_von_mises_gof(loc, concentration): assert gof > TEST_FAILURE_RATE -@pytest.mark.parametrize('scale', [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) +@pytest.mark.parametrize("scale", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_von_mises_3d(scale): concentration = torch.randn(3) concentration = concentration * (scale / concentration.norm(2)) @@ -134,7 +184,7 @@ def test_von_mises_3d(scale): assert torch.abs(ratio - 1) < 0.01, ratio -@pytest.mark.parametrize('scale', [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) +@pytest.mark.parametrize("scale", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0]) def test_von_mises_3d_gof(scale): concentration = torch.randn(3) concentration = concentration * (scale / concentration.norm(2)) diff --git a/tests/distributions/test_zero_inflated.py b/tests/distributions/test_zero_inflated.py index 34be79a41e..f7872add6b 100644 --- a/tests/distributions/test_zero_inflated.py +++ b/tests/distributions/test_zero_inflated.py @@ -80,10 +80,14 @@ def test_zip_mean_variance(gate, rate): def test_zinb_0_gate(total_count, probs): # if gate is 0 ZINB is NegativeBinomial zinb1 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), gate=torch.zeros(1), probs=torch.tensor(probs) + total_count=torch.tensor(total_count), + gate=torch.zeros(1), + probs=torch.tensor(probs), ) zinb2 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), gate_logits=torch.tensor(-99.9), probs=torch.tensor(probs) + total_count=torch.tensor(total_count), + gate_logits=torch.tensor(-99.9), + probs=torch.tensor(probs), ) neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs)) s = neg_bin.sample((20,)) @@ -99,10 +103,14 @@ def test_zinb_0_gate(total_count, probs): def test_zinb_1_gate(total_count, probs): # if gate is 1 ZINB is Delta(0) zinb1 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), gate=torch.ones(1), probs=torch.tensor(probs) + total_count=torch.tensor(total_count), + gate=torch.ones(1), + probs=torch.tensor(probs), ) zinb2 = ZeroInflatedNegativeBinomial( - total_count=torch.tensor(total_count), gate_logits=torch.tensor(math.inf), probs=torch.tensor(probs) + total_count=torch.tensor(total_count), + gate_logits=torch.tensor(math.inf), + probs=torch.tensor(probs), ) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) diff --git a/tests/doctest_fixtures.py b/tests/doctest_fixtures.py index 0d4e785d84..30d65650d2 100644 --- a/tests/doctest_fixtures.py +++ b/tests/doctest_fixtures.py @@ -21,15 +21,15 @@ @pytest.fixture(autouse=True) def add_imports(doctest_namespace): - doctest_namespace['dist'] = dist - doctest_namespace['gp'] = gp - doctest_namespace['named'] = named - doctest_namespace['np'] = numpy - doctest_namespace['param_with_module_name'] = param_with_module_name - doctest_namespace['poutine'] = poutine - doctest_namespace['pyro'] = pyro - doctest_namespace['torch'] = torch - doctest_namespace['EmpiricalMarginal'] = EmpiricalMarginal - doctest_namespace['HMC'] = HMC - doctest_namespace['MCMC'] = MCMC - doctest_namespace['NUTS'] = NUTS + doctest_namespace["dist"] = dist + doctest_namespace["gp"] = gp + doctest_namespace["named"] = named + doctest_namespace["np"] = numpy + doctest_namespace["param_with_module_name"] = param_with_module_name + doctest_namespace["poutine"] = poutine + doctest_namespace["pyro"] = pyro + doctest_namespace["torch"] = torch + doctest_namespace["EmpiricalMarginal"] = EmpiricalMarginal + doctest_namespace["HMC"] = HMC + doctest_namespace["MCMC"] = MCMC + doctest_namespace["NUTS"] = NUTS diff --git a/tests/infer/mcmc/test_adaptation.py b/tests/infer/mcmc/test_adaptation.py index 2fad237d90..36cdb11aa3 100644 --- a/tests/infer/mcmc/test_adaptation.py +++ b/tests/infer/mcmc/test_adaptation.py @@ -13,18 +13,21 @@ from tests.common import assert_close, assert_equal -@pytest.mark.parametrize("adapt_step_size, adapt_mass, warmup_steps, expected", [ - (False, False, 100, []), - (False, True, 50, [(0, 6), (7, 44), (45, 49)]), - (True, False, 150, [(0, 74), (75, 99), (100, 149)]), - (True, True, 200, [(0, 74), (75, 99), (100, 149), (150, 199)]), - (True, True, 280, [(0, 74), (75, 99), (100, 229), (230, 279)]), - (True, True, 18, [(0, 17)]), -]) +@pytest.mark.parametrize( + "adapt_step_size, adapt_mass, warmup_steps, expected", + [ + (False, False, 100, []), + (False, True, 50, [(0, 6), (7, 44), (45, 49)]), + (True, False, 150, [(0, 74), (75, 99), (100, 149)]), + (True, True, 200, [(0, 74), (75, 99), (100, 149), (150, 199)]), + (True, True, 280, [(0, 74), (75, 99), (100, 229), (230, 279)]), + (True, True, 18, [(0, 17)]), + ], +) def test_adaptation_schedule(adapt_step_size, adapt_mass, warmup_steps, expected): - adapter = WarmupAdapter(0.1, - adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass) + adapter = WarmupAdapter( + 0.1, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass + ) adapter.configure(warmup_steps, mass_matrix_shape={"z": (5, 5)}) expected_schedule = [adapt_window(i, j) for i, j in expected] assert_equal(adapter.adaptation_schedule, expected_schedule, prec=0) @@ -46,8 +49,12 @@ def test_arrowhead_mass_matrix(diagonal): cov = torch.mm(cov, cov.t()) if diagonal: cov = cov.diag().diag() - z_dist = torch.distributions.MultivariateNormal(torch.zeros(size), covariance_matrix=cov) - g_dist = torch.distributions.MultivariateNormal(torch.zeros(size), precision_matrix=cov) + z_dist = torch.distributions.MultivariateNormal( + torch.zeros(size), covariance_matrix=cov + ) + g_dist = torch.distributions.MultivariateNormal( + torch.zeros(size), precision_matrix=cov + ) z_samples = z_dist.sample((num_samples,)).reshape((num_samples,) + shape) g_samples = g_dist.sample((num_samples,)).reshape((num_samples,) + shape) @@ -57,6 +64,9 @@ def test_arrowhead_mass_matrix(diagonal): block_adapter.end_adaptation() arrowhead_adapter.end_adaptation() - assert_close(arrowhead_adapter.inverse_mass_matrix[('z',)], - block_adapter.inverse_mass_matrix[('z',)], - atol=0.3, rtol=0.3) + assert_close( + arrowhead_adapter.inverse_mass_matrix[("z",)], + block_adapter.inverse_mass_matrix[("z",)], + atol=0.3, + rtol=0.3, + ) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 2cf655a42f..6dbc5eb236 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -21,8 +21,7 @@ def mark_jit(*args, **kwargs): jit_markers = kwargs.pop("marks", []) jit_markers += [ - pytest.mark.skipif('CI' in os.environ, - reason='to reduce running time on CI') + pytest.mark.skipif("CI" in os.environ, reason="to reduce running time on CI") ] kwargs["marks"] = jit_markers return pytest.param(*args, **kwargs) @@ -33,7 +32,6 @@ def jit_idfn(param): class GaussianChain: - def __init__(self, dim, chain_len, num_obs): self.dim = dim self.chain_len = chain_len @@ -45,39 +43,45 @@ def model(self, data): loc = self.loc_0 lambda_prec = self.lambda_prec for i in range(1, self.chain_len + 1): - loc = pyro.sample('loc_{}'.format(i), - dist.Normal(loc=loc, scale=lambda_prec)) - pyro.sample('obs', dist.Normal(loc, lambda_prec), obs=data) + loc = pyro.sample( + "loc_{}".format(i), dist.Normal(loc=loc, scale=lambda_prec) + ) + pyro.sample("obs", dist.Normal(loc, lambda_prec), obs=data) @property def data(self): return torch.ones(self.num_obs, self.dim) def id_fn(self): - return 'dim={}_chain-len={}_num_obs={}'.format(self.dim, self.chain_len, self.num_obs) + return "dim={}_chain-len={}_num_obs={}".format( + self.dim, self.chain_len, self.num_obs + ) def rmse(t1, t2): return (t1 - t2).pow(2).mean().sqrt() -T = namedtuple('TestExample', [ - 'fixture', - 'num_samples', - 'warmup_steps', - 'hmc_params', - 'expected_means', - 'expected_precs', - 'mean_tol', - 'std_tol']) +T = namedtuple( + "TestExample", + [ + "fixture", + "num_samples", + "warmup_steps", + "hmc_params", + "expected_means", + "expected_precs", + "mean_tol", + "std_tol", + ], +) TEST_CASES = [ T( GaussianChain(dim=10, chain_len=3, num_obs=1), num_samples=800, warmup_steps=200, - hmc_params={'step_size': 0.5, - 'num_steps': 4}, + hmc_params={"step_size": 0.5, "num_steps": 4}, expected_means=[0.25, 0.50, 0.75], expected_precs=[1.33, 1, 1.33], mean_tol=0.08, @@ -87,8 +91,7 @@ def rmse(t1, t2): GaussianChain(dim=10, chain_len=4, num_obs=1), num_samples=1600, warmup_steps=300, - hmc_params={'step_size': 0.46, - 'num_steps': 5}, + hmc_params={"step_size": 0.46, "num_steps": 5}, expected_means=[0.20, 0.40, 0.60, 0.80], expected_precs=[1.25, 0.83, 0.83, 1.25], mean_tol=0.08, @@ -98,7 +101,7 @@ def rmse(t1, t2): GaussianChain(dim=5, chain_len=2, num_obs=100), num_samples=2000, warmup_steps=1000, - hmc_params={'num_steps': 15, 'step_size': 0.7}, + hmc_params={"num_steps": 15, "step_size": 0.7}, expected_means=[0.5, 1.0], expected_precs=[2.0, 100], mean_tol=0.08, @@ -108,38 +111,42 @@ def rmse(t1, t2): GaussianChain(dim=5, chain_len=9, num_obs=1), num_samples=3000, warmup_steps=500, - hmc_params={'step_size': 0.2, - 'num_steps': 15}, + hmc_params={"step_size": 0.2, "num_steps": 15}, expected_means=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11], mean_tol=0.11, std_tol=0.11, - ) + ), ] -TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample' - else t[0][0].id_fn() for t in TEST_CASES] +TEST_IDS = [ + t[0].id_fn() if type(t).__name__ == "TestExample" else t[0][0].id_fn() + for t in TEST_CASES +] @pytest.mark.parametrize( - 'fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol', + "fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol", TEST_CASES, - ids=TEST_IDS) -@pytest.mark.skip(reason='Slow test (https://github.com/pytorch/pytorch/issues/12190)') + ids=TEST_IDS, +) +@pytest.mark.skip(reason="Slow test (https://github.com/pytorch/pytorch/issues/12190)") @pytest.mark.disable_validation() -def test_hmc_conjugate_gaussian(fixture, - num_samples, - warmup_steps, - hmc_params, - expected_means, - expected_precs, - mean_tol, - std_tol): +def test_hmc_conjugate_gaussian( + fixture, + num_samples, + warmup_steps, + hmc_params, + expected_means, + expected_precs, + mean_tol, + std_tol, +): pyro.get_param_store().clear() hmc_kernel = HMC(fixture.model, **hmc_params) samples = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data) for i in range(1, fixture.chain_len + 1): - param_name = 'loc_' + str(i) + param_name = "loc_" + str(i) marginal = samples[param_name] latent_loc = marginal.mean(0) latent_std = marginal.var(0).sqrt() @@ -147,16 +154,16 @@ def test_hmc_conjugate_gaussian(fixture, expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents - logger.debug('Posterior mean (actual) - {}'.format(param_name)) + logger.debug("Posterior mean (actual) - {}".format(param_name)) logger.debug(latent_loc) - logger.debug('Posterior mean (expected) - {}'.format(param_name)) + logger.debug("Posterior mean (expected) - {}".format(param_name)) logger.debug(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents - logger.debug('Posterior std (actual) - {}'.format(param_name)) + logger.debug("Posterior std (actual) - {}".format(param_name)) logger.debug(latent_std) - logger.debug('Posterior std (expected) - {}'.format(param_name)) + logger.debug("Posterior std (expected) - {}".format(param_name)) logger.debug(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) @@ -169,27 +176,39 @@ def test_hmc_conjugate_gaussian(fixture, (None, 1, None, True, False, False), (None, 1, None, True, True, False), (None, 1, None, True, True, True), - ] + ], ) -def test_logistic_regression(step_size, trajectory_length, num_steps, - adapt_step_size, adapt_mass_matrix, full_mass): +def test_logistic_regression( + step_size, + trajectory_length, + num_steps, + adapt_step_size, + adapt_mass_matrix, + full_mass, +): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1., dim + 1.) + true_coefs = torch.arange(1.0, dim + 1.0) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): - coefs_mean = pyro.param('coefs_mean', torch.zeros(dim)) - coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(dim))) - y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) + coefs_mean = pyro.param("coefs_mean", torch.zeros(dim)) + coefs = pyro.sample("beta", dist.Normal(coefs_mean, torch.ones(dim))) + y = pyro.sample("y", dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, - num_steps=num_steps, adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass) + hmc_kernel = HMC( + model, + step_size=step_size, + trajectory_length=trajectory_length, + num_steps=num_steps, + adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, + full_mass=full_mass, + ) mcmc = MCMC(hmc_kernel, num_samples=500, warmup_steps=100, disable_progbar=True) mcmc.run(data) - samples = mcmc.get_samples()['beta'] + samples = mcmc.get_samples()["beta"] assert_equal(rmse(true_coefs, samples.mean(0)).item(), 0.0, prec=0.1) @@ -197,17 +216,19 @@ def model(data): def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) - p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) + p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, trajectory_length=1, jit_compile=jit, ignore_jit_warnings=True) + hmc_kernel = HMC( + model, trajectory_length=1, jit_compile=jit, ignore_jit_warnings=True + ) mcmc = MCMC(hmc_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() - assert_equal(samples['p_latent'].mean(0), true_probs, prec=0.02) + assert_equal(samples["p_latent"].mean(0), true_probs, prec=0.02) @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @@ -215,26 +236,31 @@ def test_beta_bernoulli(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) - p_latent = pyro.sample('p_latent', dist.Beta(alpha, beta)) + p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta)) with pyro.plate("data", data.shape[0], dim=-2): - pyro.sample('obs', dist.Bernoulli(p_latent), obs=data) + pyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, trajectory_length=1, max_plate_nesting=2, - jit_compile=jit, ignore_jit_warnings=True) + hmc_kernel = HMC( + model, + trajectory_length=1, + max_plate_nesting=2, + jit_compile=jit, + ignore_jit_warnings=True, + ) mcmc = MCMC(hmc_kernel, num_samples=800, warmup_steps=500) mcmc.run(data) samples = mcmc.get_samples() - assert_equal(samples['p_latent'].mean(0), true_probs, prec=0.05) + assert_equal(samples["p_latent"].mean(0), true_probs, prec=0.05) def test_gamma_normal(): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) - p_latent = pyro.sample('p_latent', dist.Gamma(rate, concentration)) + p_latent = pyro.sample("p_latent", dist.Gamma(rate, concentration)) pyro.sample("obs", dist.Normal(3, p_latent), obs=data) return p_latent @@ -244,7 +270,7 @@ def model(data): mcmc = MCMC(hmc_kernel, num_samples=200, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() - assert_equal(samples['p_latent'].mean(0), true_std, prec=0.05) + assert_equal(samples["p_latent"].mean(0), true_std, prec=0.05) @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @@ -254,33 +280,42 @@ def model(data): y = pyro.sample("y", dist.Bernoulli(y_prob)) with pyro.plate("data", data.shape[0]): z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) - pyro.sample("obs", dist.Normal(2. * z, 1.), obs=data) + pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data) pyro.sample("nuisance", dist.Bernoulli(0.3)) N = 2000 y_prob = torch.tensor(0.3) y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() - data = dist.Normal(2. * z, 1.0).sample() - hmc_kernel = HMC(model, trajectory_length=1, max_plate_nesting=1, - jit_compile=jit, ignore_jit_warnings=True) + data = dist.Normal(2.0 * z, 1.0).sample() + hmc_kernel = HMC( + model, + trajectory_length=1, + max_plate_nesting=1, + jit_compile=jit, + ignore_jit_warnings=True, + ) mcmc = MCMC(hmc_kernel, num_samples=600, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() - assert_equal(samples['y_prob'].mean(0), y_prob, prec=0.06) + assert_equal(samples["y_prob"].mean(0), y_prob, prec=0.06) @pytest.mark.parametrize("kernel", [HMC, NUTS]) @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.skipif("CUDA_TEST" in os.environ, reason="https://github.com/pytorch/pytorch/issues/22811") +@pytest.mark.skipif( + "CUDA_TEST" in os.environ, reason="https://github.com/pytorch/pytorch/issues/22811" +) def test_unnormalized_normal(kernel, jit): - true_mean, true_std = torch.tensor(5.), torch.tensor(1.) - init_params = {"z": torch.tensor(0.)} + true_mean, true_std = torch.tensor(5.0), torch.tensor(1.0) + init_params = {"z": torch.tensor(0.0)} def potential_energy(params): return 0.5 * torch.sum(((params["z"] - true_mean) / true_std) ** 2) - potential_fn = potential_energy if not jit else torch.jit.trace(potential_energy, init_params) + potential_fn = ( + potential_energy if not jit else torch.jit.trace(potential_energy, init_params) + ) hmc_kernel = kernel(model=None, potential_fn=potential_fn) samples = init_params @@ -302,20 +337,25 @@ def potential_energy(params): assert_close(torch.std(posterior), true_std, rtol=0.05) -@pytest.mark.parametrize('jit', [False, mark_jit(True)], ids=jit_idfn) -@pytest.mark.parametrize('op', [torch.inverse, torch.linalg.cholesky]) +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("op", [torch.inverse, torch.linalg.cholesky]) def test_singular_matrix_catch(jit, op): def potential_energy(z): - return op(z['cov']).sum() + return op(z["cov"]).sum() - init_params = {'cov': torch.eye(3)} - potential_fn = potential_energy if not jit else torch.jit.trace(potential_energy, init_params) - hmc_kernel = HMC(potential_fn=potential_fn, adapt_step_size=False, - num_steps=10, step_size=1e-20) + init_params = {"cov": torch.eye(3)} + potential_fn = ( + potential_energy if not jit else torch.jit.trace(potential_energy, init_params) + ) + hmc_kernel = HMC( + potential_fn=potential_fn, adapt_step_size=False, num_steps=10, step_size=1e-20 + ) hmc_kernel.initial_params = init_params hmc_kernel.setup(warmup_steps=0) # setup an invalid cache to trigger singular error for torch.inverse - hmc_kernel._cache({'cov': torch.ones(3, 3)}, torch.tensor(0.), {'cov': torch.zeros(3, 3)}) + hmc_kernel._cache( + {"cov": torch.ones(3, 3)}, torch.tensor(0.0), {"cov": torch.zeros(3, 3)} + ) samples = init_params for i in range(10): diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index 13dc2e070b..16947867f5 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -24,6 +24,7 @@ class PriorKernel(MCMCKernel): Disregards the value of the current trace (or observed data) and samples a value from the model's prior. """ + def __init__(self, model): self.model = model self.data = None @@ -33,8 +34,9 @@ def __init__(self, model): def setup(self, warmup_steps, data): self.data = data - init_params, potential_fn, transforms, model_trace = initialize_model(self.model, - model_args=(data,)) + init_params, potential_fn, transforms, model_trace = initialize_model( + self.model, model_args=(data,) + ) if self._initial_params is None: self._initial_params = init_params if self.transforms is None: @@ -42,7 +44,7 @@ def setup(self, warmup_steps, data): self._prototype_trace = model_trace def diagnostics(self): - return {'dummy_key': 'dummy_value'} + return {"dummy_key": "dummy_value"} @property def initial_params(self): @@ -69,26 +71,61 @@ def sample(self, params): def normal_normal_model(data): x = torch.tensor([0.0]) - y = pyro.sample('y', dist.Normal(x, torch.ones(data.shape))) - pyro.sample('obs', dist.Normal(y, torch.tensor([1.0])), obs=data) + y = pyro.sample("y", dist.Normal(x, torch.ones(data.shape))) + pyro.sample("obs", dist.Normal(y, torch.tensor([1.0])), obs=data) return y -def run_default_mcmc(data, kernel, num_samples, warmup_steps=None, initial_params=None, - num_chains=1, hook_fn=None, mp_context=None, transforms=None, num_draws=None, - group_by_chain=False): - mcmc = MCMC(kernel=kernel, num_samples=num_samples, warmup_steps=warmup_steps, initial_params=initial_params, - num_chains=num_chains, hook_fn=hook_fn, mp_context=mp_context, transforms=transforms) +def run_default_mcmc( + data, + kernel, + num_samples, + warmup_steps=None, + initial_params=None, + num_chains=1, + hook_fn=None, + mp_context=None, + transforms=None, + num_draws=None, + group_by_chain=False, +): + mcmc = MCMC( + kernel=kernel, + num_samples=num_samples, + warmup_steps=warmup_steps, + initial_params=initial_params, + num_chains=num_chains, + hook_fn=hook_fn, + mp_context=mp_context, + transforms=transforms, + ) mcmc.run(data) return mcmc.get_samples(num_draws, group_by_chain=group_by_chain), mcmc.num_chains -def run_streaming_mcmc(data, kernel, num_samples, warmup_steps=None, initial_params=None, - num_chains=1, hook_fn=None, mp_context=None, transforms=None, num_draws=None, - group_by_chain=False): - mcmc = StreamingMCMC(kernel=kernel, num_samples=num_samples, warmup_steps=warmup_steps, - initial_params=initial_params, statistics=StatsOfDict(default=StackStats), - num_chains=num_chains, hook_fn=hook_fn, transforms=transforms) +def run_streaming_mcmc( + data, + kernel, + num_samples, + warmup_steps=None, + initial_params=None, + num_chains=1, + hook_fn=None, + mp_context=None, + transforms=None, + num_draws=None, + group_by_chain=False, +): + mcmc = StreamingMCMC( + kernel=kernel, + num_samples=num_samples, + warmup_steps=warmup_steps, + initial_params=initial_params, + statistics=StatsOfDict(default=StackStats), + num_chains=num_chains, + hook_fn=hook_fn, + transforms=transforms, + ) mcmc.run(data) statistics = mcmc.get_statistics(group_by_chain=group_by_chain) @@ -97,13 +134,13 @@ def run_streaming_mcmc(data, kernel, num_samples, warmup_steps=None, initial_par agg = {} for (_, name), stat in statistics.items(): if name in agg: - agg[name].append(stat['samples']) + agg[name].append(stat["samples"]) else: - agg[name] = [stat['samples']] + agg[name] = [stat["samples"]] for name, l in agg.items(): samples[name] = torch.stack(l) else: - samples = {name: stat['samples'] for name, stat in statistics.items()} + samples = {name: stat["samples"] for name, stat in statistics.items()} samples = select_samples(samples, num_draws, group_by_chain) @@ -114,19 +151,29 @@ def run_streaming_mcmc(data, kernel, num_samples, warmup_steps=None, initial_par @pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) -@pytest.mark.parametrize('num_draws', [None, 1800, 2200]) -@pytest.mark.parametrize('group_by_chain', [False, True]) -@pytest.mark.parametrize('num_chains', [1, 2]) +@pytest.mark.parametrize("num_draws", [None, 1800, 2200]) +@pytest.mark.parametrize("group_by_chain", [False, True]) +@pytest.mark.parametrize("num_chains", [1, 2]) @pytest.mark.filterwarnings("ignore:num_chains") def test_mcmc_interface(run_mcmc_cls, num_draws, group_by_chain, num_chains): num_samples = 2000 data = torch.tensor([1.0]) - initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,), - num_chains=num_chains) + initial_params, _, transforms, _ = initialize_model( + normal_normal_model, model_args=(data,), num_chains=num_chains + ) kernel = PriorKernel(normal_normal_model) - samples, mcmc_num_chains = run_mcmc_cls(data, kernel, num_samples=num_samples, warmup_steps=100, - initial_params=initial_params, num_chains=num_chains, mp_context='spawn', - transforms=transforms, num_draws=num_draws, group_by_chain=group_by_chain) + samples, mcmc_num_chains = run_mcmc_cls( + data, + kernel, + num_samples=num_samples, + warmup_steps=100, + initial_params=initial_params, + num_chains=num_chains, + mp_context="spawn", + transforms=transforms, + num_draws=num_draws, + group_by_chain=group_by_chain, + ) # test sample shape expected_samples = num_draws if num_draws is not None else num_samples if group_by_chain: @@ -136,39 +183,48 @@ def test_mcmc_interface(run_mcmc_cls, num_draws, group_by_chain, num_chains): expected_shape = (expected_samples, 1) else: expected_shape = (mcmc_num_chains * expected_samples, 1) - assert samples['y'].shape == expected_shape + assert samples["y"].shape == expected_shape # test sample stats if group_by_chain: samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} - sample_mean = samples['y'].mean() - sample_std = samples['y'].std() + sample_mean = samples["y"].mean() + sample_std = samples["y"].std() assert_close(sample_mean, torch.tensor(0.0), atol=0.1) assert_close(sample_std, torch.tensor(1.0), atol=0.1) -@pytest.mark.parametrize("num_chains, cpu_count", [ - (1, 2), - (2, 1), - (2, 2), - (2, 3), -]) +@pytest.mark.parametrize( + "num_chains, cpu_count", + [ + (1, 2), + (2, 1), + (2, 2), + (2, 3), + ], +) @pytest.mark.parametrize("default_init_params", [True, False]) -def test_num_chains(num_chains, cpu_count, default_init_params, - monkeypatch): - monkeypatch.setattr(torch.multiprocessing, 'cpu_count', lambda: cpu_count) +def test_num_chains(num_chains, cpu_count, default_init_params, monkeypatch): + monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: cpu_count) data = torch.tensor([1.0]) - initial_params, _, transforms, _ = initialize_model(normal_normal_model, - model_args=(data,), - num_chains=num_chains) + initial_params, _, transforms, _ = initialize_model( + normal_normal_model, model_args=(data,), num_chains=num_chains + ) if default_init_params: initial_params = None kernel = PriorKernel(normal_normal_model) - available_cpu = max(1, cpu_count-1) + available_cpu = max(1, cpu_count - 1) mp_context = "spawn" with optional(pytest.warns(UserWarning), available_cpu < num_chains): - mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, - initial_params=initial_params, transforms=transforms, mp_context=mp_context) + mcmc = MCMC( + kernel, + num_samples=10, + warmup_steps=10, + num_chains=num_chains, + initial_params=initial_params, + transforms=transforms, + mp_context=mp_context, + ) mcmc.run(data) assert mcmc.num_chains == num_chains if mcmc.num_chains == 1 or available_cpu < num_chains: @@ -186,24 +242,22 @@ def _hook(iters, kernel, samples, stage, i): iters.append((stage, i)) -@pytest.mark.parametrize("run_mcmc_cls", [ - run_default_mcmc, - run_streaming_mcmc -]) -@pytest.mark.parametrize("kernel, model", [ - (HMC, _empty_model), - (NUTS, _empty_model), -]) +@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) +@pytest.mark.parametrize( + "kernel, model", + [ + (HMC, _empty_model), + (NUTS, _empty_model), + ], +) @pytest.mark.parametrize("jit", [False, True]) -@pytest.mark.parametrize("num_chains", [ - 1, - 2 -]) +@pytest.mark.parametrize("num_chains", [1, 2]) @pytest.mark.filterwarnings("ignore:num_chains") def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chains): num_warmup, num_samples = 10, 10 - initial_params, potential_fn, transforms, _ = initialize_model(model, - num_chains=num_chains) + initial_params, potential_fn, transforms, _ = initialize_model( + model, num_chains=num_chains + ) iters = [] hook = partial(_hook, iters) @@ -211,35 +265,52 @@ def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chains): mp_context = "spawn" if "CUDA_TEST" in os.environ else None kern = kernel(potential_fn=potential_fn, transforms=transforms, jit_compile=jit) - samples, _ = run_mcmc_cls(data=None, kernel=kern, num_samples=num_samples, warmup_steps=num_warmup, - initial_params=initial_params, hook_fn=hook, num_chains=num_chains, mp_context=mp_context) + samples, _ = run_mcmc_cls( + data=None, + kernel=kern, + num_samples=num_samples, + warmup_steps=num_warmup, + initial_params=initial_params, + hook_fn=hook, + num_chains=num_chains, + mp_context=mp_context, + ) assert samples == {} if num_chains == 1: - expected = [("Warmup", i) for i in range(num_warmup)] + [("Sample", i) for i in range(num_samples)] + expected = [("Warmup", i) for i in range(num_warmup)] + [ + ("Sample", i) for i in range(num_samples) + ] assert iters == expected -@pytest.mark.parametrize("run_mcmc_cls", [ - run_default_mcmc, - run_streaming_mcmc -]) -@pytest.mark.parametrize("num_chains", [ - 1, - 2 -]) +@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) +@pytest.mark.parametrize("num_chains", [1, 2]) @pytest.mark.filterwarnings("ignore:num_chains") def test_mcmc_diagnostics(run_mcmc_cls, num_chains): data = torch.tensor([2.0]).repeat(3) - initial_params, _, transforms, _ = initialize_model(normal_normal_model, - model_args=(data,), - num_chains=num_chains) + initial_params, _, transforms, _ = initialize_model( + normal_normal_model, model_args=(data,), num_chains=num_chains + ) kernel = PriorKernel(normal_normal_model) if run_mcmc_cls == run_default_mcmc: - mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn", - initial_params=initial_params, transforms=transforms) + mcmc = MCMC( + kernel, + num_samples=10, + warmup_steps=10, + num_chains=num_chains, + mp_context="spawn", + initial_params=initial_params, + transforms=transforms, + ) else: - mcmc = StreamingMCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, - initial_params=initial_params, transforms=transforms) + mcmc = StreamingMCMC( + kernel, + num_samples=10, + warmup_steps=10, + num_chains=num_chains, + initial_params=initial_params, + transforms=transforms, + ) mcmc.run(data) if not torch.backends.mkl.is_available(): pytest.skip() @@ -247,34 +318,49 @@ def test_mcmc_diagnostics(run_mcmc_cls, num_chains): if run_mcmc_cls == run_default_mcmc: # TODO n_eff for streaming MCMC assert diagnostics["y"]["n_eff"].shape == data.shape assert diagnostics["y"]["r_hat"].shape == data.shape - assert diagnostics["dummy_key"] == {'chain {}'.format(i): 'dummy_value' - for i in range(num_chains)} + assert diagnostics["dummy_key"] == { + "chain {}".format(i): "dummy_value" for i in range(num_chains) + } @pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) @pytest.mark.filterwarnings("ignore:num_chains") def test_sequential_consistent(run_mcmc_cls, monkeypatch): # test if there is no stuff left from the previous chain - monkeypatch.setattr(torch.multiprocessing, 'cpu_count', lambda: 1) + monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: 1) class FirstKernel(NUTS): def setup(self, warmup_steps, *args, **kwargs): - self._chain_id = 0 if '_chain_id' not in self.__dict__ else 1 + self._chain_id = 0 if "_chain_id" not in self.__dict__ else 1 pyro.set_rng_seed(self._chain_id) super().setup(warmup_steps, *args, **kwargs) class SecondKernel(NUTS): def setup(self, warmup_steps, *args, **kwargs): - self._chain_id = 1 if '_chain_id' not in self.__dict__ else 0 + self._chain_id = 1 if "_chain_id" not in self.__dict__ else 0 pyro.set_rng_seed(self._chain_id) super().setup(warmup_steps, *args, **kwargs) data = torch.tensor([1.0]) kernel = FirstKernel(normal_normal_model) - samples1, _ = run_mcmc_cls(data, kernel, num_samples=100, warmup_steps=100, num_chains=2, group_by_chain=True) + samples1, _ = run_mcmc_cls( + data, + kernel, + num_samples=100, + warmup_steps=100, + num_chains=2, + group_by_chain=True, + ) kernel = SecondKernel(normal_normal_model) - samples2, _ = run_mcmc_cls(data, kernel, num_samples=100, warmup_steps=100, num_chains=2, group_by_chain=True) + samples2, _ = run_mcmc_cls( + data, + kernel, + num_samples=100, + warmup_steps=100, + num_chains=2, + group_by_chain=True, + ) assert_close(samples1["y"][0], samples2["y"][1]) assert_close(samples1["y"][1], samples2["y"][0]) @@ -282,20 +368,28 @@ def setup(self, warmup_steps, *args, **kwargs): @pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) def test_model_with_potential_fn(run_mcmc_cls): - init_params = {"z": torch.tensor(0.)} + init_params = {"z": torch.tensor(0.0)} def potential_fn(params): return params["z"] - run_mcmc_cls(data=None, kernel=HMC(potential_fn=potential_fn), num_samples=10, - warmup_steps=10, initial_params=init_params) + run_mcmc_cls( + data=None, + kernel=HMC(potential_fn=potential_fn), + num_samples=10, + warmup_steps=10, + initial_params=init_params, + ) @pytest.mark.parametrize("save_params", ["xy", "x", "y", "xy"]) -@pytest.mark.parametrize("Kernel,options", [ - (HMC, {}), - (NUTS, {"max_tree_depth": 2}), -]) +@pytest.mark.parametrize( + "Kernel,options", + [ + (HMC, {}), + (NUTS, {"max_tree_depth": 2}), + ], +) def test_save_params(save_params, Kernel, options): save_params = list(save_params) diff --git a/tests/infer/mcmc/test_mcmc_util.py b/tests/infer/mcmc/test_mcmc_util.py index b0e34d6d50..afed008993 100644 --- a/tests/infer/mcmc/test_mcmc_util.py +++ b/tests/infer/mcmc/test_mcmc_util.py @@ -32,7 +32,7 @@ def beta_bernoulli(): def model(data=None): with pyro.plate("num_components", 5): - beta = pyro.sample("beta", dist.Beta(1., 1.)) + beta = pyro.sample("beta", dist.Beta(1.0, 1.0)) with pyro.plate("data", N): pyro.sample("obs", dist.Bernoulli(beta), obs=data) @@ -43,20 +43,21 @@ def model(data=None): @pytest.mark.parametrize("parallel", [False, True]) def test_predictive(num_samples, parallel): model, data, true_probs = beta_bernoulli() - init_params, potential_fn, transforms, _ = initialize_model(model, - model_args=(data,)) + init_params, potential_fn, transforms, _ = initialize_model( + model, model_args=(data,) + ) nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms) - mcmc = MCMC(nuts_kernel, - 100, - initial_params=init_params, - warmup_steps=100) + mcmc = MCMC(nuts_kernel, 100, initial_params=init_params, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() with optional(pytest.warns(UserWarning), num_samples not in (None, 100)): - predictive = Predictive(model, samples, - num_samples=num_samples, - return_sites=["beta", "obs"], - parallel=parallel) + predictive = Predictive( + model, + samples, + num_samples=num_samples, + return_sites=["beta", "obs"], + parallel=parallel, + ) predictive_samples = predictive() # check shapes @@ -64,11 +65,13 @@ def test_predictive(num_samples, parallel): assert predictive_samples["obs"].shape == (100, 1000, 5) # check sample mean - assert_close(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1) + assert_close( + predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1 + ) def model_with_param(): - x = pyro.param("x", torch.tensor(1.)) + x = pyro.param("x", torch.tensor(1.0)) pyro.sample("y", dist.Normal(x, 1)) @@ -105,25 +108,28 @@ def model(): value = torch.randn(()).exp() * 10 kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value})) kernel.setup(warmup_steps=10) - assert_close(value, kernel.initial_params['x'].exp()) - - -@pytest.mark.parametrize("init_strategy", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, - init_to_uniform, - init_to_value, - init_to_feasible(), - init_to_mean(), - init_to_median(num_samples=4), - init_to_sample(), - init_to_uniform(radius=0.1), - init_to_value(values={"x": torch.tensor(3.)}), - init_to_generated( - generate=lambda: init_to_value(values={"x": torch.rand(())})), -], ids=str_erase_pointers) + assert_close(value, kernel.initial_params["x"].exp()) + + +@pytest.mark.parametrize( + "init_strategy", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + init_to_uniform, + init_to_value, + init_to_feasible(), + init_to_mean(), + init_to_median(num_samples=4), + init_to_sample(), + init_to_uniform(radius=0.1), + init_to_value(values={"x": torch.tensor(3.0)}), + init_to_generated(generate=lambda: init_to_value(values={"x": torch.rand(())})), + ], + ids=str_erase_pointers, +) def test_init_strategy_smoke(init_strategy): def model(): pyro.sample("x", dist.LogNormal(0, 1)) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index caa576591c..646f8e3b72 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -29,14 +29,18 @@ logger = logging.getLogger(__name__) -T = namedtuple('TestExample', [ - 'fixture', - 'num_samples', - 'warmup_steps', - 'expected_means', - 'expected_precs', - 'mean_tol', - 'std_tol']) +T = namedtuple( + "TestExample", + [ + "fixture", + "num_samples", + "warmup_steps", + "expected_means", + "expected_precs", + "mean_tol", + "std_tol", + ], +) TEST_CASES = [ T( @@ -74,19 +78,20 @@ expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11], mean_tol=0.08, std_tol=0.08, - ) + ), ] -TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample' - else t[0][0].id_fn() for t in TEST_CASES] +TEST_IDS = [ + t[0].id_fn() if type(t).__name__ == "TestExample" else t[0][0].id_fn() + for t in TEST_CASES +] def mark_jit(*args, **kwargs): jit_markers = kwargs.pop("marks", []) jit_markers += [ - pytest.mark.skipif('CI' in os.environ, - reason='to reduce running time on CI') + pytest.mark.skipif("CI" in os.environ, reason="to reduce running time on CI") ] kwargs["marks"] = jit_markers return pytest.param(*args, **kwargs) @@ -97,25 +102,28 @@ def jit_idfn(param): @pytest.mark.parametrize( - 'fixture, num_samples, warmup_steps, expected_means, expected_precs, mean_tol, std_tol', + "fixture, num_samples, warmup_steps, expected_means, expected_precs, mean_tol, std_tol", TEST_CASES, - ids=TEST_IDS) -@pytest.mark.skip(reason='Slow test (https://github.com/pytorch/pytorch/issues/12190)') + ids=TEST_IDS, +) +@pytest.mark.skip(reason="Slow test (https://github.com/pytorch/pytorch/issues/12190)") @pytest.mark.disable_validation() -def test_nuts_conjugate_gaussian(fixture, - num_samples, - warmup_steps, - expected_means, - expected_precs, - mean_tol, - std_tol): +def test_nuts_conjugate_gaussian( + fixture, + num_samples, + warmup_steps, + expected_means, + expected_precs, + mean_tol, + std_tol, +): pyro.get_param_store().clear() nuts_kernel = NUTS(fixture.model) mcmc = MCMC(nuts_kernel, num_samples, warmup_steps) mcmc.run(fixture.data) samples = mcmc.get_samples() for i in range(1, fixture.chain_len + 1): - param_name = 'loc_' + str(i) + param_name = "loc_" + str(i) latent = samples[param_name] latent_loc = latent.mean(0) latent_std = latent.std(0) @@ -123,16 +131,16 @@ def test_nuts_conjugate_gaussian(fixture, expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents - logger.debug('Posterior mean (actual) - {}'.format(param_name)) + logger.debug("Posterior mean (actual) - {}".format(param_name)) logger.debug(latent_loc) - logger.debug('Posterior mean (expected) - {}'.format(param_name)) + logger.debug("Posterior mean (expected) - {}".format(param_name)) logger.debug(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents - logger.debug('Posterior std (actual) - {}'.format(param_name)) + logger.debug("Posterior std (actual) - {}".format(param_name)) logger.debug(latent_std) - logger.debug('Posterior std (expected) - {}'.format(param_name)) + logger.debug("Posterior std (expected) - {}".format(param_name)) logger.debug(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) @@ -142,19 +150,21 @@ def test_nuts_conjugate_gaussian(fixture, def test_logistic_regression(jit, use_multinomial_sampling): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1., dim + 1.) + true_coefs = torch.arange(1.0, dim + 1.0) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): coefs_mean = torch.zeros(dim) - coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(dim))) - y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) + coefs = pyro.sample("beta", dist.Normal(coefs_mean, torch.ones(dim))) + y = pyro.sample("y", dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, - use_multinomial_sampling=use_multinomial_sampling, - jit_compile=jit, - ignore_jit_warnings=True) + nuts_kernel = NUTS( + model, + use_multinomial_sampling=use_multinomial_sampling, + jit_compile=jit, + ignore_jit_warnings=True, + ) mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() @@ -169,7 +179,7 @@ def model(data): (None, True, False, False), (None, True, True, False), (None, True, True, True), - ] + ], ) def test_beta_bernoulli(step_size, adapt_step_size, adapt_mass_matrix, full_mass): def model(data): @@ -181,8 +191,13 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, step_size=step_size, adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass) + nuts_kernel = NUTS( + model, + step_size=step_size, + adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, + full_mass=full_mass, + ) mcmc = MCMC(nuts_kernel, num_samples=400, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() @@ -195,16 +210,18 @@ def test_gamma_normal(jit, use_multinomial_sampling): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) - p_latent = pyro.sample('p_latent', dist.Gamma(rate, concentration)) + p_latent = pyro.sample("p_latent", dist.Gamma(rate, concentration)) pyro.sample("obs", dist.Normal(3, p_latent), obs=data) return p_latent true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, - use_multinomial_sampling=use_multinomial_sampling, - jit_compile=jit, - ignore_jit_warnings=True) + nuts_kernel = NUTS( + model, + use_multinomial_sampling=use_multinomial_sampling, + jit_compile=jit, + ignore_jit_warnings=True, + ) mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() @@ -215,7 +232,7 @@ def model(data): def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) - p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) + p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent @@ -232,13 +249,19 @@ def model(data): @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) def test_gamma_beta(jit): def model(data): - alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=1., rate=1.)) - beta_prior = pyro.sample('beta', dist.Gamma(concentration=1., rate=1.)) - pyro.sample('x', dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), obs=data) - - true_alpha = torch.tensor(5.) - true_beta = torch.tensor(1.) - data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,))) + alpha_prior = pyro.sample("alpha", dist.Gamma(concentration=1.0, rate=1.0)) + beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0, rate=1.0)) + pyro.sample( + "x", + dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), + obs=data, + ) + + true_alpha = torch.tensor(5.0) + true_beta = torch.tensor(1.0) + data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample( + torch.Size((5000,)) + ) nuts_kernel = NUTS(model, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=200) mcmc.run(data) @@ -254,40 +277,50 @@ def test_gaussian_mixture_model(jit): def gmm(data): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.plate("num_clusters", K): - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) + cluster_means = pyro.sample( + "cluster_means", dist.Normal(torch.arange(float(K)), 1.0) + ) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) - pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) + pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) return cluster_means - true_cluster_means = torch.tensor([1., 5., 10.]) + true_cluster_means = torch.tensor([1.0, 5.0, 10.0]) true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) - cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) + cluster_assignments = dist.Categorical(true_mix_proportions).sample( + torch.Size((N,)) + ) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - nuts_kernel = NUTS(gmm, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True) + nuts_kernel = NUTS( + gmm, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True + ) mcmc = MCMC(nuts_kernel, num_samples=300, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["phi"].mean(0).sort()[0], true_mix_proportions, prec=0.05) - assert_equal(samples["cluster_means"].mean(0).sort()[0], true_cluster_means, prec=0.2) + assert_equal( + samples["cluster_means"].mean(0).sort()[0], true_cluster_means, prec=0.2 + ) @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): - y_prob = pyro.sample("y_prob", dist.Beta(1., 1.)) + y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0)) with pyro.plate("data", data.shape[0]): y = pyro.sample("y", dist.Bernoulli(y_prob)) z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) - pyro.sample("obs", dist.Normal(2. * z, 1.), obs=data) + pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data) N = 2000 y_prob = torch.tensor(0.3) y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() - data = dist.Normal(2. * z, 1.0).sample() - nuts_kernel = NUTS(model, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True) + data = dist.Normal(2.0 * z, 1.0).sample() + nuts_kernel = NUTS( + model, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True + ) mcmc = MCMC(nuts_kernel, num_samples=600, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() @@ -302,21 +335,36 @@ def model(data): initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim))) with pyro.plate("states", dim): transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim))) - emission_loc = pyro.sample("emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))) - emission_scale = pyro.sample("emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))) + emission_loc = pyro.sample( + "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim)) + ) + emission_scale = pyro.sample( + "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim)) + ) x = None with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): for t, y in pyro.markov(enumerate(data)): - x = pyro.sample("x_{}".format(t), - dist.Categorical(initialize if x is None else transition[x]), - infer={"enumerate": "parallel"}) - pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(initialize if x is None else transition[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample( + "y_{}".format(t), + dist.Normal(emission_loc[x], emission_scale[x]), + obs=y, + ) def _get_initial_trace(): - guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and - not msg["name"].startswith("y"))) + guide = AutoDelta( + poutine.block( + model, + expose_fn=lambda msg: not msg["name"].startswith("x") + and not msg["name"].startswith("y"), + ) + ) elbo = TraceEnum_ELBO(max_plate_nesting=1) - svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo) + svi = SVI(model, guide, optim.Adam({"lr": 0.01}), elbo) for _ in range(100): svi.step(data) return poutine.trace(guide).get_trace(data) @@ -324,7 +372,7 @@ def _get_initial_trace(): def _generate_data(): transition_probs = torch.rand(dim, dim) emissions_loc = torch.arange(dim, dtype=torch.Tensor().dtype) - emissions_scale = 1. + emissions_scale = 1.0 state = torch.tensor(1) obs = [dist.Normal(emissions_loc[state], emissions_scale).sample()] for _ in range(num_steps): @@ -333,7 +381,9 @@ def _generate_data(): return torch.stack(obs) data = _generate_data() - nuts_kernel = NUTS(model, max_plate_nesting=1, jit_compile=True, ignore_jit_warnings=True) + nuts_kernel = NUTS( + model, max_plate_nesting=1, jit_compile=True, ignore_jit_warnings=True + ) if num_steps == 30: nuts_kernel.initial_trace = _get_initial_trace() mcmc = MCMC(nuts_kernel, num_samples=5, warmup_steps=5) @@ -344,19 +394,35 @@ def _generate_data(): def test_beta_binomial(hyperpriors): def model(data): with pyro.plate("plate_0", data.shape[-1]): - alpha = pyro.sample("alpha", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor([1., 1.]) - beta = pyro.sample("beta", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor([1., 1.]) + alpha = ( + pyro.sample("alpha", dist.HalfCauchy(1.0)) + if hyperpriors + else torch.tensor([1.0, 1.0]) + ) + beta = ( + pyro.sample("beta", dist.HalfCauchy(1.0)) + if hyperpriors + else torch.tensor([1.0, 1.0]) + ) beta_binom = BetaBinomialPair() with pyro.plate("plate_1", data.shape[-2]): probs = pyro.sample("probs", beta_binom.latent(alpha, beta)) with pyro.plate("data", data.shape[0]): - pyro.sample("binomial", beta_binom.conditional(probs=probs, total_count=total_count), obs=data) + pyro.sample( + "binomial", + beta_binom.conditional(probs=probs, total_count=total_count), + obs=data, + ) true_probs = torch.tensor([[0.7, 0.4], [0.6, 0.4]]) total_count = torch.tensor([[1000, 600], [400, 800]]) num_samples = 80 - data = dist.Binomial(total_count=total_count, probs=true_probs).sample(sample_shape=(torch.Size((10,)))) - hmc_kernel = NUTS(collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True) + data = dist.Binomial(total_count=total_count, probs=true_probs).sample( + sample_shape=(torch.Size((10,))) + ) + hmc_kernel = NUTS( + collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True + ) mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50) mcmc.run(data) samples = mcmc.get_samples() @@ -368,17 +434,27 @@ def model(data): def test_gamma_poisson(hyperpriors): def model(data): with pyro.plate("latent_dim", data.shape[1]): - alpha = pyro.sample("alpha", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor([1., 1.]) - beta = pyro.sample("beta", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor([1., 1.]) + alpha = ( + pyro.sample("alpha", dist.HalfCauchy(1.0)) + if hyperpriors + else torch.tensor([1.0, 1.0]) + ) + beta = ( + pyro.sample("beta", dist.HalfCauchy(1.0)) + if hyperpriors + else torch.tensor([1.0, 1.0]) + ) gamma_poisson = GammaPoissonPair() rate = pyro.sample("rate", gamma_poisson.latent(alpha, beta)) with pyro.plate("data", data.shape[0]): pyro.sample("obs", gamma_poisson.conditional(rate), obs=data) - true_rate = torch.tensor([3., 10.]) + true_rate = torch.tensor([3.0, 10.0]) num_samples = 100 data = dist.Poisson(rate=true_rate).sample(sample_shape=(torch.Size((100,)))) - hmc_kernel = NUTS(collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True) + hmc_kernel = NUTS( + collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True + ) mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50) mcmc.run(data) samples = mcmc.get_samples() @@ -396,7 +472,7 @@ def model(cov): pyro.sample("obs", dist.MultivariateNormal(torch.zeros(5), cov), obs=wxyz) w_cov = torch.tensor([[1.5, 0.5], [0.5, 1.5]]) - xy_cov = torch.tensor([[2., 1.], [1., 3.]]) + xy_cov = torch.tensor([[2.0, 1.0], [1.0, 3.0]]) z_var = torch.tensor([2.5]) cov = torch.zeros(5, 5) cov[:2, :2] = w_cov @@ -405,12 +481,21 @@ def model(cov): # smoke tests for dense_mass in [True, False]: - kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass) + kernel = NUTS( + model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass + ) mcmc = MCMC(kernel, num_samples=1, warmup_steps=1) mcmc.run(cov) - assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass) - - kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("x", "y")]) + assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int( + dense_mass + ) + + kernel = NUTS( + model, + jit_compile=True, + ignore_jit_warnings=True, + full_mass=[("w",), ("x", "y")], + ) mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000) mcmc.run(cov) assert_close(kernel.inverse_mass_matrix[("w",)], w_cov, atol=0.5, rtol=0.5) @@ -425,19 +510,32 @@ def model(prec): y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1)) z = pyro.sample("z", dist.Normal(0, 1000).expand([2]).to_event(1)) wyxz = torch.cat([w, y, x, z]) - pyro.sample("obs", dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), obs=wyxz) + pyro.sample( + "obs", + dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), + obs=wyxz, + ) A = torch.randn(6, 12) prec = A @ A.t() * 0.1 # smoke tests for dense_mass in [True, False]: - kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass) + kernel = NUTS( + model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass + ) mcmc = MCMC(kernel, num_samples=1, warmup_steps=1) mcmc.run(prec) - assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass) - - kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("y", "x")]) + assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int( + dense_mass + ) + + kernel = NUTS( + model, + jit_compile=True, + ignore_jit_warnings=True, + full_mass=[("w",), ("y", "x")], + ) kernel.mass_matrix_adapter = ArrowheadMassMatrix() mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000) mcmc.run(prec) @@ -452,7 +550,7 @@ def model(prec): def test_dirichlet_categorical_grad_adapt(): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) - p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) + p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent diff --git a/tests/infer/mcmc/test_valid_models.py b/tests/infer/mcmc/test_valid_models.py index 323721db34..c173b2fad8 100644 --- a/tests/infer/mcmc/test_valid_models.py +++ b/tests/infer/mcmc/test_valid_models.py @@ -47,19 +47,23 @@ def print_debug_info(model_trace): logger.debug("prob( {} ):\n {}".format(name, site["log_prob"].exp())) -@pytest.mark.parametrize("kernel, kwargs", [ - (HMC, {"adapt_step_size": True, "num_steps": 3}), - (NUTS, {"adapt_step_size": True}), -]) +@pytest.mark.parametrize( + "kernel, kwargs", + [ + (HMC, {"adapt_step_size": True, "num_steps": 3}), + (NUTS, {"adapt_step_size": True}), + ], +) def test_model_error_stray_batch_dims(kernel, kwargs): - def gmm(): - data = torch.tensor([0., 0., 3., 3., 3., 5., 5.]) + data = torch.tensor([0.0, 0.0, 3.0, 3.0, 3.0, 5.0, 5.0]) mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(3))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3.), 1.)) + cluster_means = pyro.sample( + "cluster_means", dist.Normal(torch.arange(3.0), 1.0) + ) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) - pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) + pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) return cluster_means mcmc_kernel = kernel(gmm, **kwargs) @@ -70,20 +74,24 @@ def gmm(): assert_error(mcmc_kernel) -@pytest.mark.parametrize("kernel, kwargs", [ - (HMC, {"adapt_step_size": True, "num_steps": 3}), - (NUTS, {"adapt_step_size": True}), -]) +@pytest.mark.parametrize( + "kernel, kwargs", + [ + (HMC, {"adapt_step_size": True, "num_steps": 3}), + (NUTS, {"adapt_step_size": True}), + ], +) def test_model_error_enum_dim_clash(kernel, kwargs): - def gmm(): - data = torch.tensor([0., 0., 3., 3., 3., 5., 5.]) + data = torch.tensor([0.0, 0.0, 3.0, 3.0, 3.0, 5.0, 5.0]) with pyro.plate("num_clusters", 3): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3.), 1.)) + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.0))) + cluster_means = pyro.sample( + "cluster_means", dist.Normal(torch.arange(3.0), 1.0) + ) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) - pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) + pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) return cluster_means mcmc_kernel = kernel(gmm, max_plate_nesting=0, **kwargs) @@ -100,31 +108,37 @@ def model(): inner2 = pyro.plate("inner2", 5, dim=-2) inner3 = pyro.plate("inner3", 6, dim=-4) - p = pyro.sample("p", dist.Uniform(0., 1.)) + p = pyro.sample("p", dist.Uniform(0.0, 1.0)) y = pyro.sample("y", dist.Bernoulli(p)) q = 0.5 + 0.25 * y with outer, inner2: z0 = pyro.sample("z0", dist.Bernoulli(q)) - pyro.sample("obs0", dist.Normal(2 * z0 - 1, 1.), obs=torch.ones(5, 3)) + pyro.sample("obs0", dist.Normal(2 * z0 - 1, 1.0), obs=torch.ones(5, 3)) with outer: v = pyro.sample("v", dist.Bernoulli(q)) r = 0.4 + 0.1 * v with inner1, inner3: z1 = pyro.sample("z1", dist.Bernoulli(r)) - pyro.sample("obs1", dist.Normal(2 * z1 - 1, 1.), obs=torch.ones(6, 4, 1, 3)) + pyro.sample( + "obs1", dist.Normal(2 * z1 - 1, 1.0), obs=torch.ones(6, 4, 1, 3) + ) with inner2: z2 = pyro.sample("z2", dist.Bernoulli(r)) - pyro.sample("obs2", dist.Normal(2 * z2 - 1, 1.), obs=torch.ones(5, 3)) + pyro.sample("obs2", dist.Normal(2 * z2 - 1, 1.0), obs=torch.ones(5, 3)) model_trace = poutine.trace(model).get_trace() trace_prob_evaluator = TraceTreeEvaluator(model_trace, True, 4) trace_prob_evaluator.log_prob(model_trace) plate_dims, enum_dims = [], [] - for key in reversed(sorted(trace_prob_evaluator._log_probs.keys(), key=lambda x: (len(x), x))): + for key in reversed( + sorted(trace_prob_evaluator._log_probs.keys(), key=lambda x: (len(x), x)) + ): plate_dims.append(trace_prob_evaluator._plate_dims[key]) enum_dims.append(trace_prob_evaluator._enum_dims[key]) # The reduction operation returns a singleton with dimensions preserved. - assert not any(i != 1 for i in trace_prob_evaluator._aggregate_log_probs(frozenset()).shape) + assert not any( + i != 1 for i in trace_prob_evaluator._aggregate_log_probs(frozenset()).shape + ) assert plate_dims == [[-4, -3], [-2], [-1], []] assert enum_dims, [[-8], [-9, -6], [-7], [-5]] @@ -152,17 +166,26 @@ def model(): print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 3) # all discrete sites enumerated out. - assert_equal(trace_prob_evaluator.log_prob(model_trace), torch.tensor(0.)) + assert_equal(trace_prob_evaluator.log_prob(model_trace), torch.tensor(0.0)) -@pytest.mark.parametrize("Eval", [TraceTreeEvaluator, - xfail_param(TraceEinsumEvaluator, reason="TODO: Debug this failure case.")]) +@pytest.mark.parametrize( + "Eval", + [ + TraceTreeEvaluator, + xfail_param(TraceEinsumEvaluator, reason="TODO: Debug this failure case."), + ], +) def test_enumeration_in_tree(Eval): @poutine.enum(first_available_dim=-5) @config_enumerate - @poutine.condition(data={"sample1": torch.tensor(0.), - "sample2": torch.tensor(1.), - "sample3": torch.tensor(2.)}) + @poutine.condition( + data={ + "sample1": torch.tensor(0.0), + "sample2": torch.tensor(1.0), + "sample3": torch.tensor(2.0), + } + ) def model(): outer = pyro.plate("outer", 2, dim=-1) inner1 = pyro.plate("inner1", 2, dim=-3) @@ -170,7 +193,7 @@ def model(): inner3 = pyro.plate("inner3", 2, dim=-4) d = dist.Bernoulli(0.3) - n = dist.Normal(0., 1.) + n = dist.Normal(0.0, 1.0) pyro.sample("y", d) pyro.sample("sample1", n) with outer, inner2: @@ -188,7 +211,9 @@ def model(): print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 4) # p_n(0.) * p_n(2.)^2 * p_n(1.)^6 - assert_equal(trace_prob_evaluator.log_prob(model_trace), torch.tensor(-15.2704), prec=1e-4) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), torch.tensor(-15.2704), prec=1e-4 + ) @pytest.mark.xfail(reason="Enumeration currently does not work for general DAGs") @@ -216,79 +241,88 @@ def model(): model_trace = poutine.trace(model).get_trace() print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 2) - assert_equal(trace_prob_evaluator.log_prob(model_trace), torch.tensor(0.16196)) # p_beta(0.3)^3 - - -@pytest.mark.parametrize("data, expected_log_prob", [ - (torch.tensor([1.]), torch.tensor(-1.3434)), - (torch.tensor([0.]), torch.tensor(-1.4189)), - (torch.tensor([1., 0., 0.]), torch.tensor(-4.1813)), -]) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), torch.tensor(0.16196) + ) # p_beta(0.3)^3 + + +@pytest.mark.parametrize( + "data, expected_log_prob", + [ + (torch.tensor([1.0]), torch.tensor(-1.3434)), + (torch.tensor([0.0]), torch.tensor(-1.4189)), + (torch.tensor([1.0, 0.0, 0.0]), torch.tensor(-4.1813)), + ], +) @pytest.mark.parametrize("Eval", [TraceTreeEvaluator, TraceEinsumEvaluator]) def test_enum_log_prob_continuous_observed(data, expected_log_prob, Eval): - @poutine.enum(first_available_dim=-2) @config_enumerate @poutine.condition(data={"p": torch.tensor(0.4)}) def model(data): - p = pyro.sample("p", dist.Uniform(0., 1.)) + p = pyro.sample("p", dist.Uniform(0.0, 1.0)) y = pyro.sample("y", dist.Bernoulli(p)) q = 0.5 + 0.25 * y with pyro.plate("data", len(data)): z = pyro.sample("z", dist.Bernoulli(q)) mean = 2 * z - 1 - pyro.sample("obs", dist.Normal(mean, 1.), obs=data) + pyro.sample("obs", dist.Normal(mean, 1.0), obs=data) model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 1) - assert_equal(trace_prob_evaluator.log_prob(model_trace), - expected_log_prob, - prec=1e-3) - - -@pytest.mark.parametrize("data, expected_log_prob", [ - (torch.tensor([1.]), torch.tensor(-3.5237)), - (torch.tensor([0.]), torch.tensor(-3.7091)), - (torch.tensor([1., 1.]), torch.tensor(-3.9699)), - (torch.tensor([1., 0., 0.]), torch.tensor(-5.3357)), -]) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), expected_log_prob, prec=1e-3 + ) + + +@pytest.mark.parametrize( + "data, expected_log_prob", + [ + (torch.tensor([1.0]), torch.tensor(-3.5237)), + (torch.tensor([0.0]), torch.tensor(-3.7091)), + (torch.tensor([1.0, 1.0]), torch.tensor(-3.9699)), + (torch.tensor([1.0, 0.0, 0.0]), torch.tensor(-5.3357)), + ], +) @pytest.mark.parametrize("Eval", [TraceTreeEvaluator, TraceEinsumEvaluator]) def test_enum_log_prob_continuous_sampled(data, expected_log_prob, Eval): - @poutine.enum(first_available_dim=-2) @config_enumerate - @poutine.condition(data={"p": torch.tensor(0.4), - "n": torch.tensor([[1.], [-1.]])}) + @poutine.condition( + data={"p": torch.tensor(0.4), "n": torch.tensor([[1.0], [-1.0]])} + ) def model(data): - p = pyro.sample("p", dist.Uniform(0., 1.)) + p = pyro.sample("p", dist.Uniform(0.0, 1.0)) y = pyro.sample("y", dist.Bernoulli(p)) mean = 2 * y - 1 - n = pyro.sample("n", dist.Normal(mean, 1.)) + n = pyro.sample("n", dist.Normal(mean, 1.0)) with pyro.plate("data", len(data)): pyro.sample("obs", dist.Bernoulli(torch.sigmoid(n)), obs=data) model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 1) - assert_equal(trace_prob_evaluator.log_prob(model_trace), - expected_log_prob, - prec=1e-3) - - -@pytest.mark.parametrize("data, expected_log_prob", [ - (torch.tensor([1.]), torch.tensor(-0.5108)), - (torch.tensor([1., 1.]), torch.tensor(-0.9808)), - (torch.tensor([1., 0., 0.]), torch.tensor(-2.3671)), -]) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), expected_log_prob, prec=1e-3 + ) + + +@pytest.mark.parametrize( + "data, expected_log_prob", + [ + (torch.tensor([1.0]), torch.tensor(-0.5108)), + (torch.tensor([1.0, 1.0]), torch.tensor(-0.9808)), + (torch.tensor([1.0, 0.0, 0.0]), torch.tensor(-2.3671)), + ], +) @pytest.mark.parametrize("Eval", [TraceTreeEvaluator, TraceEinsumEvaluator]) def test_enum_log_prob_discrete_observed(data, expected_log_prob, Eval): - @poutine.enum(first_available_dim=-2) @config_enumerate @poutine.condition(data={"p": torch.tensor(0.4)}) def model(data): - p = pyro.sample("p", dist.Uniform(0., 1.)) + p = pyro.sample("p", dist.Uniform(0.0, 1.0)) y = pyro.sample("y", dist.Bernoulli(p)) q = 0.25 * y + 0.5 with pyro.plate("data", len(data)): @@ -297,19 +331,21 @@ def model(data): model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 1) - assert_equal(trace_prob_evaluator.log_prob(model_trace), - expected_log_prob, - prec=1e-3) - - -@pytest.mark.parametrize("data, expected_log_prob", [ - (torch.tensor([1.]), torch.tensor(-1.15)), - (torch.tensor([0.]), torch.tensor(-1.46)), - (torch.tensor([1., 1.]), torch.tensor(-2.1998)), -]) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), expected_log_prob, prec=1e-3 + ) + + +@pytest.mark.parametrize( + "data, expected_log_prob", + [ + (torch.tensor([1.0]), torch.tensor(-1.15)), + (torch.tensor([0.0]), torch.tensor(-1.46)), + (torch.tensor([1.0, 1.0]), torch.tensor(-2.1998)), + ], +) @pytest.mark.parametrize("Eval", [TraceTreeEvaluator, TraceEinsumEvaluator]) def test_enum_log_prob_multiple_plate(data, expected_log_prob, Eval): - @poutine.enum(first_available_dim=-2) @config_enumerate @poutine.condition(data={"p": torch.tensor(0.4)}) @@ -326,24 +362,26 @@ def model(data): model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 1) - assert_equal(trace_prob_evaluator.log_prob(model_trace), - expected_log_prob, - prec=1e-3) - - -@pytest.mark.parametrize("data, expected_log_prob", [ - (torch.tensor([1.]), torch.tensor(-1.5478)), - (torch.tensor([0.]), torch.tensor(-1.4189)), - (torch.tensor([1., 0., 0.]), torch.tensor(-4.3857)), -]) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), expected_log_prob, prec=1e-3 + ) + + +@pytest.mark.parametrize( + "data, expected_log_prob", + [ + (torch.tensor([1.0]), torch.tensor(-1.5478)), + (torch.tensor([0.0]), torch.tensor(-1.4189)), + (torch.tensor([1.0, 0.0, 0.0]), torch.tensor(-4.3857)), + ], +) @pytest.mark.parametrize("Eval", [TraceTreeEvaluator, TraceEinsumEvaluator]) def test_enum_log_prob_nested_plate(data, expected_log_prob, Eval): - @poutine.enum(first_available_dim=-3) @config_enumerate @poutine.condition(data={"p": torch.tensor(0.4)}) def model(data): - p = pyro.sample("p", dist.Uniform(0., 1.)) + p = pyro.sample("p", dist.Uniform(0.0, 1.0)) y = pyro.sample("y", dist.Bernoulli(p)) q = 0.5 + 0.25 * y with pyro.plate("intermediate", 1, dim=-2): @@ -351,31 +389,34 @@ def model(data): with pyro.plate("data", len(data), dim=-1): r = 0.4 + 0.1 * v z = pyro.sample("z", dist.Bernoulli(r)) - pyro.sample("obs", dist.Normal(2 * z - 1, 1.), obs=data) + pyro.sample("obs", dist.Normal(2 * z - 1, 1.0), obs=data) model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) trace_prob_evaluator = Eval(model_trace, True, 2) - assert_equal(trace_prob_evaluator.log_prob(model_trace), - expected_log_prob, - prec=1e-3) + assert_equal( + trace_prob_evaluator.log_prob(model_trace), expected_log_prob, prec=1e-3 + ) def _beta_bernoulli(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) - p_latent = pyro.sample('p_latent', dist.Beta(alpha, beta)) - with pyro.plate('data', data.shape[0], dim=-2): - pyro.sample('obs', dist.Bernoulli(p_latent), obs=data) + p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta)) + with pyro.plate("data", data.shape[0], dim=-2): + pyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent -@pytest.mark.parametrize('jit', [False, True]) +@pytest.mark.parametrize("jit", [False, True]) def test_potential_fn_pickling(jit): - data = dist.Bernoulli(torch.tensor([0.8, 0.2])).sample(sample_shape=(torch.Size((1000,)))) - _, potential_fn, _, _ = initialize_model(_beta_bernoulli, (data,), jit_compile=jit, - skip_jit_warnings=True) - test_data = {'p_latent': torch.tensor([0.2, 0.6])} + data = dist.Bernoulli(torch.tensor([0.8, 0.2])).sample( + sample_shape=(torch.Size((1000,))) + ) + _, potential_fn, _, _ = initialize_model( + _beta_bernoulli, (data,), jit_compile=jit, skip_jit_warnings=True + ) + test_data = {"p_latent": torch.tensor([0.2, 0.6])} buffer = io.BytesIO() torch.save(potential_fn, buffer) buffer.seek(0) @@ -383,18 +424,20 @@ def test_potential_fn_pickling(jit): assert_close(deser_potential_fn(test_data), potential_fn(test_data)) -@pytest.mark.parametrize("kernel, kwargs", [ - (HMC, {"adapt_step_size": True, "num_steps": 3}), - (NUTS, {"adapt_step_size": True}), -]) +@pytest.mark.parametrize( + "kernel, kwargs", + [ + (HMC, {"adapt_step_size": True, "num_steps": 3}), + (NUTS, {"adapt_step_size": True}), + ], +) def test_reparam_stable(kernel, kwargs): - @poutine.reparam(config={"z": LatentStableReparam()}) def model(): - stability = pyro.sample("stability", dist.Uniform(0., 2.)) - skew = pyro.sample("skew", dist.Uniform(-1., 1.)) + stability = pyro.sample("stability", dist.Uniform(0.0, 2.0)) + skew = pyro.sample("skew", dist.Uniform(-1.0, 1.0)) y = pyro.sample("z", dist.Stable(stability, skew)) - pyro.sample("x", dist.Poisson(y.abs()), obs=torch.tensor(1.)) + pyro.sample("x", dist.Poisson(y.abs()), obs=torch.tensor(1.0)) mcmc_kernel = kernel(model, max_plate_nesting=0, **kwargs) assert_ok(mcmc_kernel) @@ -403,40 +446,48 @@ def model(): # Regression test for https://github.com/pyro-ppl/pyro/issues/2627 @pytest.mark.parametrize("Kernel", [HMC, NUTS]) def test_potential_fn_initial_params(Kernel): - target = torch.distributions.Normal(loc=torch.tensor([10., 0.]), - scale=torch.tensor([1., 1.])) + target = torch.distributions.Normal( + loc=torch.tensor([10.0, 0.0]), scale=torch.tensor([1.0, 1.0]) + ) def potential_fn(z): - z = z['points'] + z = z["points"] return -target.log_prob(z).sum(1)[None] - initial_params = {'points': torch.tensor([[0., 0.]])} + initial_params = {"points": torch.tensor([[0.0, 0.0]])} kernel = Kernel(potential_fn=potential_fn) - mcmc = MCMC(kernel=kernel, warmup_steps=20, initial_params=initial_params, num_samples=10) + mcmc = MCMC( + kernel=kernel, warmup_steps=20, initial_params=initial_params, num_samples=10 + ) mcmc.run() - mcmc.get_samples()['points'] - - -@pytest.mark.parametrize("mask", [ - torch.tensor(True), - torch.tensor(False), - torch.tensor([True]), - torch.tensor([False]), - torch.tensor([False, True, False]), -]) -@pytest.mark.parametrize("Kernel, options", [ - (HMC, {"adapt_step_size": True, "num_steps": 3}), - (NUTS, {"adapt_step_size": True, "max_tree_depth": 3}), -]) + mcmc.get_samples()["points"] + + +@pytest.mark.parametrize( + "mask", + [ + torch.tensor(True), + torch.tensor(False), + torch.tensor([True]), + torch.tensor([False]), + torch.tensor([False, True, False]), + ], +) +@pytest.mark.parametrize( + "Kernel, options", + [ + (HMC, {"adapt_step_size": True, "num_steps": 3}), + (NUTS, {"adapt_step_size": True, "max_tree_depth": 3}), + ], +) def test_obs_mask_ok(Kernel, options, mask): - data = torch.tensor([7., 7., 7.]) + data = torch.tensor([7.0, 7.0, 7.0]) def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with pyro.plate("plate", len(data)): - y = pyro.sample("y", dist.Normal(x, 1.), - obs=data, obs_mask=mask) + y = pyro.sample("y", dist.Normal(x, 1.0), obs=data, obs_mask=mask) assert ((y == data) == mask).all() mcmc_kernel = Kernel(model, max_plate_nesting=0, **options) diff --git a/tests/infer/reparam/test_conjugate.py b/tests/infer/reparam/test_conjugate.py index 84d81a237c..301b8e1b0a 100644 --- a/tests/infer/reparam/test_conjugate.py +++ b/tests/infer/reparam/test_conjugate.py @@ -55,9 +55,14 @@ def model(counts): prob = pyro.sample("prob", prior) pyro.sample("counts", dist.Binomial(total, prob), obs=counts) - reparam_model = poutine.reparam(model, { - "prob": ConjugateReparam(lambda counts: dist.Beta(1 + counts, 1 + total - counts)), - }) + reparam_model = poutine.reparam( + model, + { + "prob": ConjugateReparam( + lambda counts: dist.Beta(1 + counts, 1 + total - counts) + ), + }, + ) with poutine.trace() as tr, pyro.plate("particles", 10000): reparam_model(counts) @@ -89,7 +94,9 @@ def guide(): def reparam_guide(): pass - elbo = Trace_ELBO(num_particles=10000, vectorize_particles=True, max_plate_nesting=0) + elbo = Trace_ELBO( + num_particles=10000, vectorize_particles=True, max_plate_nesting=0 + ) expected_loss = elbo.differentiable_loss(model, guide) actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide) assert_close(actual_loss, expected_loss, atol=0.01) @@ -106,9 +113,13 @@ def reparam_guide(): @pytest.mark.parametrize("num_steps", range(1, 6)) def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) - trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) + trans_mat = torch.randn( + batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True + ) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) - obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) + obs_mat = torch.randn( + batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True + ) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) data = obs_dist.sample() @@ -156,26 +167,38 @@ def random_stable(shape): @pytest.mark.parametrize("num_steps", range(1, 6)) def test_stable_hmm_smoke(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_stable(batch_shape + (hidden_dim,)).to_event(1) - trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) + trans_mat = torch.randn( + batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True + ) trans_dist = random_stable(batch_shape + (num_steps, hidden_dim)).to_event(1) - obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) + obs_mat = torch.randn( + batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True + ) obs_dist = random_stable(batch_shape + (num_steps, obs_dim)).to_event(1) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) def model(data): - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps + ) with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", hmm) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) # Test that we can combine these two reparameterizers. - reparam_model = poutine.reparam(model, { - "z": LinearHMMReparam(StableReparam(), StableReparam(), StableReparam()), - }) - reparam_model = poutine.reparam(reparam_model, { - "z": ConjugateReparam(dist.Normal(data, 1).to_event(2)), - }) + reparam_model = poutine.reparam( + model, + { + "z": LinearHMMReparam(StableReparam(), StableReparam(), StableReparam()), + }, + ) + reparam_model = poutine.reparam( + reparam_model, + { + "z": ConjugateReparam(dist.Normal(data, 1).to_event(2)), + }, + ) reparam_guide = AutoDiagonalNormal(reparam_model) # Models auxiliary variables. # Smoke test only. diff --git a/tests/infer/reparam/test_discrete_cosine.py b/tests/infer/reparam/test_discrete_cosine.py index eb9ac0438b..17217e74a1 100644 --- a/tests/infer/reparam/test_discrete_cosine.py +++ b/tests/infer/reparam/test_discrete_cosine.py @@ -26,15 +26,25 @@ def get_moments(x): return torch.cat([mean, std, corr]) -@pytest.mark.parametrize("smooth", [0., 0.5, 1.0, 2.0]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize("smooth", [0.0, 0.5, 1.0, 2.0]) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_normal(shape, dim, smooth): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): @@ -61,15 +71,24 @@ def model(): assert_close(actual_grads[1], expected_grads[1], atol=0.05) -@pytest.mark.parametrize("smooth", [0., 0.5, 1.0, 2.0]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize("smooth", [0.0, 0.5, 1.0, 2.0]) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_uniform(shape, dim, smooth): - def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): @@ -78,7 +97,9 @@ def model(): value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) - reparam_model = poutine.reparam(model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)}) + reparam_model = poutine.reparam( + model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)} + ) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution) assert isinstance(trace.nodes["x"]["fn"], dist.Delta) @@ -87,15 +108,25 @@ def model(): assert_close(actual_probe, expected_probe, atol=0.1) -@pytest.mark.parametrize("smooth", [0., 0.5, 1.0, 2.0]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize("smooth", [0.0, 0.5, 1.0, 2.0]) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_init(shape, dim, smooth): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): diff --git a/tests/infer/reparam/test_haar.py b/tests/infer/reparam/test_haar.py index 74fb01679f..f981c9ee54 100644 --- a/tests/infer/reparam/test_haar.py +++ b/tests/infer/reparam/test_haar.py @@ -29,14 +29,24 @@ def get_moments(x): @pytest.mark.parametrize("flip", [False, True]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_normal(shape, dim, flip): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): @@ -64,14 +74,23 @@ def model(): @pytest.mark.parametrize("flip", [False, True]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_uniform(shape, dim, flip): - def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): @@ -90,14 +109,24 @@ def model(): @pytest.mark.parametrize("flip", [False, True]) -@pytest.mark.parametrize("shape,dim", [ - ((6,), -1), - ((2, 5,), -1), - ((4, 2), -2), - ((2, 3, 1), -2), -], ids=str) +@pytest.mark.parametrize( + "shape,dim", + [ + ((6,), -1), + ( + ( + 2, + 5, + ), + -1, + ), + ((4, 2), -2), + ((2, 3, 1), -2), + ], + ids=str, +) def test_init(shape, dim, flip): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): diff --git a/tests/infer/reparam/test_hmm.py b/tests/infer/reparam/test_hmm.py index 0ce6c0f6a0..719210377e 100644 --- a/tests/infer/reparam/test_hmm.py +++ b/tests/infer/reparam/test_hmm.py @@ -43,9 +43,13 @@ def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (duration,), hidden_dim) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) - obs_dist = dist.LogNormal(torch.randn(batch_shape + (duration, obs_dim)), - torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + obs_dist = dist.LogNormal( + torch.randn(batch_shape + (duration, obs_dim)), + torch.rand(batch_shape + (duration, obs_dim)).exp(), + ).to_event(1) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) def model(data=None): with pyro.plate_stack("plates", batch_shape): @@ -71,7 +75,9 @@ def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): trans_dist = random_studentt(batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) def model(data=None): with pyro.plate_stack("plates", batch_shape): @@ -97,15 +103,24 @@ def model(data=None): @pytest.mark.parametrize("skew", [0, None], ids=["symmetric", "skewed"]) def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) - init_dist = random_stable(batch_shape + (hidden_dim,), - stability.unsqueeze(-1), skew=skew).to_event(1) + init_dist = random_stable( + batch_shape + (hidden_dim,), stability.unsqueeze(-1), skew=skew + ).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) - trans_dist = random_stable(batch_shape + (duration, hidden_dim), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) + trans_dist = random_stable( + batch_shape + (duration, hidden_dim), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) - obs_dist = random_stable(batch_shape + (duration, obs_dim), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + obs_dist = random_stable( + batch_shape + (duration, obs_dim), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (duration, obs_dim) @@ -130,15 +145,24 @@ def model(data=None): def test_independent_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): base_batch_shape = batch_shape + (obs_dim,) stability = dist.Uniform(0.5, 2).sample(base_batch_shape) - init_dist = random_stable(base_batch_shape + (hidden_dim,), - stability.unsqueeze(-1), skew=skew).to_event(1) + init_dist = random_stable( + base_batch_shape + (hidden_dim,), stability.unsqueeze(-1), skew=skew + ).to_event(1) trans_mat = torch.randn(base_batch_shape + (duration, hidden_dim, hidden_dim)) - trans_dist = random_stable(base_batch_shape + (duration, hidden_dim), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) + trans_dist = random_stable( + base_batch_shape + (duration, hidden_dim), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) obs_mat = torch.randn(base_batch_shape + (duration, hidden_dim, 1)) - obs_dist = random_stable(base_batch_shape + (duration, 1), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + obs_dist = random_stable( + base_batch_shape + (duration, 1), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) assert hmm.batch_shape == base_batch_shape assert hmm.event_shape == (duration, 1) @@ -182,16 +206,22 @@ def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim) trans_dist = random_stable((duration, hidden_dim), stability, skew=skew).to_event(1) obs_mat = torch.randn(duration, hidden_dim, obs_dim) obs_dist = random_stable((duration, obs_dim), stability, skew=skew).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) num_samples = 200000 - expected_samples = hmm.sample([num_samples]).reshape(num_samples, duration * obs_dim) + expected_samples = hmm.sample([num_samples]).reshape( + num_samples, duration * obs_dim + ) expected_loc, expected_scale, expected_corr = get_hmm_moments(expected_samples) rep = SymmetricStableReparam() if skew == 0 else StableReparam() with pyro.plate("samples", num_samples): with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): - actual_samples = pyro.sample("x", hmm).reshape(num_samples, duration * obs_dim) + actual_samples = pyro.sample("x", hmm).reshape( + num_samples, duration * obs_dim + ) actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples) assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05) @@ -205,14 +235,17 @@ def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, obs_dim) @pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) def test_stable_hmm_shape_error(batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) - init_dist = random_stable(batch_shape + (hidden_dim,), - stability.unsqueeze(-1)).to_event(1) + init_dist = random_stable( + batch_shape + (hidden_dim,), stability.unsqueeze(-1) + ).to_event(1) trans_mat = torch.randn(batch_shape + (1, hidden_dim, hidden_dim)) - trans_dist = random_stable(batch_shape + (1, hidden_dim), - stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) + trans_dist = random_stable( + batch_shape + (1, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1) + ).to_event(1) obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim)) - obs_dist = random_stable(batch_shape + (1, obs_dim), - stability.unsqueeze(-1).unsqueeze(-1)).to_event(1) + obs_dist = random_stable( + batch_shape + (1, obs_dim), stability.unsqueeze(-1).unsqueeze(-1) + ).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (1, obs_dim) @@ -235,15 +268,24 @@ def model(data=None): @pytest.mark.parametrize("skew", [0, None], ids=["symmetric", "skewed"]) def test_init_shape(skew, batch_shape, duration, hidden_dim, obs_dim): stability = dist.Uniform(0.5, 2).sample(batch_shape) - init_dist = random_stable(batch_shape + (hidden_dim,), - stability.unsqueeze(-1), skew=skew).to_event(1) + init_dist = random_stable( + batch_shape + (hidden_dim,), stability.unsqueeze(-1), skew=skew + ).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) - trans_dist = random_stable(batch_shape + (duration, hidden_dim), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) + trans_dist = random_stable( + batch_shape + (duration, hidden_dim), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) - obs_dist = random_stable(batch_shape + (duration, obs_dim), - stability.unsqueeze(-1).unsqueeze(-1), skew=skew).to_event(1) - hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) + obs_dist = random_stable( + batch_shape + (duration, obs_dim), + stability.unsqueeze(-1).unsqueeze(-1), + skew=skew, + ).to_event(1) + hmm = dist.LinearHMM( + init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration + ) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (duration, obs_dim) diff --git a/tests/infer/reparam/test_loc_scale.py b/tests/infer/reparam/test_loc_scale.py index 40920fdac8..2a7445afb2 100644 --- a/tests/infer/reparam/test_loc_scale.py +++ b/tests/infer/reparam/test_loc_scale.py @@ -29,10 +29,10 @@ def get_moments(x): @pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("centered", [0., 0.6, 1., torch.tensor(0.4), None]) +@pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, torch.tensor(0.4), None]) @pytest.mark.parametrize("dist_type", ["Normal", "StudentT", "AsymmetricLaplace"]) def test_moments(dist_type, centered, shape): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() if isinstance(centered, torch.Tensor): centered = centered.expand(shape) @@ -73,10 +73,10 @@ def model(): @pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("centered", [0., 0.6, 1., torch.tensor(0.4), None]) +@pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, torch.tensor(0.4), None]) @pytest.mark.parametrize("dist_type", ["Normal", "StudentT", "AsymmetricLaplace"]) def test_init(dist_type, centered, shape): - loc = torch.empty(shape).uniform_(-1., 1.) + loc = torch.empty(shape).uniform_(-1.0, 1.0) scale = torch.empty(shape).uniform_(0.5, 1.5) def model(): diff --git a/tests/infer/reparam/test_neutra.py b/tests/infer/reparam/test_neutra.py index ed52aaff56..ce9491ef7f 100644 --- a/tests/infer/reparam/test_neutra.py +++ b/tests/infer/reparam/test_neutra.py @@ -18,28 +18,31 @@ def neals_funnel(dim=10): - y = pyro.sample('y', dist.Normal(0, 3)) - with pyro.plate('D', dim): - return pyro.sample('x', dist.Normal(0, torch.exp(y / 2))) + y = pyro.sample("y", dist.Normal(0, 3)) + with pyro.plate("D", dim): + return pyro.sample("x", dist.Normal(0, torch.exp(y / 2))) def dirichlet_categorical(data): concentration = torch.tensor([1.0, 1.0, 1.0]) - p_latent = pyro.sample('p', dist.Dirichlet(concentration)) - with pyro.plate('N', data.shape[0]): - pyro.sample('obs', dist.Categorical(p_latent), obs=data) + p_latent = pyro.sample("p", dist.Dirichlet(concentration)) + with pyro.plate("N", data.shape[0]): + pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent -@pytest.mark.parametrize('jit', [ - False, - xfail_param(True, reason="https://github.com/pyro-ppl/pyro/issues/2292"), -]) +@pytest.mark.parametrize( + "jit", + [ + False, + xfail_param(True, reason="https://github.com/pyro-ppl/pyro/issues/2292"), + ], +) def test_neals_funnel_smoke(jit): dim = 10 guide = AutoIAFNormal(neals_funnel) - svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) + svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step(dim) @@ -51,22 +54,36 @@ def test_neals_funnel_smoke(jit): samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze - transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2)) - assert 'x' in transformed_samples - assert 'y' in transformed_samples - - -@pytest.mark.parametrize('model, kwargs', [ - (neals_funnel, {'dim': 10}), - (dirichlet_categorical, {'data': torch.ones(10,)}) -]) + transformed_samples = neutra.transform_sample( + samples["y_shared_latent"].unsqueeze(-2) + ) + assert "x" in transformed_samples + assert "y" in transformed_samples + + +@pytest.mark.parametrize( + "model, kwargs", + [ + (neals_funnel, {"dim": 10}), + ( + dirichlet_categorical, + { + "data": torch.ones( + 10, + ) + }, + ), + ], +) def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) guide(**kwargs) neutra = NeuTraReparam(guide) reparam_model = neutra.reparam(model) _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs) - init_params, pe_fn_neutra, _, _ = initialize_model(reparam_model, model_kwargs=kwargs) + init_params, pe_fn_neutra, _, _ = initialize_model( + reparam_model, model_kwargs=kwargs + ) latent_x = list(init_params.values())[0] transformed_params = neutra.transform_sample(latent_x) pe_transformed = pe_fn_neutra(init_params) diff --git a/tests/infer/reparam/test_softmax.py b/tests/infer/reparam/test_softmax.py index c1cce054b4..5e933709e4 100644 --- a/tests/infer/reparam/test_softmax.py +++ b/tests/infer/reparam/test_softmax.py @@ -33,8 +33,9 @@ def test_gumbel_softmax(temperature, shape, dim): def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 10000): - pyro.sample("x", dist.RelaxedOneHotCategorical(temperature, - logits=logits)) + pyro.sample( + "x", dist.RelaxedOneHotCategorical(temperature, logits=logits) + ) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index 422df270fb..6337069ea0 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -14,23 +14,32 @@ from .util import check_init_reparam -@pytest.mark.parametrize("event_shape,splits,dim", [ - ((6,), [2, 1, 3], -1), - ((2, 5,), [2, 3], -1), - ((4, 2), [1, 3], -2), - ((2, 3, 1), [1, 2], -2), -], ids=str) +@pytest.mark.parametrize( + "event_shape,splits,dim", + [ + ((6,), [2, 1, 3], -1), + ( + ( + 2, + 5, + ), + [2, 3], + -1, + ), + ((4, 2), [1, 3], -2), + ((2, 3, 1), [1, 2], -2), + ], + ids=str, +) @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) def test_normal(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): with pyro.plate_stack("plates", batch_shape): - pyro.sample("x", - dist.Normal(loc, scale) - .to_event(len(event_shape))) + pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape))) # Run without reparam. trace = poutine.trace(model).get_trace() @@ -41,7 +50,8 @@ def model(): # Run with reparam. split_values = { "x_split_{}".format(i): xi - for i, xi in enumerate(expected_value.split(splits, dim))} + for i, xi in enumerate(expected_value.split(splits, dim)) + } rep = SplitReparam(splits, dim) reparam_model = poutine.reparam(model, {"x": rep}) reparam_model = poutine.condition(reparam_model, split_values) @@ -62,16 +72,27 @@ def model(): assert_close(actual_grads, expected_grads) -@pytest.mark.parametrize("event_shape,splits,dim", [ - ((6,), [2, 1, 3], -1), - ((2, 5,), [2, 3], -1), - ((4, 2), [1, 3], -2), - ((2, 3, 1), [1, 2], -2), -], ids=str) +@pytest.mark.parametrize( + "event_shape,splits,dim", + [ + ((6,), [2, 1, 3], -1), + ( + ( + 2, + 5, + ), + [2, 3], + -1, + ), + ((4, 2), [1, 3], -2), + ((2, 3, 1), [1, 2], -2), + ], + ids=str, +) @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) def test_init(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape - loc = torch.empty(shape).uniform_(-1., 1.) + loc = torch.empty(shape).uniform_(-1.0, 1.0) scale = torch.empty(shape).uniform_(0.5, 1.5) def model(): diff --git a/tests/infer/reparam/test_stable.py b/tests/infer/reparam/test_stable.py index 7bee4de09f..c225d1d814 100644 --- a/tests/infer/reparam/test_stable.py +++ b/tests/infer/reparam/test_stable.py @@ -25,7 +25,7 @@ # Test helper to extract a few absolute moments from univariate samples. # This uses abs moments because Stable variance is infinite. def get_moments(x): - points = torch.tensor([-4., -1., 0., 1., 4.]) + points = torch.tensor([-4.0, -1.0, 0.0, 1.0, 4.0]) points = points.reshape((-1,) + (1,) * x.dim()) return torch.cat([x.mean(0, keepdim=True), (x - points).abs().mean(1)]) @@ -33,13 +33,13 @@ def get_moments(x): @pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) @pytest.mark.parametrize("Reparam", [LatentStableReparam, StableReparam]) def test_stable(Reparam, shape): - stability = torch.empty(shape).uniform_(1.5, 2.).requires_grad_() + stability = torch.empty(shape).uniform_(1.5, 2.0).requires_grad_() skew = torch.empty(shape).uniform_(-0.5, 0.5).requires_grad_() # test edge case when skew is 0 if skew.dim() > 0 and skew.shape[-1] > 0: - skew.data[..., 0] = 0. + skew.data[..., 0] = 0.0 scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() params = [stability, skew, scale, loc] def model(): @@ -76,7 +76,7 @@ def model(): def test_symmetric_stable(shape): stability = torch.empty(shape).uniform_(1.6, 1.9).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() params = [stability, scale, loc] def model(): @@ -105,7 +105,9 @@ def model(): @pytest.mark.parametrize("skew", [-1.0, -0.5, 0.0, 0.5, 1.0]) @pytest.mark.parametrize("stability", [0.1, 0.4, 0.8, 0.99, 1.0, 1.01, 1.3, 1.7, 2.0]) -@pytest.mark.parametrize("Reparam", [LatentStableReparam, SymmetricStableReparam, StableReparam]) +@pytest.mark.parametrize( + "Reparam", [LatentStableReparam, SymmetricStableReparam, StableReparam] +) def test_distribution(stability, skew, Reparam): if Reparam is SymmetricStableReparam and (skew != 0 or stability == 2): pytest.skip() @@ -123,7 +125,9 @@ def model(): @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) -@pytest.mark.parametrize("Reparam", [LatentStableReparam, SymmetricStableReparam, StableReparam]) +@pytest.mark.parametrize( + "Reparam", [LatentStableReparam, SymmetricStableReparam, StableReparam] +) def test_subsample_smoke(Reparam, subsample): def model(): with poutine.reparam(config={"x": Reparam()}): diff --git a/tests/infer/reparam/test_structured.py b/tests/infer/reparam/test_structured.py index 68f45217a6..1ab9d26065 100644 --- a/tests/infer/reparam/test_structured.py +++ b/tests/infer/reparam/test_structured.py @@ -15,12 +15,12 @@ def neals_funnel(dim=10): - y = pyro.sample('y', dist.Normal(0, 3)) - with pyro.plate('D', dim): - return pyro.sample('x', dist.Normal(0, torch.exp(y / 2))) + y = pyro.sample("y", dist.Normal(0, 3)) + with pyro.plate("D", dim): + return pyro.sample("x", dist.Normal(0, torch.exp(y / 2))) -@pytest.mark.parametrize('jit', [False, True]) +@pytest.mark.parametrize("jit", [False, True]) def test_neals_funnel_smoke(jit): dim = 10 @@ -30,7 +30,7 @@ def test_neals_funnel_smoke(jit): dependencies={"x": {"y": "linear"}}, ) Elbo = JitTrace_ELBO if jit else Trace_ELBO - svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Elbo()) + svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Elbo()) for _ in range(1000): try: svi.step(dim=dim) diff --git a/tests/infer/reparam/test_studentt.py b/tests/infer/reparam/test_studentt.py index 56cdf77c32..6e46eb67d9 100644 --- a/tests/infer/reparam/test_studentt.py +++ b/tests/infer/reparam/test_studentt.py @@ -18,7 +18,7 @@ # Test helper to extract a few absolute moments from univariate samples. # This uses abs moments because StudentT variance may be infinite. def get_moments(x): - points = torch.tensor([-4., -1., 0., 1., 4.]) + points = torch.tensor([-4.0, -1.0, 0.0, 1.0, 4.0]) points = points.reshape((-1,) + (1,) * x.dim()) return torch.cat([x.mean(0, keepdim=True), (x - points).abs().mean(1)]) @@ -26,7 +26,7 @@ def get_moments(x): @pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) def test_moments(shape): df = torch.empty(shape).uniform_(1.8, 5).requires_grad_() - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() params = [df, loc, scale] @@ -58,7 +58,6 @@ def model(): @pytest.mark.parametrize("scale", [0.1, 1.0, 2.0]) @pytest.mark.parametrize("loc", [0.0, 1.234]) def test_distribution(df, loc, scale): - def model(): with pyro.plate("particles", 20000): return pyro.sample("x", dist.StudentT(df, loc, scale)) @@ -72,7 +71,7 @@ def model(): @pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) def test_init(shape): df = torch.empty(shape).uniform_(1.8, 5).requires_grad_() - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_() def model(): diff --git a/tests/infer/reparam/test_transform.py b/tests/infer/reparam/test_transform.py index 0f451389d4..e03434cab6 100644 --- a/tests/infer/reparam/test_transform.py +++ b/tests/infer/reparam/test_transform.py @@ -39,7 +39,8 @@ def test_log_normal(batch_shape, event_shape): def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), - [AffineTransform(loc, scale), ExpTransform()]) + [AffineTransform(loc, scale), ExpTransform()], + ) if event_shape: fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): @@ -48,8 +49,9 @@ def model(): with poutine.trace() as tr: value = model() - assert isinstance(tr.trace.nodes["x"]["fn"], - (dist.TransformedDistribution, dist.Independent)) + assert isinstance( + tr.trace.nodes["x"]["fn"], (dist.TransformedDistribution, dist.Independent) + ) expected_moments = get_moments(value) with poutine.reparam(config={"x": TransformReparam()}): @@ -70,7 +72,8 @@ def test_init(batch_shape, event_shape): def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), - [AffineTransform(loc, scale), ExpTransform()]) + [AffineTransform(loc, scale), ExpTransform()], + ) if event_shape: fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): diff --git a/tests/infer/reparam/test_unit_jacobian.py b/tests/infer/reparam/test_unit_jacobian.py index 74d0782caf..ffe9f63e06 100644 --- a/tests/infer/reparam/test_unit_jacobian.py +++ b/tests/infer/reparam/test_unit_jacobian.py @@ -29,7 +29,7 @@ def get_moments(x): @pytest.mark.parametrize("shape", [(6,), (4, 5), (2, 1, 3)], ids=str) def test_normal(shape): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): @@ -59,7 +59,7 @@ def model(): @pytest.mark.parametrize("shape", [(6,), (4, 5), (2, 1, 3)], ids=str) def test_init(shape): - loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): diff --git a/tests/infer/test_abstract_infer.py b/tests/infer/test_abstract_infer.py index bfacd142a2..771c8ddca5 100644 --- a/tests/infer/test_abstract_infer.py +++ b/tests/infer/test_abstract_infer.py @@ -18,7 +18,7 @@ def model(num_trials): with pyro.plate("data", num_trials.size(0)): - phi_prior = dist.Uniform(num_trials.new_tensor(0.), num_trials.new_tensor(1.)) + phi_prior = dist.Uniform(num_trials.new_tensor(0.0), num_trials.new_tensor(1.0)) success_prob = pyro.sample("phi", phi_prior) return pyro.sample("obs", dist.Binomial(num_trials, success_prob)) @@ -41,11 +41,30 @@ def nested(): # TODO: Make this available directly in `SVI` if needed. -@pytest.mark.filterwarnings('ignore::FutureWarning') +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_information_criterion(): # milk dataset: https://github.com/rmcelreath/rethinking/blob/master/data/milk.csv - kcal = torch.tensor([0.49, 0.47, 0.56, 0.89, 0.92, 0.8, 0.46, 0.71, 0.68, - 0.97, 0.84, 0.62, 0.54, 0.49, 0.48, 0.55, 0.71]) + kcal = torch.tensor( + [ + 0.49, + 0.47, + 0.56, + 0.89, + 0.92, + 0.8, + 0.46, + 0.71, + 0.68, + 0.97, + 0.84, + 0.62, + 0.54, + 0.49, + 0.48, + 0.55, + 0.71, + ] + ) kcal_mean = kcal.mean() kcal_logstd = kcal.std().log() @@ -57,8 +76,14 @@ def model(): delta_guide = AutoLaplaceApproximation(model) - svi = SVI(model, delta_guide, optim.Adam({"lr": 0.05}), loss=Trace_ELBO(), - num_steps=0, num_samples=3000) + svi = SVI( + model, + delta_guide, + optim.Adam({"lr": 0.05}), + loss=Trace_ELBO(), + num_steps=0, + num_samples=3000, + ) for i in range(100): svi.step() diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 67c76dcc72..76c98dffba 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -42,13 +42,16 @@ from tests.common import assert_close, assert_equal -@pytest.mark.parametrize("auto_class", [ - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + ], +) def test_scores(auto_class): def model(): if auto_class is AutoIAFNormal: @@ -64,25 +67,29 @@ def model(): model_trace.compute_log_prob() prefix = auto_class.__name__ - if prefix != 'AutoNormal': - assert '_{}_latent'.format(prefix) not in model_trace.nodes - assert guide_trace.nodes['_{}_latent'.format(prefix)]['log_prob_sum'].item() != 0.0 - assert model_trace.nodes['z']['log_prob_sum'].item() != 0.0 - assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0 + if prefix != "AutoNormal": + assert "_{}_latent".format(prefix) not in model_trace.nodes + assert ( + guide_trace.nodes["_{}_latent".format(prefix)]["log_prob_sum"].item() != 0.0 + ) + assert model_trace.nodes["z"]["log_prob_sum"].item() != 0.0 + assert guide_trace.nodes["z"]["log_prob_sum"].item() == 0.0 @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + AutoLaplaceApproximation, + ], +) def test_factor(auto_class, Elbo): - def model(log_factor): pyro.sample("z1", dist.Normal(0.0, 1.0)) pyro.factor("f1", log_factor) @@ -93,19 +100,18 @@ def model(log_factor): guide = auto_class(model) elbo = Elbo(strict_enumeration_warning=False) - elbo.loss(model, guide, torch.tensor(0.)) # initialize param store + elbo.loss(model, guide, torch.tensor(0.0)) # initialize param store pyro.set_rng_seed(123) - loss_5 = elbo.loss(model, guide, torch.tensor(5.)) + loss_5 = elbo.loss(model, guide, torch.tensor(5.0)) pyro.set_rng_seed(123) - loss_4 = elbo.loss(model, guide, torch.tensor(4.)) + loss_4 = elbo.loss(model, guide, torch.tensor(4.0)) assert_close(loss_5 - loss_4, -1 - 3) # helper for test_shapes() class AutoStructured_shapes(AutoStructured): def __init__(self, model, *, init_loc_fn): - def conditional_z4(): return pyro.param("z4_aux", torch.tensor(0.0)) @@ -156,25 +162,30 @@ def dependency_z6_z5(z5): @pytest.mark.parametrize("num_particles", [1, 10]) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) -@pytest.mark.parametrize("init_loc_fn", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, -]) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, - AutoLaplaceApproximation, - AutoStructured_shapes, -]) +@pytest.mark.parametrize( + "init_loc_fn", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + ], +) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + AutoLaplaceApproximation, + AutoStructured_shapes, + ], +) @pytest.mark.filterwarnings("ignore::FutureWarning") def test_shapes(auto_class, init_loc_fn, Elbo, num_particles): - def model(): pyro.sample("z1", dist.Normal(0.0, 1.0)) pyro.sample("z2", dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1)) @@ -182,35 +193,44 @@ def model(): pyro.sample("z3", dist.Normal(torch.zeros(3), torch.ones(3))) pyro.sample("z4", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) pyro.sample("z5", dist.Dirichlet(torch.ones(3))) - pyro.sample("z6", dist.Normal(0, 1).expand((2,)).mask(torch.arange(2) > 0).to_event(1)) - pyro.sample("z7", dist.LKJCholesky(2, torch.tensor(1.))) + pyro.sample( + "z6", dist.Normal(0, 1).expand((2,)).mask(torch.arange(2) > 0).to_event(1) + ) + pyro.sample("z7", dist.LKJCholesky(2, torch.tensor(1.0))) guide = auto_class(model, init_loc_fn=init_loc_fn) - elbo = Elbo(num_particles=num_particles, vectorize_particles=True, - strict_enumeration_warning=False) + elbo = Elbo( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) loss = elbo.loss(model, guide) assert np.isfinite(loss), loss @pytest.mark.xfail(reason="sequential plate is not yet supported") -@pytest.mark.parametrize('auto_class', [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + AutoLaplaceApproximation, + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO]) def test_iplate_smoke(auto_class, Elbo): - def model(): x = pyro.sample("x", dist.Normal(0, 1)) assert x.shape == () for i in pyro.plate("plate", 3): - y = pyro.sample("y_{}".format(i), dist.Normal(0, 1).expand_by([2, 1 + i, 2]).to_event(3)) + y = pyro.sample( + "y_{}".format(i), dist.Normal(0, 1).expand_by([2, 1 + i, 2]).to_event(3) + ) assert y.shape == (2, 1 + i, 2) z = pyro.sample("z", dist.Normal(0, 1).expand_by([2]).to_event(1)) @@ -219,7 +239,9 @@ def model(): pyro.sample("obs", dist.Bernoulli(0.1), obs=torch.tensor(0)) guide = auto_class(model) - infer = SVI(model, guide, Adam({"lr": 1e-6}), Elbo(strict_enumeration_warning=False)) + infer = SVI( + model, guide, Adam({"lr": 1e-6}), Elbo(strict_enumeration_warning=False) + ) infer.step() @@ -232,12 +254,14 @@ def auto_guide_list_x(model): def auto_guide_callable(model): def guide_x(): - x_loc = pyro.param("x_loc", torch.tensor(1.)) - x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive) + x_loc = pyro.param("x_loc", torch.tensor(1.0)) + x_scale = pyro.param( + "x_scale", torch.tensor(0.1), constraint=constraints.positive + ) pyro.sample("x", dist.Normal(x_loc, x_scale)) def median_x(): - return {"x": pyro.param("x_loc", torch.tensor(1.))} + return {"x": pyro.param("x_loc", torch.tensor(1.0))} guide = AutoGuideList(model) guide.append(AutoCallable(model, guide_x, median_x)) @@ -249,8 +273,8 @@ def auto_guide_module_callable(model): class GuideX(AutoGuide): def __init__(self, model): super().__init__(model) - self.x_loc = nn.Parameter(torch.tensor(1.)) - self.x_scale = PyroParam(torch.tensor(.1), constraint=constraints.positive) + self.x_loc = nn.Parameter(torch.tensor(1.0)) + self.x_scale = PyroParam(torch.tensor(0.1), constraint=constraints.positive) def forward(self, *args, **kwargs): return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} @@ -266,9 +290,9 @@ def median(self, *args, **kwargs): def nested_auto_guide_callable(model): guide = AutoGuideList(model) - guide.append(AutoDelta(poutine.block(model, expose=['x']))) - guide_y = AutoGuideList(poutine.block(model, expose=['y'])) - guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y'])) + guide.append(AutoDelta(poutine.block(model, expose=["x"]))) + guide_y = AutoGuideList(poutine.block(model, expose=["y"])) + guide_y.z = AutoIAFNormal(poutine.block(model, expose=["y"])) guide.append(guide_y) return guide @@ -289,34 +313,37 @@ def __init__(self, model): ) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - auto_guide_list_x, - auto_guide_callable, - auto_guide_module_callable, - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_feasible), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), - AutoStructured_median, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + auto_guide_list_x, + auto_guide_callable, + auto_guide_module_callable, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_feasible), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), + AutoStructured_median, + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_median(auto_class, Elbo): - def model(): pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0)) guide = auto_class(model) - optim = Adam({'lr': 0.02, 'betas': (0.8, 0.99)}) - elbo = Elbo(strict_enumeration_warning=False, - num_particles=500, vectorize_particles=True) + optim = Adam({"lr": 0.02, "betas": (0.8, 0.99)}) + elbo = Elbo( + strict_enumeration_warning=False, num_particles=500, vectorize_particles=True + ) infer = SVI(model, guide, optim, elbo) for _ in range(100): infer.step() @@ -333,22 +360,25 @@ def model(): assert_equal(median["z"], torch.tensor(0.5), prec=0.1) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - auto_guide_list_x, - auto_guide_module_callable, - nested_auto_guide_callable, - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_feasible), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), - AutoStructured_median, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + auto_guide_list_x, + auto_guide_module_callable, + nested_auto_guide_callable, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_feasible), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), + AutoStructured_median, + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_autoguide_serialization(auto_class, Elbo): def model(): @@ -356,6 +386,7 @@ def model(): with pyro.plate("plate", 2): pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0)) + guide = auto_class(model) guide() if auto_class is AutoLaplaceApproximation: @@ -391,25 +422,28 @@ def model(): assert_equal(attr_get(guide_deser), attr_get(guide).data) -@pytest.mark.parametrize("auto_class", [ - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_quantiles(auto_class, Elbo): - def model(): pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1)) guide = auto_class(model) - optim = Adam({'lr': 0.05, 'betas': (0.8, 0.99)}) - elbo = Elbo(strict_enumeration_warning=False, - num_particles=100, vectorize_particles=True) + optim = Adam({"lr": 0.05, "betas": (0.8, 0.99)}) + elbo = Elbo( + strict_enumeration_warning=False, num_particles=100, vectorize_particles=True + ) infer = SVI(model, guide, optim, elbo) for _ in range(100): infer.step() @@ -444,28 +478,31 @@ def model(): assert (quantiles["z"][2] < 0.99).all() -@pytest.mark.parametrize("continuous_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "continuous_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + AutoLaplaceApproximation, + ], +) def test_discrete_parallel(continuous_class): K = 2 - data = torch.tensor([0., 1., 10., 11., 12.]) + data = torch.tensor([0.0, 1.0, 10.0, 11.0, 12.0]) def model(data): - weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) - locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) - scale = pyro.sample('scale', dist.LogNormal(0, 1)) + weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K))) + locs = pyro.sample("locs", dist.Normal(0, 10).expand_by([K]).to_event(1)) + scale = pyro.sample("scale", dist.LogNormal(0, 1)) - with pyro.plate('data', len(data)): + with pyro.plate("data", len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) - assignment = pyro.sample('assignment', dist.Categorical(weights)) - pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) + assignment = pyro.sample("assignment", dist.Categorical(weights)) + pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.append(continuous_class(poutine.block(model, hide=["assignment"]))) @@ -476,19 +513,21 @@ def model(data): assert np.isfinite(loss), loss -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoIAFNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoIAFNormal, + AutoLaplaceApproximation, + ], +) def test_guide_list(auto_class): - def model(): - pyro.sample("x", dist.Normal(0., 1.).expand([2])) + pyro.sample("x", dist.Normal(0.0, 1.0).expand([2])) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) @@ -497,22 +536,24 @@ def model(): guide() -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoNormal, - AutoMultivariateNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) def test_callable(auto_class): - def model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) def guide_x(): - x_loc = pyro.param("x_loc", torch.tensor(0.)) + x_loc = pyro.param("x_loc", torch.tensor(0.0)) pyro.sample("x", dist.Delta(x_loc)) guide = AutoGuideList(model) @@ -522,22 +563,24 @@ def guide_x(): assert set(values) == set(["y"]) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) def test_callable_return_dict(auto_class): - def model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) def guide_x(): - x_loc = pyro.param("x_loc", torch.tensor(0.)) + x_loc = pyro.param("x_loc", torch.tensor(0.0)) x = pyro.sample("x", dist.Delta(x_loc)) return {"x": x} @@ -551,6 +594,7 @@ def guide_x(): def test_empty_model_error(): def model(): pass + guide = AutoDiagonalNormal(model) with pytest.raises(RuntimeError): guide() @@ -558,24 +602,26 @@ def model(): def test_unpack_latent(): def model(): - return pyro.sample('x', dist.LKJCholesky(2, torch.tensor(1.))) + return pyro.sample("x", dist.LKJCholesky(2, torch.tensor(1.0))) guide = AutoDiagonalNormal(model) - assert guide()['x'].shape == model().shape + assert guide()["x"].shape == model().shape latent = guide.sample_latent() assert list(guide._unpack_latent(latent))[0][1].shape == (1,) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoLowRankMultivariateNormal, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + ], +) def test_init_loc_fn(auto_class): - def model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) inits = {"x": torch.randn(()), "y": torch.randn(5)} @@ -597,19 +643,21 @@ def __init__(self, *args, **kwargs): @pytest.mark.parametrize("init_scale", [1e-1, 1e-4, 1e-8]) -@pytest.mark.parametrize("auto_class", [ - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoLowRankMultivariateNormal, - AutoLowRankMultivariateNormal_100, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLowRankMultivariateNormal_100, + ], +) def test_init_scale(auto_class, init_scale): - def model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) with pyro.plate("plate", 100): - pyro.sample("z", dist.Normal(0., 1.)) + pyro.sample("z", dist.Normal(0.0, 1.0)) guide = auto_class(model, init_scale=init_scale) guide() @@ -618,34 +666,38 @@ def model(): assert init_scale * 0.5 < scale_rms < 2.0 * init_scale -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - auto_guide_list_x, - auto_guide_callable, - auto_guide_module_callable, - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + auto_guide_list_x, + auto_guide_callable, + auto_guide_module_callable, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_median_module(auto_class, Elbo): - class Model(PyroModule): def __init__(self): super().__init__() - self.x_loc = nn.Parameter(torch.tensor(1.)) + self.x_loc = nn.Parameter(torch.tensor(1.0)) self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive) def forward(self): pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) - pyro.sample("y", dist.Normal(2., 0.1)) + pyro.sample("y", dist.Normal(2.0, 0.1)) model = Model() guide = auto_class(model) - infer = SVI(model, guide, Adam({'lr': 0.005}), Elbo(strict_enumeration_warning=False)) + infer = SVI( + model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) + ) for _ in range(20): infer.step() @@ -659,17 +711,16 @@ def forward(self): @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nested_autoguide(Elbo): - class Model(PyroModule): def __init__(self): super().__init__() - self.x_loc = nn.Parameter(torch.tensor(1.)) + self.x_loc = nn.Parameter(torch.tensor(1.0)) self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive) def forward(self): pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) with pyro.plate("plate", 2): - pyro.sample("y", dist.Normal(2., 0.1)) + pyro.sample("y", dist.Normal(2.0, 0.1)) model = Model() guide = nested_auto_guide_callable(model) @@ -678,17 +729,23 @@ def forward(self): for _, m in guide.named_modules(): if m is guide: continue - assert m.master is not None and m.master() is guide, "master ref wrong for {}".format(m._pyro_name) + assert ( + m.master is not None and m.master() is guide + ), "master ref wrong for {}".format(m._pyro_name) - infer = SVI(model, guide, Adam({'lr': 0.005}), Elbo(strict_enumeration_warning=False)) + infer = SVI( + model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) + ) for _ in range(20): infer.step() guide_trace = poutine.trace(guide).get_trace() model_trace = poutine.trace(model).get_trace() check_model_guide_match(model_trace, guide_trace) - assert all(p.startswith("AutoGuideList.0") or p.startswith("AutoGuideList.1.z") - for p in guide_trace.param_nodes) + assert all( + p.startswith("AutoGuideList.0") or p.startswith("AutoGuideList.1.z") + for p in guide_trace.param_nodes + ) stochastic_nodes = set(guide_trace.stochastic_nodes) assert "x" in stochastic_nodes assert "y" in stochastic_nodes @@ -696,16 +753,19 @@ def forward(self): assert "_AutoGuideList.1.z_latent" in stochastic_nodes -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + ], +) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_linear_regression_smoke(auto_class, Elbo): N, D = 10, 3 @@ -713,8 +773,12 @@ def test_linear_regression_smoke(auto_class, Elbo): class RandomLinear(nn.Linear, PyroModule): def __init__(self, in_features, out_features): super().__init__(in_features, out_features) - self.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) - self.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1)) + self.weight = PyroSample( + dist.Normal(0.0, 1.0).expand([out_features, in_features]).to_event(2) + ) + self.bias = PyroSample( + dist.Normal(0.0, 10.0).expand([out_features]).to_event(1) + ) class LinearRegression(PyroModule): def __init__(self): @@ -723,14 +787,16 @@ def __init__(self): def forward(self, x, y=None): mean = self.linear(x).squeeze(-1) - sigma = pyro.sample("sigma", dist.LogNormal(0., 1.)) - with pyro.plate('plate', N): - return pyro.sample('obs', dist.Normal(mean, sigma), obs=y) + sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0)) + with pyro.plate("plate", N): + return pyro.sample("obs", dist.Normal(mean, sigma), obs=y) x, y = torch.randn(N, D), torch.randn(N) model = LinearRegression() guide = auto_class(model) - infer = SVI(model, guide, Adam({'lr': 0.005}), Elbo(strict_enumeration_warning=False)) + infer = SVI( + model, guide, Adam({"lr": 0.005}), Elbo(strict_enumeration_warning=False) + ) infer.step(x, y) @@ -749,25 +815,32 @@ def __init__(self, model): ) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), - functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), - AutoStructured_predictive, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoStructured_predictive, + ], +) def test_predictive(auto_class): N, D = 3, 2 class RandomLinear(nn.Linear, PyroModule): def __init__(self, in_features, out_features): super().__init__(in_features, out_features) - self.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) - self.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1)) + self.weight = PyroSample( + dist.Normal(0.0, 1.0).expand([out_features, in_features]).to_event(2) + ) + self.bias = PyroSample( + dist.Normal(0.0, 10.0).expand([out_features]).to_event(1) + ) class LinearRegression(PyroModule): def __init__(self): @@ -776,9 +849,9 @@ def __init__(self): def forward(self, x, y=None): mean = self.linear(x).squeeze(-1) - sigma = pyro.sample("sigma", dist.LogNormal(0., 1.)) - with pyro.plate('plate', N): - return pyro.sample('obs', dist.Normal(mean, sigma), obs=y) + sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0)) + with pyro.plate("plate", N): + return pyro.sample("obs", dist.Normal(mean, sigma), obs=y) x, y = torch.randn(N, D), torch.randn(N) model = LinearRegression() @@ -808,16 +881,16 @@ def forward(self, x, y=None): @pytest.mark.parametrize("auto_class", [AutoDelta, AutoNormal]) def test_subsample_model(auto_class): - def model(x, y=None, batch_size=None): - loc = pyro.param("loc", lambda: torch.tensor(0.)) - scale = pyro.param("scale", lambda: torch.tensor(1.), - constraint=constraints.positive) + loc = pyro.param("loc", lambda: torch.tensor(0.0)) + scale = pyro.param( + "scale", lambda: torch.tensor(1.0), constraint=constraints.positive + ) with pyro.plate("batch", len(x), subsample_size=batch_size): batch_x = pyro.subsample(x, event_dim=0) batch_y = pyro.subsample(y, event_dim=0) if y is not None else None mean = loc + scale * batch_x - sigma = pyro.sample("sigma", dist.LogNormal(0., 1.)) + sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0)) return pyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) guide = auto_class(model) @@ -849,11 +922,12 @@ def model(batch, subsample, full_size): data_plate = pyro.plate("data", full_size, subsample=subsample) assert data_plate.size == 50 with data_plate: - z = 0. + z = 0.0 for t in range(num_time_steps): z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) - result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), - obs=batch[t]) + result[t] = pyro.sample( + "obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t] + ) return torch.stack(result) @@ -920,80 +994,96 @@ def create_plates(data): svi.step(data) -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) -@pytest.mark.parametrize("init_loc_fn", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) +@pytest.mark.parametrize( + "init_loc_fn", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + ], +) def test_discrete_helpful_error(auto_class, init_loc_fn): - def model(): - p = pyro.sample("p", dist.Beta(2., 2.)) + p = pyro.sample("p", dist.Beta(2.0, 2.0)) x = pyro.sample("x", dist.Bernoulli(p)) - pyro.sample("obs", dist.Bernoulli(p * x + (1 - p) * (1 - x)), - obs=torch.tensor([1., 0.])) + pyro.sample( + "obs", + dist.Bernoulli(p * x + (1 - p) * (1 - x)), + obs=torch.tensor([1.0, 0.0]), + ) guide = auto_class(model, init_loc_fn=init_loc_fn) with pytest.raises(ValueError, match=".*enumeration.html.*"): guide() -@pytest.mark.parametrize("auto_class", [ - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) -@pytest.mark.parametrize("init_loc_fn", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) +@pytest.mark.parametrize( + "init_loc_fn", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + ], +) def test_sphere_helpful_error(auto_class, init_loc_fn): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.).expand([2]).to_event(1)) + x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([2]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), - obs=torch.tensor([1., 0.])) + pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) guide = auto_class(model, init_loc_fn=init_loc_fn) with pytest.raises(ValueError, match=".*ProjectedNormalReparam.*"): guide() -@pytest.mark.parametrize("auto_class", [ - AutoDelta, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoNormal, - AutoLowRankMultivariateNormal, - AutoLaplaceApproximation, -]) -@pytest.mark.parametrize("init_loc_fn", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, -]) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + ], +) +@pytest.mark.parametrize( + "init_loc_fn", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + ], +) def test_sphere_reparam_ok(auto_class, init_loc_fn): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.).expand([3]).to_event(1)) + x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), - obs=torch.tensor([1., 0.])) + pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) model = poutine.reparam(model, {"y": ProjectedNormalReparam()}) guide = auto_class(model) @@ -1001,19 +1091,20 @@ def model(): @pytest.mark.parametrize("auto_class", [AutoDelta]) -@pytest.mark.parametrize("init_loc_fn", [ - init_to_feasible, - init_to_mean, - init_to_median, - init_to_sample, -]) +@pytest.mark.parametrize( + "init_loc_fn", + [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + ], +) def test_sphere_raw_ok(auto_class, init_loc_fn): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.).expand([3]).to_event(1)) + x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), - obs=torch.tensor([1., 0.])) + pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) guide = auto_class(model, init_loc_fn=init_loc_fn) poutine.trace(guide).get_trace().compute_log_prob() @@ -1037,15 +1128,17 @@ def __init__(self, model): ) -@pytest.mark.parametrize("Guide", [ - AutoNormal, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoStructured_exact_normal, - AutoStructured_exact_mvn, -]) +@pytest.mark.parametrize( + "Guide", + [ + AutoNormal, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoStructured_exact_normal, + AutoStructured_exact_mvn, + ], +) def test_exact(Guide): - def model(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): @@ -1074,15 +1167,17 @@ def model(data): assert_close(actual_std, expected_std, rtol=0.05) -@pytest.mark.parametrize("Guide", [ - AutoNormal, - AutoDiagonalNormal, - AutoMultivariateNormal, - AutoStructured_exact_normal, - AutoStructured_exact_mvn, -]) +@pytest.mark.parametrize( + "Guide", + [ + AutoNormal, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoStructured_exact_normal, + AutoStructured_exact_mvn, + ], +) def test_exact_batch(Guide): - def model(data): with pyro.plate("data", len(data)): loc = pyro.sample("loc", dist.Normal(0, 1)) diff --git a/tests/infer/test_compute_downstream_costs.py b/tests/infer/test_compute_downstream_costs.py index d9ae968b41..2853679e5b 100644 --- a/tests/infer/test_compute_downstream_costs.py +++ b/tests/infer/test_compute_downstream_costs.py @@ -15,25 +15,36 @@ from tests.common import assert_equal -def _brute_force_compute_downstream_costs(model_trace, guide_trace, # - non_reparam_nodes): +def _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes # +): - guide_nodes = [x for x in guide_trace.nodes if guide_trace.nodes[x]["type"] == "sample"] + guide_nodes = [ + x for x in guide_trace.nodes if guide_trace.nodes[x]["type"] == "sample" + ] downstream_costs, downstream_guide_cost_nodes = {}, {} stacks = get_plate_stacks(model_trace) for node in guide_nodes: - downstream_costs[node] = MultiFrameTensor((stacks[node], - model_trace.nodes[node]['log_prob'] - - guide_trace.nodes[node]['log_prob'])) + downstream_costs[node] = MultiFrameTensor( + ( + stacks[node], + model_trace.nodes[node]["log_prob"] + - guide_trace.nodes[node]["log_prob"], + ) + ) downstream_guide_cost_nodes[node] = set([node]) descendants = guide_trace.successors(node) for desc in descendants: - desc_mft = MultiFrameTensor((stacks[desc], - model_trace.nodes[desc]['log_prob'] - - guide_trace.nodes[desc]['log_prob'])) + desc_mft = MultiFrameTensor( + ( + stacks[desc], + model_trace.nodes[desc]["log_prob"] + - guide_trace.nodes[desc]["log_prob"], + ) + ) downstream_costs[node].add(*desc_mft.items()) downstream_guide_cost_nodes[node].update([desc]) @@ -43,20 +54,29 @@ def _brute_force_compute_downstream_costs(model_trace, guide_trace, # children_in_model.update(model_trace.successors(node)) children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: - assert (model_trace.nodes[child]["type"] == "sample") - child_mft = MultiFrameTensor((stacks[child], - model_trace.nodes[child]['log_prob'])) + assert model_trace.nodes[child]["type"] == "sample" + child_mft = MultiFrameTensor( + (stacks[child], model_trace.nodes[child]["log_prob"]) + ) downstream_costs[site].add(*child_mft.items()) downstream_guide_cost_nodes[site].update([child]) for k in non_reparam_nodes: - downstream_costs[k] = downstream_costs[k].sum_to(guide_trace.nodes[k]["cond_indep_stack"]) + downstream_costs[k] = downstream_costs[k].sum_to( + guide_trace.nodes[k]["cond_indep_stack"] + ) return downstream_costs, downstream_guide_cost_nodes -def big_model_guide(include_obs=True, include_single=False, include_inner_1=False, flip_c23=False, - include_triple=False, include_z1=False): +def big_model_guide( + include_obs=True, + include_single=False, + include_inner_1=False, + flip_c23=False, + include_triple=False, + include_z1=False, +): p0 = torch.tensor(math.exp(-0.20), requires_grad=True) p1 = torch.tensor(math.exp(-0.33), requires_grad=True) p2 = torch.tensor(math.exp(-0.70), requires_grad=True) @@ -64,11 +84,19 @@ def big_model_guide(include_obs=True, include_single=False, include_inner_1=Fals with pyro.plate("plate_triple1", 6) as ind_triple1: with pyro.plate("plate_triple2", 7) as ind_triple2: if include_z1: - pyro.sample("z1", dist.Bernoulli(p2).expand_by([len(ind_triple2), len(ind_triple1)])) + pyro.sample( + "z1", + dist.Bernoulli(p2).expand_by( + [len(ind_triple2), len(ind_triple1)] + ), + ) with pyro.plate("plate_triple3", 9) as ind_triple3: - pyro.sample("z0", - dist.Bernoulli(p2).expand_by( - [len(ind_triple3), len(ind_triple2), len(ind_triple1)])) + pyro.sample( + "z0", + dist.Bernoulli(p2).expand_by( + [len(ind_triple3), len(ind_triple2), len(ind_triple1)] + ), + ) pyro.sample("a1", dist.Bernoulli(p0)) if include_single: with pyro.plate("plate_single", 5) as ind_single: @@ -78,20 +106,41 @@ def big_model_guide(include_obs=True, include_single=False, include_inner_1=Fals pyro.sample("b1", dist.Bernoulli(p0).expand_by([len(ind_outer)])) if include_inner_1: with pyro.plate("plate_inner_1", 3) as ind_inner: - pyro.sample("c1", dist.Bernoulli(p1).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "c1", dist.Bernoulli(p1).expand_by([len(ind_inner), len(ind_outer)]) + ) if flip_c23 and not include_obs: - pyro.sample("c3", dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)])) - pyro.sample("c2", dist.Bernoulli(p1).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "c3", + dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)]), + ) + pyro.sample( + "c2", + dist.Bernoulli(p1).expand_by([len(ind_inner), len(ind_outer)]), + ) else: - pyro.sample("c2", dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)])) - pyro.sample("c3", dist.Bernoulli(p2).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "c2", + dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)]), + ) + pyro.sample( + "c3", + dist.Bernoulli(p2).expand_by([len(ind_inner), len(ind_outer)]), + ) with pyro.plate("plate_inner_2", 4) as ind_inner: - pyro.sample("d1", dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)])) - d2 = pyro.sample("d2", dist.Bernoulli(p2).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "d1", dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)]) + ) + d2 = pyro.sample( + "d2", dist.Bernoulli(p2).expand_by([len(ind_inner), len(ind_outer)]) + ) assert d2.shape == (4, 2) if include_obs: - pyro.sample("obs", dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)]), - obs=torch.ones(d2.size())) + pyro.sample( + "obs", + dist.Bernoulli(p0).expand_by([len(ind_inner), len(ind_outer)]), + obs=torch.ones(d2.size()), + ) @pytest.mark.parametrize("include_inner_1", [True, False]) @@ -99,16 +148,27 @@ def big_model_guide(include_obs=True, include_single=False, include_inner_1=Fals @pytest.mark.parametrize("flip_c23", [True, False]) @pytest.mark.parametrize("include_triple", [True, False]) @pytest.mark.parametrize("include_z1", [True, False]) -def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23, - include_triple, include_z1): - guide_trace = poutine.trace(big_model_guide, - graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1, - include_single=include_single, flip_c23=flip_c23, - include_triple=include_triple, include_z1=include_z1) - model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace), - graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1, - include_single=include_single, flip_c23=flip_c23, - include_triple=include_triple, include_z1=include_z1) +def test_compute_downstream_costs_big_model_guide_pair( + include_inner_1, include_single, flip_c23, include_triple, include_z1 +): + guide_trace = poutine.trace(big_model_guide, graph_type="dense").get_trace( + include_obs=False, + include_inner_1=include_inner_1, + include_single=include_single, + flip_c23=flip_c23, + include_triple=include_triple, + include_z1=include_z1, + ) + model_trace = poutine.trace( + poutine.replay(big_model_guide, trace=guide_trace), graph_type="dense" + ).get_trace( + include_obs=True, + include_inner_1=include_inner_1, + include_single=include_single, + flip_c23=flip_c23, + include_triple=include_triple, + include_z1=include_z1, + ) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) @@ -116,89 +176,157 @@ def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_ guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) - dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) + dc, dc_nodes = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) - dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) + dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) assert dc_nodes == dc_nodes_brute - expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'}, - 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'}, - 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'}, - 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'}, - 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'}, - 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}} + expected_nodes_full_model = { + "a1": {"c2", "a1", "d1", "c1", "obs", "b1", "d2", "c3", "b0"}, + "d2": {"obs", "d2"}, + "d1": {"obs", "d1", "d2"}, + "c3": {"d2", "obs", "d1", "c3"}, + "b0": {"b0", "d1", "c1", "obs", "b1", "d2", "c3", "c2"}, + "b1": {"obs", "b1", "d1", "d2", "c3", "c1", "c2"}, + "c1": {"d1", "c1", "obs", "d2", "c3", "c2"}, + "c2": {"obs", "d1", "c3", "d2", "c2"}, + } if not include_triple and include_inner_1 and include_single and not flip_c23: - assert(dc_nodes == expected_nodes_full_model) - - expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) - expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) - expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) - expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False) + assert dc_nodes == expected_nodes_full_model + + expected_b1 = ( + model_trace.nodes["b1"]["log_prob"] - guide_trace.nodes["b1"]["log_prob"] + ) + expected_b1 += ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ).sum(0) + expected_b1 += ( + model_trace.nodes["d1"]["log_prob"] - guide_trace.nodes["d1"]["log_prob"] + ).sum(0) + expected_b1 += model_trace.nodes["obs"]["log_prob"].sum(0, keepdim=False) if include_inner_1: - expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0) - expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0) - expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0) - assert_equal(expected_b1, dc['b1'], prec=1.0e-6) + expected_b1 += ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ).sum(0) + expected_b1 += ( + model_trace.nodes["c2"]["log_prob"] - guide_trace.nodes["c2"]["log_prob"] + ).sum(0) + expected_b1 += ( + model_trace.nodes["c3"]["log_prob"] - guide_trace.nodes["c3"]["log_prob"] + ).sum(0) + assert_equal(expected_b1, dc["b1"], prec=1.0e-6) if include_single: - expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) - expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() - expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum() - expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum() - expected_b0 += model_trace.nodes['obs']['log_prob'].sum() + expected_b0 = ( + model_trace.nodes["b0"]["log_prob"] - guide_trace.nodes["b0"]["log_prob"] + ) + expected_b0 += ( + model_trace.nodes["b1"]["log_prob"] - guide_trace.nodes["b1"]["log_prob"] + ).sum() + expected_b0 += ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ).sum() + expected_b0 += ( + model_trace.nodes["d1"]["log_prob"] - guide_trace.nodes["d1"]["log_prob"] + ).sum() + expected_b0 += model_trace.nodes["obs"]["log_prob"].sum() if include_inner_1: - expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() - expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum() - expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum() - assert_equal(expected_b0, dc['b0'], prec=1.0e-6) - assert dc['b0'].size() == (5,) + expected_b0 += ( + model_trace.nodes["c1"]["log_prob"] + - guide_trace.nodes["c1"]["log_prob"] + ).sum() + expected_b0 += ( + model_trace.nodes["c2"]["log_prob"] + - guide_trace.nodes["c2"]["log_prob"] + ).sum() + expected_b0 += ( + model_trace.nodes["c3"]["log_prob"] + - guide_trace.nodes["c3"]["log_prob"] + ).sum() + assert_equal(expected_b0, dc["b0"], prec=1.0e-6) + assert dc["b0"].size() == (5,) if include_inner_1: - expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']) - expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) - expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) - expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0) - - expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']) - expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) - expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) - expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0) - - expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) + expected_c3 = ( + model_trace.nodes["c3"]["log_prob"] - guide_trace.nodes["c3"]["log_prob"] + ) + expected_c3 += ( + model_trace.nodes["d1"]["log_prob"] - guide_trace.nodes["d1"]["log_prob"] + ).sum(0) + expected_c3 += ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ).sum(0) + expected_c3 += model_trace.nodes["obs"]["log_prob"].sum(0) + + expected_c2 = ( + model_trace.nodes["c2"]["log_prob"] - guide_trace.nodes["c2"]["log_prob"] + ) + expected_c2 += ( + model_trace.nodes["d1"]["log_prob"] - guide_trace.nodes["d1"]["log_prob"] + ).sum(0) + expected_c2 += ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ).sum(0) + expected_c2 += model_trace.nodes["obs"]["log_prob"].sum(0) + + expected_c1 = ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) if flip_c23: - expected_c3 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] - expected_c2 += model_trace.nodes['c3']['log_prob'] + expected_c3 += ( + model_trace.nodes["c2"]["log_prob"] + - guide_trace.nodes["c2"]["log_prob"] + ) + expected_c2 += model_trace.nodes["c3"]["log_prob"] else: - expected_c2 += model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob'] - expected_c2 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] + expected_c2 += ( + model_trace.nodes["c3"]["log_prob"] + - guide_trace.nodes["c3"]["log_prob"] + ) + expected_c2 += ( + model_trace.nodes["c2"]["log_prob"] + - guide_trace.nodes["c2"]["log_prob"] + ) expected_c1 += expected_c3 - assert_equal(expected_c1, dc['c1'], prec=1.0e-6) - assert_equal(expected_c2, dc['c2'], prec=1.0e-6) - assert_equal(expected_c3, dc['c3'], prec=1.0e-6) + assert_equal(expected_c1, dc["c1"], prec=1.0e-6) + assert_equal(expected_c2, dc["c2"], prec=1.0e-6) + assert_equal(expected_c3, dc["c3"], prec=1.0e-6) - expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob'] - expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob'] - expected_d1 += model_trace.nodes['obs']['log_prob'] + expected_d1 = ( + model_trace.nodes["d1"]["log_prob"] - guide_trace.nodes["d1"]["log_prob"] + ) + expected_d1 += ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ) + expected_d1 += model_trace.nodes["obs"]["log_prob"] - expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']) - expected_d2 += model_trace.nodes['obs']['log_prob'] + expected_d2 = ( + model_trace.nodes["d2"]["log_prob"] - guide_trace.nodes["d2"]["log_prob"] + ) + expected_d2 += model_trace.nodes["obs"]["log_prob"] if include_triple: - expected_z0 = dc['a1'] + model_trace.nodes['z0']['log_prob'] - guide_trace.nodes['z0']['log_prob'] - assert_equal(expected_z0, dc['z0'], prec=1.0e-6) - assert_equal(expected_d2, dc['d2'], prec=1.0e-6) - assert_equal(expected_d1, dc['d1'], prec=1.0e-6) - - assert dc['b1'].size() == (2,) - assert dc['d2'].size() == (4, 2) + expected_z0 = ( + dc["a1"] + + model_trace.nodes["z0"]["log_prob"] + - guide_trace.nodes["z0"]["log_prob"] + ) + assert_equal(expected_z0, dc["z0"], prec=1.0e-6) + assert_equal(expected_d2, dc["d2"], prec=1.0e-6) + assert_equal(expected_d1, dc["d1"], prec=1.0e-6) + + assert dc["b1"].size() == (2,) + assert dc["d2"].size() == (4, 2) for k in dc: - assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) + assert guide_trace.nodes[k]["log_prob"].size() == dc[k].size() assert_equal(dc[k], dc_brute[k]) @@ -224,10 +352,10 @@ def diamond_guide(dim): @pytest.mark.parametrize("dim", [2, 3, 7, 11]) def test_compute_downstream_costs_duplicates(dim): - guide_trace = poutine.trace(diamond_guide, - graph_type="dense").get_trace(dim=dim) - model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace), - graph_type="dense").get_trace(dim=dim) + guide_trace = poutine.trace(diamond_guide, graph_type="dense").get_trace(dim=dim) + model_trace = poutine.trace( + poutine.replay(diamond_model, trace=guide_trace), graph_type="dense" + ).get_trace(dim=dim) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) @@ -236,37 +364,47 @@ def test_compute_downstream_costs_duplicates(dim): non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) - dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) - dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) + dc, dc_nodes = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) + dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) assert dc_nodes == dc_nodes_brute - expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']) + expected_a1 = ( + model_trace.nodes["a1"]["log_prob"] - guide_trace.nodes["a1"]["log_prob"] + ) for d in range(dim): - expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob'] - expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob'] - expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) - expected_a1 += model_trace.nodes['obs']['log_prob'] - - expected_b1 = - guide_trace.nodes['b1']['log_prob'] + expected_a1 += model_trace.nodes["b{}".format(d)]["log_prob"] + expected_a1 -= guide_trace.nodes["b{}".format(d)]["log_prob"] + expected_a1 += ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) + expected_a1 += model_trace.nodes["obs"]["log_prob"] + + expected_b1 = -guide_trace.nodes["b1"]["log_prob"] for d in range(dim): - expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob'] - expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) - expected_b1 += model_trace.nodes['obs']['log_prob'] - - expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) + expected_b1 += model_trace.nodes["b{}".format(d)]["log_prob"] + expected_b1 += ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) + expected_b1 += model_trace.nodes["obs"]["log_prob"] + + expected_c1 = ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) for d in range(dim): - expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob'] - expected_c1 += model_trace.nodes['obs']['log_prob'] + expected_c1 += model_trace.nodes["b{}".format(d)]["log_prob"] + expected_c1 += model_trace.nodes["obs"]["log_prob"] - assert_equal(expected_a1, dc['a1'], prec=1.0e-6) - assert_equal(expected_b1, dc['b1'], prec=1.0e-6) - assert_equal(expected_c1, dc['c1'], prec=1.0e-6) + assert_equal(expected_a1, dc["a1"], prec=1.0e-6) + assert_equal(expected_b1, dc["b1"], prec=1.0e-6) + assert_equal(expected_c1, dc["c1"], prec=1.0e-6) for k in dc: - assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) + assert guide_trace.nodes[k]["log_prob"].size() == dc[k].size() assert_equal(dc[k], dc_brute[k]) @@ -280,16 +418,20 @@ def nested_model_guide(include_obs=True, dim1=11, dim2=7): c_i = pyro.sample("c{}".format(i), dist.Bernoulli(p1).expand_by([len(ind)])) assert c_i.shape == (dim2 + i,) if include_obs: - obs_i = pyro.sample("obs{}".format(i), dist.Bernoulli(c_i), obs=torch.ones(c_i.size())) + obs_i = pyro.sample( + "obs{}".format(i), dist.Bernoulli(c_i), obs=torch.ones(c_i.size()) + ) assert obs_i.shape == (dim2 + i,) @pytest.mark.parametrize("dim1", [2, 5, 9]) def test_compute_downstream_costs_plate_in_iplate(dim1): - guide_trace = poutine.trace(nested_model_guide, - graph_type="dense").get_trace(include_obs=False, dim1=dim1) - model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace), - graph_type="dense").get_trace(include_obs=True, dim1=dim1) + guide_trace = poutine.trace(nested_model_guide, graph_type="dense").get_trace( + include_obs=False, dim1=dim1 + ) + model_trace = poutine.trace( + poutine.replay(nested_model_guide, trace=guide_trace), graph_type="dense" + ).get_trace(include_obs=True, dim1=dim1) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) @@ -298,34 +440,48 @@ def test_compute_downstream_costs_plate_in_iplate(dim1): non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) - dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) - dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) + dc, dc_nodes = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) + dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) assert dc_nodes == dc_nodes_brute - expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) - expected_c1 += model_trace.nodes['obs1']['log_prob'] - - expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) - expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() - expected_b1 += model_trace.nodes['obs1']['log_prob'].sum() - - expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']) - expected_c0 += model_trace.nodes['obs0']['log_prob'] - - expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) - expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum() - expected_b0 += model_trace.nodes['obs0']['log_prob'].sum() - - assert_equal(expected_c1, dc['c1'], prec=1.0e-6) - assert_equal(expected_b1, dc['b1'], prec=1.0e-6) - assert_equal(expected_c0, dc['c0'], prec=1.0e-6) - assert_equal(expected_b0, dc['b0'], prec=1.0e-6) + expected_c1 = ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) + expected_c1 += model_trace.nodes["obs1"]["log_prob"] + + expected_b1 = ( + model_trace.nodes["b1"]["log_prob"] - guide_trace.nodes["b1"]["log_prob"] + ) + expected_b1 += ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ).sum() + expected_b1 += model_trace.nodes["obs1"]["log_prob"].sum() + + expected_c0 = ( + model_trace.nodes["c0"]["log_prob"] - guide_trace.nodes["c0"]["log_prob"] + ) + expected_c0 += model_trace.nodes["obs0"]["log_prob"] + + expected_b0 = ( + model_trace.nodes["b0"]["log_prob"] - guide_trace.nodes["b0"]["log_prob"] + ) + expected_b0 += ( + model_trace.nodes["c0"]["log_prob"] - guide_trace.nodes["c0"]["log_prob"] + ).sum() + expected_b0 += model_trace.nodes["obs0"]["log_prob"].sum() + + assert_equal(expected_c1, dc["c1"], prec=1.0e-6) + assert_equal(expected_b1, dc["b1"], prec=1.0e-6) + assert_equal(expected_c0, dc["c0"], prec=1.0e-6) + assert_equal(expected_b0, dc["b0"], prec=1.0e-6) for k in dc: - assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) + assert guide_trace.nodes[k]["log_prob"].size() == dc[k].size() assert_equal(dc[k], dc_brute[k]) @@ -340,17 +496,21 @@ def nested_model_guide2(include_obs=True, dim1=3, dim2=2): b_i = pyro.sample("b{}".format(i), dist.Bernoulli(p0).expand_by([len(ind)])) assert b_i.shape == (dim1,) if include_obs: - obs_i = pyro.sample("obs{}".format(i), dist.Bernoulli(b_i), obs=torch.ones(b_i.size())) + obs_i = pyro.sample( + "obs{}".format(i), dist.Bernoulli(b_i), obs=torch.ones(b_i.size()) + ) assert obs_i.shape == (dim1,) @pytest.mark.parametrize("dim1", [2, 5]) @pytest.mark.parametrize("dim2", [3, 4]) def test_compute_downstream_costs_iplate_in_plate(dim1, dim2): - guide_trace = poutine.trace(nested_model_guide2, - graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2) - model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace), - graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) + guide_trace = poutine.trace(nested_model_guide2, graph_type="dense").get_trace( + include_obs=False, dim1=dim1, dim2=dim2 + ) + model_trace = poutine.trace( + poutine.replay(nested_model_guide2, trace=guide_trace), graph_type="dense" + ).get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) @@ -358,29 +518,39 @@ def test_compute_downstream_costs_iplate_in_plate(dim1, dim2): guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) - dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) - dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) + dc, dc_nodes = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) + dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) assert dc_nodes == dc_nodes_brute for k in dc: - assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) + assert guide_trace.nodes[k]["log_prob"].size() == dc[k].size() assert_equal(dc[k], dc_brute[k]) - expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'] - expected_b1 += model_trace.nodes['obs1']['log_prob'] - assert_equal(expected_b1, dc['b1']) + expected_b1 = ( + model_trace.nodes["b1"]["log_prob"] - guide_trace.nodes["b1"]["log_prob"] + ) + expected_b1 += model_trace.nodes["obs1"]["log_prob"] + assert_equal(expected_b1, dc["b1"]) - expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob'] + expected_c = model_trace.nodes["c"]["log_prob"] - guide_trace.nodes["c"]["log_prob"] for i in range(dim2): - expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \ - guide_trace.nodes['b{}'.format(i)]['log_prob'] - expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob'] - assert_equal(expected_c, dc['c']) - - expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob'] + expected_c += ( + model_trace.nodes["b{}".format(i)]["log_prob"] + - guide_trace.nodes["b{}".format(i)]["log_prob"] + ) + expected_c += model_trace.nodes["obs{}".format(i)]["log_prob"] + assert_equal(expected_c, dc["c"]) + + expected_a1 = ( + model_trace.nodes["a1"]["log_prob"] - guide_trace.nodes["a1"]["log_prob"] + ) expected_a1 += expected_c.sum() - assert_equal(expected_a1, dc['a1']) + assert_equal(expected_a1, dc["a1"]) def plate_reuse_model_guide(include_obs=True, dim1=3, dim2=2): @@ -403,10 +573,12 @@ def plate_reuse_model_guide(include_obs=True, dim1=3, dim2=2): @pytest.mark.parametrize("dim1", [2, 5]) @pytest.mark.parametrize("dim2", [3, 4]) def test_compute_downstream_costs_plate_reuse(dim1, dim2): - guide_trace = poutine.trace(plate_reuse_model_guide, - graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2) - model_trace = poutine.trace(poutine.replay(plate_reuse_model_guide, trace=guide_trace), - graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) + guide_trace = poutine.trace(plate_reuse_model_guide, graph_type="dense").get_trace( + include_obs=False, dim1=dim1, dim2=dim2 + ) + model_trace = poutine.trace( + poutine.replay(plate_reuse_model_guide, trace=guide_trace), graph_type="dense" + ).get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) @@ -414,17 +586,26 @@ def test_compute_downstream_costs_plate_reuse(dim1, dim2): guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) - dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, - non_reparam_nodes) - dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) + dc, dc_nodes = _compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) + dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( + model_trace, guide_trace, non_reparam_nodes + ) assert dc_nodes == dc_nodes_brute for k in dc: - assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) + assert guide_trace.nodes[k]["log_prob"].size() == dc[k].size() assert_equal(dc[k], dc_brute[k]) - expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'] - expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() - expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] - expected_c1 += model_trace.nodes['obs']['log_prob'] - assert_equal(expected_c1, dc['c1']) + expected_c1 = ( + model_trace.nodes["c1"]["log_prob"] - guide_trace.nodes["c1"]["log_prob"] + ) + expected_c1 += ( + model_trace.nodes["b1"]["log_prob"] - guide_trace.nodes["b1"]["log_prob"] + ).sum() + expected_c1 += ( + model_trace.nodes["c2"]["log_prob"] - guide_trace.nodes["c2"]["log_prob"] + ) + expected_c1 += model_trace.nodes["obs"]["log_prob"] + assert_equal(expected_c1, dc["c1"]) diff --git a/tests/infer/test_conjugate_gradients.py b/tests/infer/test_conjugate_gradients.py index aba5ba97c5..4d5ef116b8 100644 --- a/tests/infer/test_conjugate_gradients.py +++ b/tests/infer/test_conjugate_gradients.py @@ -8,7 +8,6 @@ class ConjugateChainGradientTests(GaussianChain): - def test_gradients(self): for N in [3, 5]: for reparameterized in [True, False]: @@ -18,15 +17,27 @@ def do_test_gradients(self, N, reparameterized): pyro.clear_param_store() self.setup_chain(N) - elbo = TraceGraph_ELBO(num_particles=100000, vectorize_particles=True, max_plate_nesting=1) + elbo = TraceGraph_ELBO( + num_particles=100000, vectorize_particles=True, max_plate_nesting=1 + ) elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized) for i in range(1, N + 1): for param_prefix in ["loc_q_%d", "log_sig_q_%d", "kappa_q_%d"]: - if i == N and param_prefix == 'kappa_q_%d': + if i == N and param_prefix == "kappa_q_%d": continue actual_grad = pyro.param(param_prefix % i).grad - assert_equal(actual_grad, 0.0 * actual_grad, prec=0.10, msg="".join([ - "parameter %s%d" % (param_prefix[:-2], i), - "\nexpected = zero vector", - "\n actual = {}".format(actual_grad.detach().cpu().numpy())])) + assert_equal( + actual_grad, + 0.0 * actual_grad, + prec=0.10, + msg="".join( + [ + "parameter %s%d" % (param_prefix[:-2], i), + "\nexpected = zero vector", + "\n actual = {}".format( + actual_grad.detach().cpu().numpy() + ), + ] + ), + ) diff --git a/tests/infer/test_csis.py b/tests/infer/test_csis.py index b9769a9d4d..bb4ea906e5 100644 --- a/tests/infer/test_csis.py +++ b/tests/infer/test_csis.py @@ -13,9 +13,9 @@ def model(observations={"y1": 0, "y2": 0}): - x = pyro.sample("x", dist.Normal(torch.tensor(0.), torch.tensor(5**0.5))) - pyro.sample("y1", dist.Normal(x, torch.tensor(2**0.5)), obs=observations["y1"]) - pyro.sample("y2", dist.Normal(x, torch.tensor(2**0.5)), obs=observations["y2"]) + x = pyro.sample("x", dist.Normal(torch.tensor(0.0), torch.tensor(5 ** 0.5))) + pyro.sample("y1", dist.Normal(x, torch.tensor(2 ** 0.5)), obs=observations["y1"]) + pyro.sample("y2", dist.Normal(x, torch.tensor(2 ** 0.5)), obs=observations["y2"]) return x @@ -23,7 +23,7 @@ class Guide(nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(1, 1, bias=False) - self.std = torch.nn.Parameter(torch.tensor(1.)) + self.std = torch.nn.Parameter(torch.tensor(1.0)) def forward(self, observations={"y1": 0, "y2": 0}): pyro.module("guide", self) @@ -36,13 +36,9 @@ def forward(self, observations={"y1": 0, "y2": 0}): def test_csis_sampling(): pyro.clear_param_store() guide = Guide() - csis = pyro.infer.CSIS(model, - guide, - pyro.optim.Adam({}), - num_inference_samples=500) + csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({}), num_inference_samples=500) # observations chosen so that proposal distribution and true posterior will both have zero mean - posterior = csis.run({"y1": torch.tensor(-1.0), - "y2": torch.tensor(1.0)}) + posterior = csis.run({"y1": torch.tensor(-1.0), "y2": torch.tensor(1.0)}) assert_equal(len(posterior.exec_traces), 500) marginal = pyro.infer.EmpiricalMarginal(posterior, "x") assert_equal(marginal.mean, torch.tensor(0.0), prec=0.1) @@ -53,9 +49,7 @@ def test_csis_parameter_update(): pyro.clear_param_store() guide = Guide() initial_parameters = {k: v.item() for k, v in guide.named_parameters()} - csis = pyro.infer.CSIS(model, - guide, - pyro.optim.Adam({'lr': 1e-2})) + csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({"lr": 1e-2})) csis.step() updated_parameters = {k: v.item() for k, v in guide.named_parameters()} for k, init_v in initial_parameters.items(): @@ -66,10 +60,7 @@ def test_csis_parameter_update(): def test_csis_validation_batch(): pyro.clear_param_store() guide = Guide() - csis = pyro.infer.CSIS(model, - guide, - pyro.optim.Adam({}), - validation_batch_size=5) + csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({}), validation_batch_size=5) init_loss_1 = csis.validation_loss() init_loss_2 = csis.validation_loss() csis.step() diff --git a/tests/infer/test_discrete.py b/tests/infer/test_discrete.py index e57b04eb48..67f5dde732 100644 --- a/tests/infer/test_discrete.py +++ b/tests/infer/test_discrete.py @@ -49,7 +49,7 @@ def log_mean_prob(trace, particle_dim): """ assert particle_dim < 0 trace.compute_log_prob() - total = 0. + total = 0.0 for node in trace.nodes.values(): if node["type"] == "sample" and type(node["fn"]).__name__ != "_Subsample": log_prob = node["log_prob"] @@ -59,12 +59,16 @@ def log_mean_prob(trace, particle_dim): return total.logsumexp(0) - math.log(num_particles) -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) -@pytest.mark.parametrize('plate_size', [2]) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) +@pytest.mark.parametrize("plate_size", [2]) def test_plate_smoke(infer, temperature, plate_size): # +-----------------+ # z1 --|--> z2 ---> x2 | @@ -75,27 +79,31 @@ def test_plate_smoke(infer, temperature, plate_size): def model(): p = pyro.param("p", torch.tensor([0.25, 0.75])) q = pyro.param("q", torch.tensor([[0.25, 0.75], [0.75, 0.25]])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) z1 = pyro.sample("z1", dist.Categorical(p)) with pyro.plate("plate", plate_size): z2 = pyro.sample("z2", dist.Categorical(q[z1])) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=torch.ones(plate_size)) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=torch.ones(plate_size)) first_available_dim = -2 infer(model, first_available_dim, temperature)() -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) def test_distribution_1(infer, temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 - data = torch.tensor([1., 2., 3.]) + data = torch.tensor([1.0, 2.0, 3.0]) @config_enumerate def model(num_particles=1, z=None): @@ -104,29 +112,41 @@ def model(num_particles=1, z=None): z = pyro.sample("z", dist.Bernoulli(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3): - pyro.sample("x", dist.Normal(z, 1.), obs=data) + pyro.sample("x", dist.Normal(z, 1.0), obs=data) first_available_dim = -3 sampled_model = infer(model, first_available_dim, temperature) sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) - conditioned_traces = {z: poutine.trace(model).get_trace(z=torch.tensor(z)) for z in [0., 1.]} + conditioned_traces = { + z: poutine.trace(model).get_trace(z=torch.tensor(z)) for z in [0.0, 1.0] + } # Check posterior over z. actual_z_mean = sampled_trace.nodes["z"]["value"].mean() if temperature: - expected_z_mean = 1 / (1 + (conditioned_traces[0].log_prob_sum() - - conditioned_traces[1].log_prob_sum()).exp()) + expected_z_mean = 1 / ( + 1 + + ( + conditioned_traces[0].log_prob_sum() + - conditioned_traces[1].log_prob_sum() + ).exp() + ) else: - expected_z_mean = (conditioned_traces[1].log_prob_sum() > - conditioned_traces[0].log_prob_sum()).float() + expected_z_mean = ( + conditioned_traces[1].log_prob_sum() > conditioned_traces[0].log_prob_sum() + ).float() assert_equal(actual_z_mean, expected_z_mean, prec=1e-2) -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) def test_distribution_2(infer, temperature): # +--------+ # z1 --|--> x1 | @@ -135,36 +155,45 @@ def test_distribution_2(infer, temperature): # z2 --|--> x2 | # +--------+ num_particles = 10000 - data = torch.tensor([[-1., -1., 0.], [-1., 1., 1.]]) + data = torch.tensor([[-1.0, -1.0, 0.0], [-1.0, 1.0, 1.0]]) @config_enumerate def model(num_particles=1, z1=None, z2=None): p = pyro.param("p", torch.tensor([[0.25, 0.75], [0.1, 0.9]])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) with pyro.plate("num_particles", num_particles, dim=-2): z1 = pyro.sample("z1", dist.Categorical(p[0]), obs=z1) z2 = pyro.sample("z2", dist.Categorical(p[z1]), obs=z2) logger.info("z1.shape = {}".format(z1.shape)) logger.info("z2.shape = {}".format(z2.shape)) with pyro.plate("data", 3): - pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0]) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1]) + pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1]) first_available_dim = -3 sampled_model = infer(model, first_available_dim, temperature) - sampled_trace = poutine.trace( - sampled_model).get_trace(num_particles) - conditioned_traces = {(z1, z2): poutine.trace(model).get_trace(z1=torch.tensor(z1), - z2=torch.tensor(z2)) - for z1 in [0, 1] for z2 in [0, 1]} + sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) + conditioned_traces = { + (z1, z2): poutine.trace(model).get_trace( + z1=torch.tensor(z1), z2=torch.tensor(z2) + ) + for z1 in [0, 1] + for z2 in [0, 1] + } # Check joint posterior over (z1, z2). actual_probs = torch.empty(2, 2) expected_probs = torch.empty(2, 2) for (z1, z2), tr in conditioned_traces.items(): expected_probs[z1, z2] = tr.log_prob_sum().exp() - actual_probs[z1, z2] = ((sampled_trace.nodes["z1"]["value"] == z1) & - (sampled_trace.nodes["z2"]["value"] == z2)).float().mean() + actual_probs[z1, z2] = ( + ( + (sampled_trace.nodes["z1"]["value"] == z1) + & (sampled_trace.nodes["z2"]["value"] == z2) + ) + .float() + .mean() + ) if temperature: expected_probs = expected_probs / expected_probs.sum() else: @@ -174,47 +203,61 @@ def model(num_particles=1, z1=None, z2=None): assert_equal(expected_probs, actual_probs, prec=1e-2) -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) def test_distribution_3(infer, temperature): # +---------+ +---------------+ # z1 --|--> x1 | | z2 ---> x2 | # | 3 | | 2 | # +---------+ +---------------+ num_particles = 10000 - data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])] + data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])] @config_enumerate def model(num_particles=1, z1=None, z2=None): p = pyro.param("p", torch.tensor([0.25, 0.75])) - loc = pyro.param("loc", torch.tensor([-1., 1.])) + loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) with pyro.plate("num_particles", num_particles, dim=-2): z1 = pyro.sample("z1", dist.Categorical(p), obs=z1) with pyro.plate("data[0]", 3): - pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0]) + pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) with pyro.plate("data[1]", 2): z2 = pyro.sample("z2", dist.Categorical(p), obs=z2) - pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1]) + pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1]) first_available_dim = -3 sampled_model = infer(model, first_available_dim, temperature) - sampled_trace = poutine.trace( - sampled_model).get_trace(num_particles) - conditioned_traces = {(z1, z20, z21): poutine.trace(model).get_trace(z1=torch.tensor(z1), - z2=torch.tensor([z20, z21])) - for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1]} + sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) + conditioned_traces = { + (z1, z20, z21): poutine.trace(model).get_trace( + z1=torch.tensor(z1), z2=torch.tensor([z20, z21]) + ) + for z1 in [0, 1] + for z20 in [0, 1] + for z21 in [0, 1] + } # Check joint posterior over (z1, z2[0], z2[1]). actual_probs = torch.empty(2, 2, 2) expected_probs = torch.empty(2, 2, 2) for (z1, z20, z21), tr in conditioned_traces.items(): expected_probs[z1, z20, z21] = tr.log_prob_sum().exp() - actual_probs[z1, z20, z21] = ((sampled_trace.nodes["z1"]["value"] == z1) & - (sampled_trace.nodes["z2"]["value"][..., :1] == z20) & - (sampled_trace.nodes["z2"]["value"][..., 1:] == z21)).float().mean() + actual_probs[z1, z20, z21] = ( + ( + (sampled_trace.nodes["z1"]["value"] == z1) + & (sampled_trace.nodes["z2"]["value"][..., :1] == z20) + & (sampled_trace.nodes["z2"]["value"][..., 1:] == z21) + ) + .float() + .mean() + ) if temperature: expected_probs = expected_probs / expected_probs.sum() else: @@ -224,17 +267,21 @@ def model(num_particles=1, z1=None, z2=None): assert_equal(expected_probs.reshape(-1), actual_probs.reshape(-1), prec=1e-2) -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) def test_distribution_masked(infer, temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 - data = torch.tensor([1., 2., 3.]) + data = torch.tensor([1.0, 2.0, 3.0]) mask = torch.tensor([True, False, False]) @config_enumerate @@ -244,30 +291,42 @@ def model(num_particles=1, z=None): z = pyro.sample("z", dist.Bernoulli(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3), poutine.mask(mask=mask): - pyro.sample("x", dist.Normal(z, 1.), obs=data) + pyro.sample("x", dist.Normal(z, 1.0), obs=data) first_available_dim = -3 sampled_model = infer(model, first_available_dim, temperature) sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) - conditioned_traces = {z: poutine.trace(model).get_trace(z=torch.tensor(z)) for z in [0., 1.]} + conditioned_traces = { + z: poutine.trace(model).get_trace(z=torch.tensor(z)) for z in [0.0, 1.0] + } # Check posterior over z. actual_z_mean = sampled_trace.nodes["z"]["value"].mean() if temperature: - expected_z_mean = 1 / (1 + (conditioned_traces[0].log_prob_sum() - - conditioned_traces[1].log_prob_sum()).exp()) + expected_z_mean = 1 / ( + 1 + + ( + conditioned_traces[0].log_prob_sum() + - conditioned_traces[1].log_prob_sum() + ).exp() + ) else: - expected_z_mean = (conditioned_traces[1].log_prob_sum() > - conditioned_traces[0].log_prob_sum()).float() + expected_z_mean = ( + conditioned_traces[1].log_prob_sum() > conditioned_traces[0].log_prob_sum() + ).float() assert_equal(actual_z_mean, expected_z_mean, prec=1e-2) -@pytest.mark.parametrize('length', [1, 2, 10, 100]) -@pytest.mark.parametrize('infer,temperature', [ - (infer_discrete, 0), - (infer_discrete, 1), - (elbo_infer_discrete, 1), -], ids=['map', 'sample', 'sample-elbo']) +@pytest.mark.parametrize("length", [1, 2, 10, 100]) +@pytest.mark.parametrize( + "infer,temperature", + [ + (infer_discrete, 0), + (infer_discrete, 1), + (elbo_infer_discrete, 1), + ], + ids=["map", "sample", "sample-elbo"], +) def test_hmm_smoke(infer, temperature, length): # This should match the example in the infer_discrete docstring. @@ -276,19 +335,23 @@ def hmm(data, hidden_dim=10): means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): - states.append(pyro.sample("states_{}".format(t), - dist.Categorical(transition[states[-1]]))) - data[t] = pyro.sample("obs_{}".format(t), - dist.Normal(means[states[-1]], 1.), - obs=data[t]) + states.append( + pyro.sample( + "states_{}".format(t), dist.Categorical(transition[states[-1]]) + ) + ) + data[t] = pyro.sample( + "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] + ) return states, data true_states, data = hmm([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) - decoder = infer(config_enumerate(hmm), - first_available_dim=-1, temperature=temperature) + decoder = infer( + config_enumerate(hmm), first_available_dim=-1, temperature=temperature + ) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) @@ -296,14 +359,14 @@ def hmm(data, hidden_dim=10): logger.info("inferred states: {}".format(list(map(int, inferred_states)))) -@pytest.mark.xfail(reason='infer_discrete log_prob is incorrect') -@pytest.mark.parametrize('nderivs', [0, 1], ids=['value', 'grad']) +@pytest.mark.xfail(reason="infer_discrete log_prob is incorrect") +@pytest.mark.parametrize("nderivs", [0, 1], ids=["value", "grad"]) def test_prob(nderivs): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 - data = torch.tensor([0.5, 1., 1.5]) + data = torch.tensor([0.5, 1.0, 1.5]) p = pyro.param("p", torch.tensor(0.25)) @config_enumerate @@ -312,7 +375,7 @@ def model(num_particles): with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p)) with pyro.plate("data", 3): - pyro.sample("x", dist.Normal(z, 1.), obs=data) + pyro.sample("x", dist.Normal(z, 1.0), obs=data) def guide(num_particles): pass @@ -320,9 +383,12 @@ def guide(num_particles): elbo = TraceEnum_ELBO(max_plate_nesting=2) expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1) - posterior_model = infer_discrete(config_enumerate(model, "parallel"), - first_available_dim=-3) - posterior_trace = poutine.trace(posterior_model).get_trace(num_particles=num_particles) + posterior_model = infer_discrete( + config_enumerate(model, "parallel"), first_available_dim=-3 + ) + posterior_trace = poutine.trace(posterior_model).get_trace( + num_particles=num_particles + ) actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2) if nderivs == 0: @@ -339,20 +405,19 @@ def test_warning(): def model(): x = pyro.sample("x", dist.Categorical(torch.ones(3))) with pyro.plate("data", len(data)): - pyro.sample("obs", dist.Normal(x.float(), 1), - obs=data) + pyro.sample("obs", dist.Normal(x.float(), 1), obs=data) model_1 = infer_discrete(model, first_available_dim=-2) - model_2 = infer_discrete(model, first_available_dim=-2, - strict_enumeration_warning=False) - model_3 = infer_discrete(config_enumerate(model), - first_available_dim=-2) + model_2 = infer_discrete( + model, first_available_dim=-2, strict_enumeration_warning=False + ) + model_3 = infer_discrete(config_enumerate(model), first_available_dim=-2) # model_1 should raise warnings. with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") model_1() - assert w, 'No warnings were raised' + assert w, "No warnings were raised" # model_2 and model_3 should both be valid. model_2() diff --git a/tests/infer/test_elbo_mapdata.py b/tests/infer/test_elbo_mapdata.py index 0fcdef2117..17ce72b123 100644 --- a/tests/infer/test_elbo_mapdata.py +++ b/tests/infer/test_elbo_mapdata.py @@ -17,15 +17,24 @@ @pytest.mark.stage("integration", "integration_batch_1") @pytest.mark.init(rng_seed=161) -@pytest.mark.parametrize("map_type,batch_size,n_steps,lr", [("iplate", 3, 7000, 0.0008), ("iplate", 8, 100, 0.018), - ("iplate", None, 100, 0.013), ("range", 3, 100, 0.018), - ("range", 8, 100, 0.01), ("range", None, 100, 0.011), - ("plate", 3, 7000, 0.0008), ("plate", 8, 7000, 0.0008), - ("plate", None, 7000, 0.0008)]) +@pytest.mark.parametrize( + "map_type,batch_size,n_steps,lr", + [ + ("iplate", 3, 7000, 0.0008), + ("iplate", 8, 100, 0.018), + ("iplate", None, 100, 0.013), + ("range", 3, 100, 0.018), + ("range", 8, 100, 0.01), + ("range", None, 100, 0.011), + ("plate", 3, 7000, 0.0008), + ("plate", 8, 7000, 0.0008), + ("plate", None, 7000, 0.0008), + ], +) def test_elbo_mapdata(map_type, batch_size, n_steps, lr): # normal-normal: known covariance - lam0 = torch.tensor([0.1, 0.1]) # precision of prior - loc0 = torch.tensor([0.0, 0.5]) # prior mean + lam0 = torch.tensor([0.1, 0.1]) # precision of prior + loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise lam = torch.tensor([6.0, 4.0]) data = [] @@ -48,34 +57,48 @@ def add_data_point(x, y): n_data = torch.tensor([float(len(data))]) analytic_lam_n = lam0 + n_data.expand_as(lam) * lam analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n) - analytic_loc_n = sum_data * (lam / analytic_lam_n) +\ - loc0 * (lam0 / analytic_lam_n) + analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * (lam0 / analytic_lam_n) - logger.debug("DOING ELBO TEST [bs = {}, map_type = {}]".format(batch_size, map_type)) + logger.debug( + "DOING ELBO TEST [bs = {}, map_type = {}]".format(batch_size, map_type) + ) pyro.clear_param_store() def model(): - loc_latent = pyro.sample("loc_latent", - dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1)) + loc_latent = pyro.sample( + "loc_latent", dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1) + ) if map_type == "iplate": for i in pyro.plate("aaa", len(data), batch_size): - pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(lam, -0.5)) .to_event(1), - obs=data[i]), + pyro.sample( + "obs_%d" % i, + dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), + obs=data[i], + ), elif map_type == "plate": with pyro.plate("aaa", len(data), batch_size) as ind: - pyro.sample("obs", dist.Normal(loc_latent, torch.pow(lam, -0.5)) .to_event(1), - obs=data[ind]), + pyro.sample( + "obs", + dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), + obs=data[ind], + ), else: for i, x in enumerate(data): - pyro.sample('obs_%d' % i, - dist.Normal(loc_latent, torch.pow(lam, -0.5)) - .to_event(1), - obs=x) + pyro.sample( + "obs_%d" % i, + dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), + obs=x, + ) return loc_latent def guide(): - loc_q = pyro.param("loc_q", analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23])) - log_sig_q = pyro.param("log_sig_q", analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23])) + loc_q = pyro.param( + "loc_q", analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23]) + ) + log_sig_q = pyro.param( + "log_sig_q", + analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23]), + ) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).to_event(1)) if map_type == "iplate" or map_type is None: @@ -94,16 +117,10 @@ def guide(): for k in range(n_steps): svi.step() - loc_error = torch.sum( - torch.pow( - analytic_loc_n - - pyro.param("loc_q"), - 2.0)) + loc_error = torch.sum(torch.pow(analytic_loc_n - pyro.param("loc_q"), 2.0)) log_sig_error = torch.sum( - torch.pow( - analytic_log_sig_n - - pyro.param("log_sig_q"), - 2.0)) + torch.pow(analytic_log_sig_n - pyro.param("log_sig_q"), 2.0) + ) if k % 500 == 0: logger.debug("errors - {}, {}".format(loc_error, log_sig_error)) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index a9a66a2b09..f034e80582 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -38,15 +38,16 @@ def _skip_cuda(*args): - return skipif_param(*args, - condition="CUDA_TEST" in os.environ, - reason="https://github.com/pyro-ppl/pyro/issues/1380") + return skipif_param( + *args, + condition="CUDA_TEST" in os.environ, + reason="https://github.com/pyro-ppl/pyro/issues/1380" + ) @pytest.mark.parametrize("depth", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("graph_type", ["flat", "dense"]) def test_iter_discrete_traces_order(depth, graph_type): - @config_enumerate(default="sequential") def model(depth): for i in range(depth): @@ -86,8 +87,9 @@ def test_iter_discrete_traces_vector(expand, graph_type): @config_enumerate(default="sequential", expand=expand) def model(): p = pyro.param("p", torch.tensor([0.05, 0.15])) - probs = pyro.param("probs", torch.tensor([[0.1, 0.2, 0.3, 0.4], - [0.4, 0.3, 0.2, 0.1]])) + probs = pyro.param( + "probs", torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]) + ) with pyro.plate("plate", 2): x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(probs)) @@ -114,7 +116,7 @@ def log_prob(self, value): return torch.stack([(-self.probs).log1p(), self.probs.log()])[i, j] -@pytest.mark.parametrize('sample_shape', [(), (2,), (3, 4)]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (3, 4)]) def test_unsafe_bernoulli(sample_shape): logits = torch.randn(10) p = dist.Bernoulli(logits=logits) @@ -174,11 +176,11 @@ def gmm_guide(data, verbose=False): @pytest.mark.parametrize("model", [gmm_model, gmm_guide]) def test_gmm_iter_discrete_traces(data_size, graph_type, model): pyro.clear_param_store() - data = torch.arange(0., float(data_size)) + data = torch.arange(0.0, float(data_size)) model = config_enumerate(model, "sequential") traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True)) # This non-vectorized version is exponential in data_size: - assert len(traces) == 2**data_size + assert len(traces) == 2 ** data_size # A Gaussian mixture model, with vectorized batching. @@ -209,65 +211,83 @@ def gmm_batch_guide(data): @pytest.mark.parametrize("model", [gmm_batch_model, gmm_batch_guide]) def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type): pyro.clear_param_store() - data = torch.arange(0., float(data_size)) + data = torch.arange(0.0, float(data_size)) model = config_enumerate(model, "sequential") traces = list(iter_discrete_traces(graph_type, model, data=data)) # This vectorized version is independent of data_size: assert len(traces) == 2 -@pytest.mark.parametrize("model,guide", [ - (gmm_model, gmm_guide), - (gmm_batch_model, gmm_batch_guide), -], ids=["single", "batch"]) +@pytest.mark.parametrize( + "model,guide", + [ + (gmm_model, gmm_guide), + (gmm_batch_model, gmm_batch_guide), + ], + ids=["single", "batch"], +) @pytest.mark.parametrize("enumerate1", [None, "sequential", "parallel"]) def test_svi_step_smoke(model, guide, enumerate1): pyro.clear_param_store() data = torch.tensor([0.0, 1.0, 9.0]) guide = config_enumerate(guide, default=enumerate1) - optimizer = pyro.optim.Adam({"lr": .001}) + optimizer = pyro.optim.Adam({"lr": 0.001}) elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1])) inference = SVI(model, guide, optimizer, loss=elbo) inference.step(data) -@pytest.mark.parametrize("model,guide", [ - (gmm_model, gmm_guide), - (gmm_batch_model, gmm_batch_guide), -], ids=["single", "batch"]) +@pytest.mark.parametrize( + "model,guide", + [ + (gmm_model, gmm_guide), + (gmm_batch_model, gmm_batch_guide), + ], + ids=["single", "batch"], +) @pytest.mark.parametrize("enumerate1", [None, "sequential", "parallel"]) def test_differentiable_loss(model, guide, enumerate1): pyro.clear_param_store() data = torch.tensor([0.0, 1.0, 9.0]) guide = config_enumerate(guide, default=enumerate1) - elbo = TraceEnum_ELBO(max_plate_nesting=1, - strict_enumeration_warning=any([enumerate1])) + elbo = TraceEnum_ELBO( + max_plate_nesting=1, strict_enumeration_warning=any([enumerate1]) + ) pyro.set_rng_seed(0) loss = elbo.differentiable_loss(model, guide, data) param_names = sorted(pyro.get_param_store()) actual_loss = loss.item() - actual_grads = grad(loss, [pyro.param(name).unconstrained() for name in param_names]) + actual_grads = grad( + loss, [pyro.param(name).unconstrained() for name in param_names] + ) pyro.set_rng_seed(0) expected_loss = elbo.loss_and_grads(model, guide, data) expected_grads = [pyro.param(name).unconstrained().grad for name in param_names] assert_equal(actual_loss, expected_loss) - for name, actual_grad, expected_grad in zip(param_names, actual_grads, expected_grads): - assert_equal(actual_grad, expected_grad, msg='bad {} gradient. Expected:\n{}\nActual:\n{}'.format( - name, expected_grad, actual_grad)) + for name, actual_grad, expected_grad in zip( + param_names, actual_grads, expected_grads + ): + assert_equal( + actual_grad, + expected_grad, + msg="bad {} gradient. Expected:\n{}\nActual:\n{}".format( + name, expected_grad, actual_grad + ), + ) @pytest.mark.parametrize("enumerate1", [None, "sequential", "parallel"]) def test_svi_step_guide_uses_grad(enumerate1): - data = torch.tensor([0., 1., 3.]) + data = torch.tensor([0.0, 1.0, 3.0]) def model(): scale = pyro.param("scale") - loc = pyro.sample("loc", dist.Normal(0., 10.)) + loc = pyro.sample("loc", dist.Normal(0.0, 10.0)) pyro.sample("b", dist.Bernoulli(0.5)) with pyro.plate("data", len(data)): pyro.sample("obs", dist.Normal(loc, scale), obs=data) @@ -278,8 +298,8 @@ def guide(): scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) var = pyro.param("var", torch.tensor(1.0), constraint=constraints.positive) - x = torch.tensor(0., requires_grad=True) - prior = dist.Normal(0., 10.).log_prob(x) + x = torch.tensor(0.0, requires_grad=True) + prior = dist.Normal(0.0, 10.0).log_prob(x) likelihood = dist.Normal(x, scale).log_prob(data).sum() loss = -(prior + likelihood) g = grad(loss, [x], create_graph=True)[0] @@ -293,7 +313,7 @@ def guide(): inference.step() -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) @pytest.mark.parametrize("method", ["loss", "differentiable_loss", "loss_and_grads"]) @pytest.mark.parametrize("enumerate1", [None, "sequential", "parallel"]) def test_elbo_bern(method, enumerate1, scale): @@ -320,10 +340,17 @@ def guide(): if method == "loss": actual = elbo.loss(model, guide) / num_particles expected = kl.item() * scale - assert_equal(actual, expected, prec=prec, msg="".join([ - "\nexpected = {}".format(expected), - "\n actual = {}".format(actual), - ])) + assert_equal( + actual, + expected, + prec=prec, + msg="".join( + [ + "\nexpected = {}".format(expected), + "\n actual = {}".format(actual), + ] + ), + ) else: if method == "differentiable_loss": loss = elbo.differentiable_loss(model, guide) @@ -332,10 +359,17 @@ def guide(): elbo.loss_and_grads(model, guide) actual = q.grad / num_particles expected = grad(kl, [q])[0] * scale - assert_equal(actual, expected, prec=prec, msg="".join([ - "\nexpected = {}".format(expected.detach().cpu().numpy()), - "\n actual = {}".format(actual.detach().cpu().numpy()), - ])) + assert_equal( + actual, + expected, + prec=prec, + msg="".join( + [ + "\nexpected = {}".format(expected.detach().cpu().numpy()), + "\n actual = {}".format(actual.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("method", ["loss", "differentiable_loss", "loss_and_grads"]) @@ -344,28 +378,35 @@ def test_elbo_normal(method, enumerate1): pyro.clear_param_store() num_particles = 1 if enumerate1 else 10000 prec = 0.01 - q = pyro.param("q", torch.tensor(1., requires_grad=True)) - kl = kl_divergence(dist.Normal(q, 1.), dist.Normal(0., 1.)) + q = pyro.param("q", torch.tensor(1.0, requires_grad=True)) + kl = kl_divergence(dist.Normal(q, 1.0), dist.Normal(0.0, 1.0)) def model(): with pyro.plate("particles", num_particles): - pyro.sample("z", dist.Normal(0., 1.).expand_by([num_particles])) + pyro.sample("z", dist.Normal(0.0, 1.0).expand_by([num_particles])) @config_enumerate(default=enumerate1, num_samples=20000) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("z", dist.Normal(q, 1.).expand_by([num_particles])) + pyro.sample("z", dist.Normal(q, 1.0).expand_by([num_particles])) elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1])) if method == "loss": actual = elbo.loss(model, guide) / num_particles expected = kl.item() - assert_equal(actual, expected, prec=prec, msg="".join([ - "\nexpected = {}".format(expected), - "\n actual = {}".format(actual), - ])) + assert_equal( + actual, + expected, + prec=prec, + msg="".join( + [ + "\nexpected = {}".format(expected), + "\n actual = {}".format(actual), + ] + ), + ) else: if method == "differentiable_loss": loss = elbo.differentiable_loss(model, guide) @@ -374,24 +415,37 @@ def guide(): elbo.loss_and_grads(model, guide) actual = q.grad / num_particles expected = grad(kl, [q])[0] - assert_equal(actual, expected, prec=prec, msg="".join([ - "\nexpected = {}".format(expected.detach().cpu().numpy()), - "\n actual = {}".format(actual.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,num_samples1", [ - (None, None), - ("sequential", None), - ("parallel", None), - ("parallel", 300), -]) -@pytest.mark.parametrize("enumerate2,num_samples2", [ - (None, None), - ("sequential", None), - ("parallel", None), - ("parallel", 300), -]) + assert_equal( + actual, + expected, + prec=prec, + msg="".join( + [ + "\nexpected = {}".format(expected.detach().cpu().numpy()), + "\n actual = {}".format(actual.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,num_samples1", + [ + (None, None), + ("sequential", None), + ("parallel", None), + ("parallel", 300), + ], +) +@pytest.mark.parametrize( + "enumerate2,num_samples2", + [ + (None, None), + ("sequential", None), + ("parallel", None), + ("parallel", 300), + ], +) @pytest.mark.parametrize("method", ["differentiable_loss", "loss_and_grads"]) def test_elbo_bern_bern(method, enumerate1, enumerate2, num_samples1, num_samples2): pyro.clear_param_store() @@ -413,16 +467,26 @@ def model(): def guide(): q = pyro.param("q") - pyro.sample("x1", dist.Bernoulli(q), infer={"enumerate": enumerate1, "num_samples": num_samples1}) - pyro.sample("x2", dist.Bernoulli(q), infer={"enumerate": enumerate2, "num_samples": num_samples2}) + pyro.sample( + "x1", + dist.Bernoulli(q), + infer={"enumerate": enumerate1, "num_samples": num_samples1}, + ) + pyro.sample( + "x2", + dist.Bernoulli(q), + infer={"enumerate": enumerate2, "num_samples": num_samples2}, + ) kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.2, 0.4]) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=any([enumerate1, enumerate2])) + elbo = TraceEnum_ELBO( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=any([enumerate1, enumerate2]), + ) if method == "differentiable_loss": loss = elbo.differentiable_loss(model, guide) actual_loss = loss.item() @@ -431,29 +495,48 @@ def guide(): actual_loss = elbo.loss_and_grads(model, guide) actual_grad = q.grad - assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ - "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,enumerate2,enumerate3,num_samples", [ - (e1, e2, e3, num_samples) - for e1 in [None, "sequential", "parallel"] - for e2 in [None, "sequential", "parallel"] - for e3 in [None, "sequential", "parallel"] - for num_samples in [None, 10000] - if num_samples is None or (e1, e2, e3) == ("parallel", "parallel", "parallel") -]) + assert_equal( + actual_loss, + expected_loss, + prec=prec, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=prec, + msg="".join( + [ + "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,enumerate2,enumerate3,num_samples", + [ + (e1, e2, e3, num_samples) + for e1 in [None, "sequential", "parallel"] + for e2 in [None, "sequential", "parallel"] + for e3 in [None, "sequential", "parallel"] + for num_samples in [None, 10000] + if num_samples is None or (e1, e2, e3) == ("parallel", "parallel", "parallel") + ], +) @pytest.mark.parametrize("method", ["differentiable_loss", "loss_and_grads"]) def test_elbo_berns(method, enumerate1, enumerate2, enumerate3, num_samples): pyro.clear_param_store() num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 10000 - prec = 0.001 if all([enumerate1, enumerate2, enumerate3]) and not num_samples else 0.1 + prec = ( + 0.001 if all([enumerate1, enumerate2, enumerate3]) and not num_samples else 0.1 + ) q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) def model(): @@ -463,17 +546,33 @@ def model(): def guide(): q = pyro.param("q") - pyro.sample("x1", dist.Bernoulli(q), infer={"enumerate": enumerate1, "num_samples": num_samples}) - pyro.sample("x2", dist.Bernoulli(q), infer={"enumerate": enumerate2, "num_samples": num_samples}) - pyro.sample("x3", dist.Bernoulli(q), infer={"enumerate": enumerate3, "num_samples": num_samples}) - - kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3]) + pyro.sample( + "x1", + dist.Bernoulli(q), + infer={"enumerate": enumerate1, "num_samples": num_samples}, + ) + pyro.sample( + "x2", + dist.Bernoulli(q), + infer={"enumerate": enumerate2, "num_samples": num_samples}, + ) + pyro.sample( + "x3", + dist.Bernoulli(q), + infer={"enumerate": enumerate3, "num_samples": num_samples}, + ) + + kl = sum( + kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3] + ) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]), + ) if method == "differentiable_loss": loss = elbo.differentiable_loss(model, guide) actual_loss = loss.item() @@ -482,14 +581,28 @@ def guide(): actual_loss = elbo.loss_and_grads(model, guide) actual_grad = q.grad - assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ - "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_loss, + expected_loss, + prec=prec, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=prec, + msg="".join( + [ + "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("num_samples", [None, 2000]) @@ -497,7 +610,9 @@ def guide(): @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate3", ["sequential", "parallel"]) -def test_elbo_categoricals(enumerate1, enumerate2, enumerate3, max_plate_nesting, num_samples): +def test_elbo_categoricals( + enumerate1, enumerate2, enumerate3, max_plate_nesting, num_samples +): pyro.clear_param_store() p1 = torch.tensor([0.6, 0.4]) p2 = torch.tensor([0.3, 0.3, 0.4]) @@ -512,36 +627,69 @@ def model(): pyro.sample("x3", dist.Categorical(p3)) def guide(): - pyro.sample("x1", dist.Categorical(pyro.param("q1")), - infer={"enumerate": enumerate1, - "num_samples": num_samples if enumerate1 == "parallel" else None}) - pyro.sample("x2", dist.Categorical(pyro.param("q2")), - infer={"enumerate": enumerate2, - "num_samples": num_samples if enumerate2 == "parallel" else None}) - pyro.sample("x3", dist.Categorical(pyro.param("q3")), - infer={"enumerate": enumerate3, - "num_samples": num_samples if enumerate3 == "parallel" else None}) - - kl = (kl_divergence(dist.Categorical(q1), dist.Categorical(p1)) + - kl_divergence(dist.Categorical(q2), dist.Categorical(p2)) + - kl_divergence(dist.Categorical(q3), dist.Categorical(p3))) + pyro.sample( + "x1", + dist.Categorical(pyro.param("q1")), + infer={ + "enumerate": enumerate1, + "num_samples": num_samples if enumerate1 == "parallel" else None, + }, + ) + pyro.sample( + "x2", + dist.Categorical(pyro.param("q2")), + infer={ + "enumerate": enumerate2, + "num_samples": num_samples if enumerate2 == "parallel" else None, + }, + ) + pyro.sample( + "x3", + dist.Categorical(pyro.param("q3")), + infer={ + "enumerate": enumerate3, + "num_samples": num_samples if enumerate3 == "parallel" else None, + }, + ) + + kl = ( + kl_divergence(dist.Categorical(q1), dist.Categorical(p1)) + + kl_divergence(dist.Categorical(q2), dist.Categorical(p2)) + + kl_divergence(dist.Categorical(q3), dist.Categorical(p3)) + ) expected_loss = kl.item() expected_grads = grad(kl, [q1, q2, q3]) - elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting, - strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + max_plate_nesting=max_plate_nesting, + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]), + ) actual_loss = elbo.loss_and_grads(model, guide) actual_grads = [q1.grad, q2.grad, q3.grad] - assert_equal(actual_loss, expected_loss, prec=0.001 if not num_samples else 0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) + assert_equal( + actual_loss, + expected_loss, + prec=0.001 if not num_samples else 0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_equal(actual_grad, expected_grad, prec=0.001 if not num_samples else 0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=0.001 if not num_samples else 0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("enumerate1", [None, "parallel"]) @@ -550,28 +698,47 @@ def guide(): @pytest.mark.parametrize("method", ["differentiable_loss", "loss_and_grads"]) def test_elbo_normals(method, enumerate1, enumerate2, enumerate3): pyro.clear_param_store() - num_particles = 100 * 10 ** sum(1 for e in [enumerate1, enumerate2, enumerate3] if not e) + num_particles = 100 * 10 ** sum( + 1 for e in [enumerate1, enumerate2, enumerate3] if not e + ) prec = 0.1 q = pyro.param("q", torch.tensor(0.0, requires_grad=True)) def model(): - pyro.sample("x1", dist.Normal(0.25, 1.)) - pyro.sample("x2", dist.Normal(0.5, 1.)) - pyro.sample("x3", dist.Normal(1., 1.)) + pyro.sample("x1", dist.Normal(0.25, 1.0)) + pyro.sample("x2", dist.Normal(0.5, 1.0)) + pyro.sample("x3", dist.Normal(1.0, 1.0)) def guide(): q = pyro.param("q") - pyro.sample("x1", dist.Normal(q, 1.), infer={"enumerate": enumerate1, "num_samples": 10}) - pyro.sample("x2", dist.Normal(q, 1.), infer={"enumerate": enumerate2, "num_samples": 10}) - pyro.sample("x3", dist.Normal(q, 1.), infer={"enumerate": enumerate3, "num_samples": 10}) - - kl = sum(kl_divergence(dist.Normal(q, 1.), dist.Normal(p, 1.)) for p in [0.25, 0.5, 1.]) + pyro.sample( + "x1", + dist.Normal(q, 1.0), + infer={"enumerate": enumerate1, "num_samples": 10}, + ) + pyro.sample( + "x2", + dist.Normal(q, 1.0), + infer={"enumerate": enumerate2, "num_samples": 10}, + ) + pyro.sample( + "x3", + dist.Normal(q, 1.0), + infer={"enumerate": enumerate3, "num_samples": 10}, + ) + + kl = sum( + kl_divergence(dist.Normal(q, 1.0), dist.Normal(p, 1.0)) + for p in [0.25, 0.5, 1.0] + ) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]), + ) if method == "differentiable_loss": loss = elbo.differentiable_loss(model, guide) actual_loss = loss.item() @@ -580,23 +747,40 @@ def guide(): actual_loss = elbo.loss_and_grads(model, guide) actual_grad = q.grad - assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ - "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,enumerate2,num_samples", [ - (e1, e2, num_samples) - for e1 in [None, "sequential", "parallel"] - for e2 in [None, "sequential", "parallel"] - for num_samples in [None, 10000] - if num_samples is None or (e1, e2) == ("parallel", "parallel") -]) + assert_equal( + actual_loss, + expected_loss, + prec=prec, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=prec, + msg="".join( + [ + "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,enumerate2,num_samples", + [ + (e1, e2, num_samples) + for e1 in [None, "sequential", "parallel"] + for e2 in [None, "sequential", "parallel"] + for num_samples in [None, 10000] + if num_samples is None or (e1, e2) == ("parallel", "parallel") + ], +) @pytest.mark.parametrize("plate_dim", [1, 2]) def test_elbo_plate(plate_dim, enumerate1, enumerate2, num_samples): pyro.clear_param_store() @@ -608,16 +792,24 @@ def model(): with pyro.plate("particles", num_particles): pyro.sample("y", dist.Bernoulli(p).expand_by([num_particles])) with pyro.plate("plate", plate_dim): - pyro.sample("z", dist.Bernoulli(p).expand_by([plate_dim, num_particles])) + pyro.sample( + "z", dist.Bernoulli(p).expand_by([plate_dim, num_particles]) + ) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("y", dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate1, "num_samples": num_samples}) + pyro.sample( + "y", + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate1, "num_samples": num_samples}, + ) with pyro.plate("plate", plate_dim): - pyro.sample("z", dist.Bernoulli(q).expand_by([plate_dim, num_particles]), - infer={"enumerate": enumerate2, "num_samples": num_samples}) + pyro.sample( + "z", + dist.Bernoulli(q).expand_by([plate_dim, num_particles]), + infer={"enumerate": enumerate2, "num_samples": num_samples}, + ) kl = (1 + plate_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() @@ -625,16 +817,30 @@ def guide(): elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1, enumerate2])) actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + actual_grad = pyro.param("q").grad / num_particles + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("enumerate2", [None, "sequential", "parallel"]) @@ -650,16 +856,24 @@ def model(): with pyro.plate("particles", num_particles): pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles])) for i in pyro.plate("plate", plate_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles])) + pyro.sample( + "y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles]) + ) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate1}) + pyro.sample( + "x", + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate1}, + ) for i in pyro.plate("plate", plate_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate2}) + pyro.sample( + "y_{}".format(i), + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate2}, + ) kl = (1 + plate_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() @@ -667,32 +881,53 @@ def guide(): elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1, enumerate2])) actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,enumerate2,enumerate3,enumerate4,num_samples", [ - (e1, e2, e3, e4, num_samples) - for e1 in [None, "sequential", "parallel"] - for e2 in [None, "sequential", "parallel"] - for e3 in [None, "sequential", "parallel"] - for e4 in [None, "sequential", "parallel"] - for num_samples in [None, 10000] - if num_samples is None or (e1, e2, e3, e4) == ("parallel",) * 4 -]) + actual_grad = pyro.param("q").grad / num_particles + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,enumerate2,enumerate3,enumerate4,num_samples", + [ + (e1, e2, e3, e4, num_samples) + for e1 in [None, "sequential", "parallel"] + for e2 in [None, "sequential", "parallel"] + for e3 in [None, "sequential", "parallel"] + for e4 in [None, "sequential", "parallel"] + for num_samples in [None, 10000] + if num_samples is None or (e1, e2, e3, e4) == ("parallel",) * 4 + ], +) @pytest.mark.parametrize("inner_dim", [2]) @pytest.mark.parametrize("outer_dim", [2]) -def test_elbo_plate_plate(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, enumerate4, num_samples): +def test_elbo_plate_plate( + outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, enumerate4, num_samples +): pyro.clear_param_store() - num_particles = 1 if all([enumerate1, enumerate2, enumerate3, enumerate4]) else 100000 + num_particles = ( + 1 if all([enumerate1, enumerate2, enumerate3, enumerate4]) else 100000 + ) q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 @@ -714,44 +949,71 @@ def guide(): context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": enumerate1, "num_samples": num_samples}) with context1: - pyro.sample("x", d, infer={"enumerate": enumerate2, "num_samples": num_samples}) + pyro.sample( + "x", d, infer={"enumerate": enumerate2, "num_samples": num_samples} + ) with context2: - pyro.sample("y", d, infer={"enumerate": enumerate3, "num_samples": num_samples}) + pyro.sample( + "y", d, infer={"enumerate": enumerate3, "num_samples": num_samples} + ) with context1, context2: - pyro.sample("z", d, infer={"enumerate": enumerate4, "num_samples": num_samples}) + pyro.sample( + "z", d, infer={"enumerate": enumerate4, "num_samples": num_samples} + ) kl_node = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]), + ) actual_loss = elbo.loss_and_grads(model, guide) - actual_grad = pyro.param('q').grad - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,enumerate2,enumerate3,num_samples", [ - (e1, e2, e3, num_samples) - for e1 in [None, "sequential", "parallel"] - for e2 in [None, "sequential", "parallel"] - for e3 in [None, "sequential", "parallel"] - for num_samples in [None, 2000] - if num_samples is None or (e1, e2, e3) == ("parallel",) * 3 -]) + actual_grad = pyro.param("q").grad + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,enumerate2,enumerate3,num_samples", + [ + (e1, e2, e3, num_samples) + for e1 in [None, "sequential", "parallel"] + for e2 in [None, "sequential", "parallel"] + for e3 in [None, "sequential", "parallel"] + for num_samples in [None, 2000] + if num_samples is None or (e1, e2, e3) == ("parallel",) * 3 + ], +) @pytest.mark.parametrize("inner_dim", [2]) @pytest.mark.parametrize("outer_dim", [3]) -def test_elbo_plate_iplate(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, num_samples): +def test_elbo_plate_iplate( + outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, num_samples +): pyro.clear_param_store() num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 100000 q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) @@ -761,38 +1023,70 @@ def model(): with pyro.plate("particles", num_particles): pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles])) with pyro.plate("outer", outer_dim): - pyro.sample("y", dist.Bernoulli(p).expand_by([outer_dim, num_particles])) + pyro.sample( + "y", dist.Bernoulli(p).expand_by([outer_dim, num_particles]) + ) for i in pyro.plate("inner", inner_dim): - pyro.sample("z_{}".format(i), dist.Bernoulli(p).expand_by([outer_dim, num_particles])) + pyro.sample( + "z_{}".format(i), + dist.Bernoulli(p).expand_by([outer_dim, num_particles]), + ) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate1, "num_samples": num_samples}) + pyro.sample( + "x", + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate1, "num_samples": num_samples}, + ) with pyro.plate("outer", outer_dim): - pyro.sample("y", dist.Bernoulli(q).expand_by([outer_dim, num_particles]), - infer={"enumerate": enumerate2, "num_samples": num_samples}) + pyro.sample( + "y", + dist.Bernoulli(q).expand_by([outer_dim, num_particles]), + infer={"enumerate": enumerate2, "num_samples": num_samples}, + ) for i in pyro.plate("inner", inner_dim): - pyro.sample("z_{}".format(i), dist.Bernoulli(q).expand_by([outer_dim, num_particles]), - infer={"enumerate": enumerate3, "num_samples": num_samples}) - - kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) + pyro.sample( + "z_{}".format(i), + dist.Bernoulli(q).expand_by([outer_dim, num_particles]), + infer={"enumerate": enumerate3, "num_samples": num_samples}, + ) + + kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence( + dist.Bernoulli(q), dist.Bernoulli(p) + ) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]) + ) actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + actual_grad = pyro.param("q").grad / num_particles + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("enumerate3", [None, "sequential", "parallel"]) @@ -811,39 +1105,71 @@ def model(): pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles])) inner_plate = pyro.plate("inner", inner_dim) for i in pyro.plate("outer", outer_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles])) + pyro.sample( + "y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles]) + ) with inner_plate: - pyro.sample("z_{}".format(i), dist.Bernoulli(p).expand_by([inner_dim, num_particles])) + pyro.sample( + "z_{}".format(i), + dist.Bernoulli(p).expand_by([inner_dim, num_particles]), + ) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate1}) + pyro.sample( + "x", + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate1}, + ) inner_plate = pyro.plate("inner", inner_dim) for i in pyro.plate("outer", outer_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate2}) + pyro.sample( + "y_{}".format(i), + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate2}, + ) with inner_plate: - pyro.sample("z_{}".format(i), dist.Bernoulli(q).expand_by([inner_dim, num_particles]), - infer={"enumerate": enumerate3}) - - kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) + pyro.sample( + "z_{}".format(i), + dist.Bernoulli(q).expand_by([inner_dim, num_particles]), + infer={"enumerate": enumerate3}, + ) + + kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence( + dist.Bernoulli(q), dist.Bernoulli(p) + ) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]) + ) actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + actual_grad = pyro.param("q").grad / num_particles + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.1, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("enumerate3", [None, "sequential", "parallel"]) @@ -862,39 +1188,71 @@ def model(): pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles])) inner_iplate = pyro.plate("inner", outer_dim) for i in pyro.plate("outer", inner_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles])) + pyro.sample( + "y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles]) + ) for j in inner_iplate: - pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(p).expand_by([num_particles])) + pyro.sample( + "z_{}_{}".format(i, j), + dist.Bernoulli(p).expand_by([num_particles]), + ) def guide(): q = pyro.param("q") with pyro.plate("particles", num_particles): - pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate1}) + pyro.sample( + "x", + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate1}, + ) inner_iplate = pyro.plate("inner", inner_dim) for i in pyro.plate("outer", outer_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate2}) + pyro.sample( + "y_{}".format(i), + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate2}, + ) for j in inner_iplate: - pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(q).expand_by([num_particles]), - infer={"enumerate": enumerate3}) - - kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) + pyro.sample( + "z_{}_{}".format(i, j), + dist.Bernoulli(q).expand_by([num_particles]), + infer={"enumerate": enumerate3}, + ) + + kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence( + dist.Bernoulli(q), dist.Bernoulli(p) + ) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) + elbo = TraceEnum_ELBO( + strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]) + ) actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles - - assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.2, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + actual_grad = pyro.param("q").grad / num_particles + + assert_equal( + actual_loss, + expected_loss, + prec=0.1, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.2, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("pi1", [0.33, 0.43]) @@ -919,8 +1277,8 @@ def guide(): logger.info("Computing gradients using surrogate loss") elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1)) - actual_grad_q1 = pyro.param('q1').grad / num_particles - actual_grad_q2 = pyro.param('q2').grad / num_particles + actual_grad_q1 = pyro.param("q1").grad / num_particles + actual_grad_q2 = pyro.param("q2").grad / num_particles logger.info("Computing analytic gradients") q1 = torch.tensor(pi1, requires_grad=True) @@ -932,26 +1290,45 @@ def guide(): prec = 0.03 if enumerate1 is None else 0.001 - assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([ - "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), - "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), - ])) - assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([ - "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), - "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), - ])) + assert_equal( + actual_grad_q1, + expected_grad_q1, + prec=prec, + msg="".join( + [ + "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), + "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_q2, + expected_grad_q2, + prec=prec, + msg="".join( + [ + "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), + "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("pi1", [0.33, 0.44]) @pytest.mark.parametrize("pi2", [0.55, 0.39]) @pytest.mark.parametrize("pi3", [0.22, 0.29]) -@pytest.mark.parametrize("enumerate1,num_samples", [ - (None, None), - ("sequential", None), - ("parallel", None), - ("parallel", 2), -]) -def test_non_mean_field_bern_normal_elbo_gradient(enumerate1, pi1, pi2, pi3, num_samples): +@pytest.mark.parametrize( + "enumerate1,num_samples", + [ + (None, None), + ("sequential", None), + ("parallel", None), + ("parallel", 2), + ], +) +def test_non_mean_field_bern_normal_elbo_gradient( + enumerate1, pi1, pi2, pi3, num_samples +): pyro.clear_param_store() include_z = True num_particles = 10000 @@ -967,17 +1344,21 @@ def guide(): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.plate("particles", num_particles): - y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1}) + y = pyro.sample( + "y", + dist.Bernoulli(q1).expand_by([num_particles]), + infer={"enumerate": enumerate1}, + ) if include_z: pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0)) logger.info("Computing gradients using surrogate loss") elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, guide) - actual_grad_q1 = pyro.param('q1').grad / num_particles + actual_grad_q1 = pyro.param("q1").grad / num_particles if include_z: - actual_grad_q2 = pyro.param('q2').grad / num_particles - actual_grad_q3 = pyro.param('q3').grad / num_particles + actual_grad_q2 = pyro.param("q2").grad / num_particles + actual_grad_q3 = pyro.param("q3").grad / num_particles logger.info("Computing analytic gradients") q1 = torch.tensor(pi1, requires_grad=True) @@ -985,34 +1366,58 @@ def guide(): q3 = torch.tensor(pi3, requires_grad=True) elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(q3)) if include_z: - elbo = elbo + q1 * kl_divergence(dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0)) - elbo = elbo + (1.0 - q1) * kl_divergence(dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0)) + elbo = elbo + q1 * kl_divergence( + dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0) + ) + elbo = elbo + (1.0 - q1) * kl_divergence( + dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0) + ) expected_grad_q1, expected_grad_q2, expected_grad_q3 = grad(elbo, [q1, q2, q3]) else: expected_grad_q1, expected_grad_q3 = grad(elbo, [q1, q3]) prec = 0.04 if enumerate1 is None else 0.02 - assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([ - "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), - "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), - ])) + assert_equal( + actual_grad_q1, + expected_grad_q1, + prec=prec, + msg="".join( + [ + "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), + "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), + ] + ), + ) if include_z: - assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([ - "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), - "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), - ])) - assert_equal(actual_grad_q3, expected_grad_q3, prec=prec, msg="".join([ - "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()), - "\nq3 actual = {}".format(actual_grad_q3.data.cpu().numpy()), - ])) + assert_equal( + actual_grad_q2, + expected_grad_q2, + prec=prec, + msg="".join( + [ + "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), + "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_q3, + expected_grad_q3, + prec=prec, + msg="".join( + [ + "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()), + "\nq3 actual = {}".format(actual_grad_q3.data.cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("pi1", [0.33, 0.41]) @pytest.mark.parametrize("pi2", [0.44, 0.17]) @pytest.mark.parametrize("pi3", [0.22, 0.29]) def test_non_mean_field_normal_bern_elbo_gradient(pi1, pi2, pi3): - def model(num_particles): with pyro.plate("particles", num_particles): q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True)) @@ -1029,27 +1434,41 @@ def guide(num_particles): zz = torch.exp(z) / (1.0 + torch.exp(z)) pyro.sample("y", dist.Bernoulli(q1 * zz)) - qs = ['q1', 'q2', 'q3', 'q4'] + qs = ["q1", "q2", "q3", "q4"] results = {} - for ed, num_particles in zip([None, 'parallel', 'sequential'], [30000, 20000, 20000]): + for ed, num_particles in zip( + [None, "parallel", "sequential"], [30000, 20000, 20000] + ): pyro.clear_param_store() elbo = TraceEnum_ELBO(strict_enumeration_warning=any([ed])) elbo.loss_and_grads(model, config_enumerate(guide, default=ed), num_particles) results[str(ed)] = {} for q in qs: - results[str(ed)]['actual_grad_%s' % q] = pyro.param(q).grad.detach().cpu().numpy() / num_particles + results[str(ed)]["actual_grad_%s" % q] = ( + pyro.param(q).grad.detach().cpu().numpy() / num_particles + ) prec = 0.03 - for ed in ['parallel', 'sequential']: - logger.info('\n*** {} ***'.format(ed)) + for ed in ["parallel", "sequential"]: + logger.info("\n*** {} ***".format(ed)) for q in qs: - logger.info("[{}] actual: {}".format(q, results[ed]['actual_grad_%s' % q])) - assert_equal(results[ed]['actual_grad_%s' % q], results['None']['actual_grad_%s' % q], prec=prec, - msg="".join([ - "\nexpected (MC estimate) = {}".format(results['None']['actual_grad_%s' % q]), - "\n actual ({} estimate) = {}".format(ed, results[ed]['actual_grad_%s' % q]), - ])) + logger.info("[{}] actual: {}".format(q, results[ed]["actual_grad_%s" % q])) + assert_equal( + results[ed]["actual_grad_%s" % q], + results["None"]["actual_grad_%s" % q], + prec=prec, + msg="".join( + [ + "\nexpected (MC estimate) = {}".format( + results["None"]["actual_grad_%s" % q] + ), + "\n actual ({} estimate) = {}".format( + ed, results[ed]["actual_grad_%s" % q] + ), + ] + ), + ) @pytest.mark.parametrize("enumerate1", [None, "sequential", "parallel"]) @@ -1073,50 +1492,73 @@ def guide(): a = pyro.param("a") with pyro.plate("particles", num_particles): pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles])) - pyro.sample("y", ShapeAugmentedGamma(a, torch.tensor(1.0)).expand_by([num_particles])) + pyro.sample( + "y", + ShapeAugmentedGamma(a, torch.tensor(1.0)).expand_by([num_particles]), + ) elbo = TraceEnum_ELBO(strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, guide) actual_q = q.grad / num_particles expected_q = grad(kl1, [q])[0] - assert_equal(actual_q, expected_q, prec=prec, msg="".join([ - "\nexpected q.grad = {}".format(expected_q.detach().cpu().numpy()), - "\n actual q.grad = {}".format(actual_q.detach().cpu().numpy()), - ])) + assert_equal( + actual_q, + expected_q, + prec=prec, + msg="".join( + [ + "\nexpected q.grad = {}".format(expected_q.detach().cpu().numpy()), + "\n actual q.grad = {}".format(actual_q.detach().cpu().numpy()), + ] + ), + ) actual_a = a.grad / num_particles expected_a = grad(kl2, [a])[0] - assert_equal(actual_a, expected_a, prec=prec, msg="".join([ - "\nexpected a.grad= {}".format(expected_a.detach().cpu().numpy()), - "\n actual a.grad = {}".format(actual_a.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,num_steps,expand", [ - ("sequential", 2, True), - ("sequential", 2, False), - ("sequential", 3, True), - ("sequential", 3, False), - ("parallel", 2, True), - ("parallel", 2, False), - ("parallel", 3, True), - ("parallel", 3, False), - ("parallel", 10, False), - ("parallel", 20, False), - _skip_cuda("parallel", 30, False), -]) + assert_equal( + actual_a, + expected_a, + prec=prec, + msg="".join( + [ + "\nexpected a.grad= {}".format(expected_a.detach().cpu().numpy()), + "\n actual a.grad = {}".format(actual_a.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,num_steps,expand", + [ + ("sequential", 2, True), + ("sequential", 2, False), + ("sequential", 3, True), + ("sequential", 3, False), + ("parallel", 2, True), + ("parallel", 2, False), + ("parallel", 3, True), + ("parallel", 3, False), + ("parallel", 10, False), + ("parallel", 20, False), + _skip_cuda("parallel", 30, False), + ], +) def test_elbo_hmm_in_model(enumerate1, num_steps, expand): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.9, 0.1], [0.1, 0.9]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.9, 0.1], [0.1, 0.9]]), + constraint=constraints.simplex, + ) locs = pyro.param("obs_locs", torch.tensor([-1.0, 1.0])) - scale = pyro.param("obs_scale", torch.tensor(1.0), - constraint=constraints.positive) + scale = pyro.param( + "obs_scale", torch.tensor(1.0), constraint=constraints.positive + ) x = None for i, y in pyro.markov(enumerate(data)): @@ -1126,8 +1568,11 @@ def model(data): @config_enumerate(default=enumerate1, expand=expand) def guide(data): - mean_field_probs = pyro.param("mean_field_probs", torch.ones(num_steps, 2) / 2, - constraint=constraints.simplex) + mean_field_probs = pyro.param( + "mean_field_probs", + torch.ones(num_steps, 2) / 2, + constraint=constraints.simplex, + ) for i in pyro.markov(range(num_steps)): pyro.sample("x_{}".format(i), dist.Categorical(mean_field_probs[i])) @@ -1144,39 +1589,54 @@ def guide(data): for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = expected_unconstrained_grads[name] - assert_equal(actual, expected, msg=''.join([ - '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), - '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize("enumerate1,num_steps,expand", [ - ("sequential", 2, True), - ("sequential", 2, False), - ("sequential", 3, True), - ("sequential", 3, False), - ("parallel", 2, True), - ("parallel", 2, False), - ("parallel", 3, True), - ("parallel", 3, False), - ("parallel", 10, False), - ("parallel", 20, False), - _skip_cuda("parallel", 30, False), - _skip_cuda("parallel", 40, False), - _skip_cuda("parallel", 50, False), -]) + assert_equal( + actual, + expected, + msg="".join( + [ + "\nexpected {}.grad = {}".format(name, expected.cpu().numpy()), + "\n actual {}.grad = {}".format( + name, actual.detach().cpu().numpy() + ), + ] + ), + ) + + +@pytest.mark.parametrize( + "enumerate1,num_steps,expand", + [ + ("sequential", 2, True), + ("sequential", 2, False), + ("sequential", 3, True), + ("sequential", 3, False), + ("parallel", 2, True), + ("parallel", 2, False), + ("parallel", 3, True), + ("parallel", 3, False), + ("parallel", 10, False), + ("parallel", 20, False), + _skip_cuda("parallel", 30, False), + _skip_cuda("parallel", 40, False), + _skip_cuda("parallel", 50, False), + ], +) def test_elbo_hmm_in_guide(enumerate1, num_steps, expand): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - emission_probs = pyro.param("emission_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + emission_probs = pyro.param( + "emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = None for i, y in pyro.markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] @@ -1185,9 +1645,11 @@ def model(data): @config_enumerate(default=enumerate1, expand=expand) def guide(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = None for i, y in pyro.markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] @@ -1229,29 +1691,41 @@ def guide(data): for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = torch.tensor(expected_grads[num_steps][name]) - assert_equal(actual, expected, msg=''.join([ - '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), - '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize('num_steps', [2, 3, 4, 5, 10, 20, _skip_cuda(30)]) + assert_equal( + actual, + expected, + msg="".join( + [ + "\nexpected {}.grad = {}".format(name, expected.cpu().numpy()), + "\n actual {}.grad = {}".format( + name, actual.detach().cpu().numpy() + ), + ] + ), + ) + + +@pytest.mark.parametrize("num_steps", [2, 3, 4, 5, 10, 20, _skip_cuda(30)]) def test_hmm_enumerate_model(num_steps): data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,)) @config_enumerate def model(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - emission_probs = pyro.param("emission_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + emission_probs = pyro.param( + "emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = 0 for t, y in pyro.markov(enumerate(data)): x = pyro.sample("x_{}".format(t), dist.Categorical(transition_probs[x])) pyro.sample("y_{}".format(t), dist.Categorical(emission_probs[x]), obs=y) - logger.debug('{}\t{}'.format(t, tuple(x.shape))) + logger.debug("{}\t{}".format(t, tuple(x.shape))) def guide(data): pass @@ -1260,40 +1734,50 @@ def guide(data): elbo.differentiable_loss(model, guide, data) -@pytest.mark.parametrize('num_steps', [2, 3, 4, 5, 10, 20, _skip_cuda(30)]) +@pytest.mark.parametrize("num_steps", [2, 3, 4, 5, 10, 20, _skip_cuda(30)]) def test_hmm_enumerate_model_and_guide(num_steps): data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,)) def model(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - emission_probs = pyro.param("emission_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + emission_probs = pyro.param( + "emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5]))) - logger.debug('-1\t{}'.format(tuple(x.shape))) + logger.debug("-1\t{}".format(tuple(x.shape))) for t, y in pyro.markov(enumerate(data)): - x = pyro.sample("x_{}".format(t), dist.Categorical(transition_probs[x]), - infer={"enumerate": "parallel"}) + x = pyro.sample( + "x_{}".format(t), + dist.Categorical(transition_probs[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("y_{}".format(t), dist.Categorical(emission_probs[x]), obs=y) - logger.debug('{}\t{}'.format(t, tuple(x.shape))) + logger.debug("{}\t{}".format(t, tuple(x.shape))) def guide(data): - init_probs = pyro.param("init_probs", - torch.tensor([0.75, 0.25]), - constraint=constraints.simplex) - pyro.sample("x", dist.Categorical(init_probs), - infer={"enumerate": "parallel"}) + init_probs = pyro.param( + "init_probs", torch.tensor([0.75, 0.25]), constraint=constraints.simplex + ) + pyro.sample("x", dist.Categorical(init_probs), infer={"enumerate": "parallel"}) elbo = TraceEnum_ELBO() elbo.differentiable_loss(model, guide, data) def _check_loss_and_grads(expected_loss, actual_loss): - assert_equal(actual_loss, expected_loss, - msg='Expected:\n{}\nActual:\n{}'.format(expected_loss.detach().cpu().numpy(), - actual_loss.detach().cpu().numpy())) + assert_equal( + actual_loss, + expected_loss, + msg="Expected:\n{}\nActual:\n{}".format( + expected_loss.detach().cpu().numpy(), actual_loss.detach().cpu().numpy() + ), + ) names = pyro.get_param_store().keys() params = [pyro.param(name).unconstrained() for name in names] @@ -1302,26 +1786,33 @@ def _check_loss_and_grads(expected_loss, actual_loss): for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): if actual_grad is None or expected_grad is None: continue - assert_equal(actual_grad, expected_grad, - msg='{}\nExpected:\n{}\nActual:\n{}'.format(name, - expected_grad.detach().cpu().numpy(), - actual_grad.detach().cpu().numpy())) - - -@pytest.mark.parametrize('scale', [1, 10]) + assert_equal( + actual_grad, + expected_grad, + msg="{}\nExpected:\n{}\nActual:\n{}".format( + name, + expected_grad.detach().cpu().numpy(), + actual_grad.detach().cpu().numpy(), + ), + ) + + +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_1(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([0.3, 0.7]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", torch.tensor([0.3, 0.7]), constraint=constraints.simplex + ) @poutine.scale(scale=scale) def auto_model(): @@ -1329,8 +1820,7 @@ def auto_model(): probs_y = pyro.param("model_probs_y") probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) - pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + pyro.sample("y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"}) pyro.sample("z", dist.Categorical(probs_z), obs=torch.tensor(0)) @poutine.scale(scale=scale) @@ -1352,20 +1842,24 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_2(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) @poutine.scale(scale=scale) def auto_model(): @@ -1373,8 +1867,9 @@ def auto_model(): probs_y = pyro.param("model_probs_y") probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=torch.tensor(0)) @poutine.scale(scale=scale) @@ -1398,20 +1893,24 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_3(scale): - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(): probs_x = pyro.param("model_probs_x") @@ -1419,8 +1918,9 @@ def auto_model(): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with poutine.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=torch.tensor(0)) def hand_model(): @@ -1443,27 +1943,33 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(1, 1), (2, 2), (3, 2)], - ids=["single", "batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", + [(1, 1), (2, 2), (3, 2)], + ids=["single", "batch", "masked"], +) def test_elbo_enumerate_plate_1(num_samples, num_masked, scale): # +---------+ # x ----> y ----> z | # | N | # +---------+ - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(data): probs_x = pyro.param("model_probs_x") @@ -1471,8 +1977,9 @@ def auto_model(data): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with poutine.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) if num_masked == num_samples: with pyro.plate("data", len(data)): pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) @@ -1487,8 +1994,9 @@ def hand_model(data): probs_z = pyro.param("model_probs_z") x = pyro.sample("x", dist.Categorical(probs_x)) with poutine.scale(scale=scale): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) for i in pyro.plate("data", num_masked): pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @@ -1505,27 +2013,33 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(1, 1), (2, 2), (3, 2)], - ids=["single", "batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", + [(1, 1), (2, 2), (3, 2)], + ids=["single", "batch", "masked"], +) def test_elbo_enumerate_plate_2(num_samples, num_masked, scale): # +-----------------+ # x ----> y ----> z | # | N | # +-----------------+ - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) def auto_model(data): probs_x = pyro.param("model_probs_x") @@ -1535,13 +2049,19 @@ def auto_model(data): with poutine.scale(scale=scale): with pyro.plate("data", len(data)): if num_masked == num_samples: - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) else: with poutine.mask(mask=torch.arange(num_samples) < num_masked): - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) def hand_model(data): @@ -1551,8 +2071,11 @@ def hand_model(data): x = pyro.sample("x", dist.Categorical(probs_x)) with poutine.scale(scale=scale): for i in pyro.plate("data", num_masked): - y = pyro.sample("y_{}".format(i), dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y_{}".format(i), + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @config_enumerate @@ -1567,28 +2090,34 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('num_samples,num_masked', - [(1, 1), (2, 2), (3, 2)], - ids=["single", "batch", "masked"]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", + [(1, 1), (2, 2), (3, 2)], + ids=["single", "batch", "masked"], +) def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): # +-----------------------+ # | x ----> y ----> z | # | N | # +-----------------------+ # This plate should remain unreduced since all enumeration is in a single plate. - pyro.param("guide_probs_x", - torch.tensor([0.1, 0.9]), - constraint=constraints.simplex) - pyro.param("model_probs_x", - torch.tensor([0.4, 0.6]), - constraint=constraints.simplex) - pyro.param("model_probs_y", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_z", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) + pyro.param( + "guide_probs_x", torch.tensor([0.1, 0.9]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_x", torch.tensor([0.4, 0.6]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_y", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_z", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) @poutine.scale(scale=scale) def auto_model(data): @@ -1598,14 +2127,18 @@ def auto_model(data): with pyro.plate("data", len(data)): if num_masked == num_samples: x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) else: with poutine.mask(mask=torch.arange(num_samples) < num_masked): x = pyro.sample("x", dist.Categorical(probs_x)) - y = pyro.sample("y", dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) @poutine.scale(scale=scale) @@ -1626,8 +2159,11 @@ def hand_model(data): probs_z = pyro.param("model_probs_z") for i in pyro.plate("data", num_masked): x = pyro.sample("x_{}".format(i), dist.Categorical(probs_x)) - y = pyro.sample("y_{}".format(i), dist.Categorical(probs_y[x]), - infer={"enumerate": "parallel"}) + y = pyro.sample( + "y_{}".format(i), + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) @poutine.scale(scale=scale) @@ -1644,9 +2180,10 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) -@pytest.mark.parametrize('outer_obs,inner_obs', - [(False, True), (True, False), (True, True)]) +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "outer_obs,inner_obs", [(False, True), (True, False), (True, True)] +) def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): # a ---> outer_obs # \ @@ -1657,8 +2194,8 @@ def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): # This tests two different observations, one outside and one inside an plate. pyro.param("probs_a", torch.tensor([0.4, 0.6]), constraint=constraints.simplex) pyro.param("probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex) - pyro.param("locs", torch.tensor([-1., 1.])) - pyro.param("scales", torch.tensor([1., 2.]), constraint=constraints.positive) + pyro.param("locs", torch.tensor([-1.0, 1.0])) + pyro.param("scales", torch.tensor([1.0, 2.0]), constraint=constraints.positive) outer_data = torch.tensor(2.0) inner_data = torch.tensor([0.5, 1.5]) @@ -1668,17 +2205,17 @@ def auto_model(): probs_b = pyro.param("probs_b") locs = pyro.param("locs") scales = pyro.param("scales") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) if outer_obs: - pyro.sample("outer_obs", dist.Normal(0., scales[a]), - obs=outer_data) + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) with pyro.plate("inner", 2): - b = pyro.sample("b", dist.Categorical(probs_b), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b), infer={"enumerate": "parallel"} + ) if inner_obs: - pyro.sample("inner_obs", dist.Normal(locs[b], scales[a]), - obs=inner_data) + pyro.sample( + "inner_obs", dist.Normal(locs[b], scales[a]), obs=inner_data + ) @poutine.scale(scale=scale) def hand_model(): @@ -1686,17 +2223,21 @@ def hand_model(): probs_b = pyro.param("probs_b") locs = pyro.param("locs") scales = pyro.param("scales") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) if outer_obs: - pyro.sample("outer_obs", dist.Normal(0., scales[a]), - obs=outer_data) + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) for i in pyro.plate("inner", 2): - b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b_{}".format(i), + dist.Categorical(probs_b), + infer={"enumerate": "parallel"}, + ) if inner_obs: - pyro.sample("inner_obs_{}".format(i), dist.Normal(locs[b], scales[a]), - obs=inner_data[i]) + pyro.sample( + "inner_obs_{}".format(i), + dist.Normal(locs[b], scales[a]), + obs=inner_data[i], + ) def guide(): pass @@ -1715,19 +2256,22 @@ def test_elbo_enumerate_plate_5(): # | M=2 V | # | b ----> c | # +------------------+ - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([1, 2]) @config_enumerate @@ -1738,8 +2282,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("b_axis", 2): b = pyro.sample("b", dist.Categorical(probs_b)) - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @config_enumerate def guide_plate(): @@ -1755,9 +2298,9 @@ def model_iplate(): a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("b_axis", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i]) + pyro.sample( + "c_{}".format(i), dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i] + ) @config_enumerate def guide_iplate(): @@ -1768,13 +2311,15 @@ def guide_iplate(): elbo = TraceEnum_ELBO(max_plate_nesting=0) expected_loss = elbo.differentiable_loss(model_iplate, guide_iplate) elbo = TraceEnum_ELBO(max_plate_nesting=1) - with pytest.raises(ValueError, match="Expected model enumeration to be no more global than guide"): + with pytest.raises( + ValueError, match="Expected model enumeration to be no more global than guide" + ): actual_loss = elbo.differentiable_loss(model_plate, guide_plate) # This never gets run because we don't support this yet. _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.parametrize('enumerate1', ['parallel', 'sequential']) +@pytest.mark.parametrize("enumerate1", ["parallel", "sequential"]) def test_elbo_enumerate_plate_6(enumerate1): # Guide Model # +-------+ @@ -1783,19 +2328,22 @@ def test_elbo_enumerate_plate_6(enumerate1): # +-------+ # This tests that sequential enumeration over b works, even though # model-side enumeration moves c into b's plate via contraction. - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([1, 2]) @config_enumerate @@ -1806,8 +2354,7 @@ def model_plate(): a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b)) with pyro.plate("b_axis", 2): - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @config_enumerate def model_iplate(): @@ -1817,9 +2364,9 @@ def model_iplate(): a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b)) for i in pyro.plate("b_axis", 2): - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i]) + pyro.sample( + "c_{}".format(i), dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i] + ) @config_enumerate(default=enumerate1) def guide(): @@ -1833,7 +2380,7 @@ def guide(): _check_loss_and_grads(expected_loss, actual_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plate_7(scale): # Guide Model # a -----> b @@ -1843,27 +2390,37 @@ def test_elbo_enumerate_plate_7(scale): # | c -----> d -----> e N=2 | # +---------------------------+ # This tests a mixture of model and guide enumeration. - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("model_probs_d", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), - constraint=constraints.simplex) - pyro.param("model_probs_e", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("guide_probs_a", - torch.tensor([0.35, 0.64]), - constraint=constraints.simplex) - pyro.param("guide_probs_c", - torch.tensor([[0., 1.], [1., 0.]]), # deterministic - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + constraint=constraints.simplex, + ) + pyro.param( + "model_probs_e", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_a", torch.tensor([0.35, 0.64]), constraint=constraints.simplex + ) + pyro.param( + "guide_probs_c", + torch.tensor([[0.0, 1.0], [1.0, 0.0]]), # deterministic + constraint=constraints.simplex, + ) @poutine.scale(scale=scale) def auto_model(data): @@ -1873,20 +2430,23 @@ def auto_model(data): probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample("b", dist.Categorical(probs_b[a]), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) with pyro.plate("data", 2): c = pyro.sample("c", dist.Categorical(probs_c[a])) - d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}) + d = pyro.sample( + "d", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) @poutine.scale(scale=scale) def auto_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) with pyro.plate("data", 2): pyro.sample("c", dist.Categorical(probs_c[a])) @@ -1898,21 +2458,23 @@ def hand_model(data): probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample("b", dist.Categorical(probs_b[a]), - infer={"enumerate": "parallel"}) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) for i in pyro.plate("data", 2): c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) - d = pyro.sample("d_{}".format(i), - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}) + d = pyro.sample( + "d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i]) @poutine.scale(scale=scale) def hand_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") - a = pyro.sample("a", dist.Categorical(probs_a), - infer={"enumerate": "parallel"}) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) for i in pyro.plate("data", 2): pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) @@ -1924,7 +2486,7 @@ def hand_guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_1(scale): # +-----------------+ # | a ----> b M=2 | @@ -1934,18 +2496,18 @@ def test_elbo_enumerate_plates_1(scale): # +-----------------+ # This tests two unrelated plates. # Each should remain uncontracted. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([0.75, 0.25]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param("probs_c", torch.tensor([0.75, 0.25]), constraint=constraints.simplex) + pyro.param( + "probs_d", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) b_data = torch.tensor([0, 1]) d_data = torch.tensor([0, 0, 1]) @@ -1987,22 +2549,24 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_2(scale): # +---------+ +---------+ # | b <---- a ----> c | # | M=2 | | N=3 | # +---------+ +---------+ # This tests two different plates with recycled dimension. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) b_data = torch.tensor([0, 1]) c_data = torch.tensor([0, 0, 1]) @@ -2014,11 +2578,9 @@ def auto_model(): probs_c = pyro.param("probs_c") a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("b_axis", 2): - pyro.sample("b", dist.Categorical(probs_b[a]), - obs=b_data) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) with pyro.plate("c_axis", 3): - pyro.sample("c", dist.Categorical(probs_c[a]), - obs=c_data) + pyro.sample("c", dist.Categorical(probs_c[a]), obs=c_data) @config_enumerate @poutine.scale(scale=scale) @@ -2028,11 +2590,9 @@ def hand_model(): probs_c = pyro.param("probs_c") a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("b_axis", 2): - pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), - obs=b_data[i]) + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), obs=b_data[i]) for j in pyro.plate("c_axis", 3): - pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a]), - obs=c_data[j]) + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a]), obs=c_data[j]) def guide(): pass @@ -2044,7 +2604,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_3(scale): # +--------------------+ # | +----------+ | @@ -2054,12 +2614,12 @@ def test_elbo_enumerate_plates_3(scale): # +--------------------+ # This is tests the case of multiple plate contractions in # a single step. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1], [0, 0]]) @config_enumerate @@ -2070,8 +2630,7 @@ def auto_model(): a = pyro.sample("a", dist.Categorical(probs_a)) with pyro.plate("outer", 2): with pyro.plate("inner", 2): - pyro.sample("b", dist.Categorical(probs_b[a]), - obs=data) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=data) @config_enumerate @poutine.scale(scale=scale) @@ -2082,8 +2641,9 @@ def hand_model(): a = pyro.sample("a", dist.Categorical(probs_a)) for i in pyro.plate("outer", 2): for j in inner: - pyro.sample("b_{}_{}".format(i, j), dist.Categorical(probs_b[a]), - obs=data[i, j]) + pyro.sample( + "b_{}_{}".format(i, j), dist.Categorical(probs_b[a]), obs=data[i, j] + ) def guide(): pass @@ -2095,7 +2655,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_4(scale): # +--------------------+ # | +----------+ | @@ -2103,15 +2663,17 @@ def test_elbo_enumerate_plates_4(scale): # | | N=2 | | # | M=2 +----------+ | # +--------------------+ - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) @config_enumerate @poutine.scale(scale=scale) @@ -2136,8 +2698,9 @@ def hand_model(data): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for j in inner: - pyro.sample("c_{}_{}".format(i, j), dist.Categorical(probs_c[b]), - obs=data[i, j]) + pyro.sample( + "c_{}_{}".format(i, j), dist.Categorical(probs_c[b]), obs=data[i, j] + ) def guide(data): pass @@ -2150,7 +2713,7 @@ def guide(data): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_5(scale): # a # | \ @@ -2160,16 +2723,17 @@ def test_elbo_enumerate_plates_5(scale): # | | N=2 | | # | M=2 +----------+ | # +-------------------+ - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], - [[0.2, 0.8], [0.1, 0.9]]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.2, 0.8], [0.1, 0.9]]]), + constraint=constraints.simplex, + ) data = torch.tensor([[0, 1], [0, 0]]) @config_enumerate @@ -2182,8 +2746,7 @@ def auto_model(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b[a])) with pyro.plate("inner", 2): - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @config_enumerate @poutine.scale(scale=scale) @@ -2196,9 +2759,11 @@ def hand_model(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for j in inner: - pyro.sample("c_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[i, j]) + pyro.sample( + "c_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[i, j], + ) def guide(): pass @@ -2210,7 +2775,7 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_6(scale): # +----------+ # | M=2 | @@ -2224,18 +2789,22 @@ def test_elbo_enumerate_plates_6(scale): # +-------------+ # This tests different ways of mixing two independence contexts, # where each can be either sequential or vectorized plate. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + constraint=constraints.simplex, + ) @config_enumerate @poutine.scale(scale=scale) @@ -2247,13 +2816,19 @@ def model_iplate_iplate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] for i in b_axis: for j in c_axis: - pyro.sample("d_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_d)[b[i], c[j]]), - obs=data[i, j]) + pyro.sample( + "d_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_d)[b[i], c[j]]), + obs=data[i, j], + ) @config_enumerate @poutine.scale(scale=scale) @@ -2265,14 +2840,18 @@ def model_iplate_plate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] with c_axis: c = pyro.sample("c", dist.Categorical(probs_c[a])) for i in b_axis: with c_axis: - pyro.sample("d_{}".format(i), - dist.Categorical(Vindex(probs_d)[b[i], c]), - obs=data[i]) + pyro.sample( + "d_{}".format(i), + dist.Categorical(Vindex(probs_d)[b[i], c]), + obs=data[i], + ) @config_enumerate @poutine.scale(scale=scale) @@ -2286,12 +2865,16 @@ def model_plate_iplate(data): a = pyro.sample("a", dist.Categorical(probs_a)) with b_axis: b = pyro.sample("b", dist.Categorical(probs_b[a])) - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] with b_axis: for j in c_axis: - pyro.sample("d_{}".format(j), - dist.Categorical(Vindex(probs_d)[b, c[j]]), - obs=data[:, j]) + pyro.sample( + "d_{}".format(j), + dist.Categorical(Vindex(probs_d)[b, c[j]]), + obs=data[:, j], + ) @config_enumerate @poutine.scale(scale=scale) @@ -2308,9 +2891,7 @@ def model_plate_plate(data): with c_axis: c = pyro.sample("c", dist.Categorical(probs_c[a])) with b_axis, c_axis: - pyro.sample("d", - dist.Categorical(Vindex(probs_d)[b, c]), - obs=data) + pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), obs=data) def guide(data): pass @@ -2327,11 +2908,13 @@ def guide(data): # But promoting both to plates should result in an error. elbo = TraceEnum_ELBO(max_plate_nesting=2) - with pytest.raises(NotImplementedError, match="Expected tree-structured plate nesting.*"): + with pytest.raises( + NotImplementedError, match="Expected tree-structured plate nesting.*" + ): elbo.differentiable_loss(model_plate_plate, guide, data) -@pytest.mark.parametrize('scale', [1, 10]) +@pytest.mark.parametrize("scale", [1, 10]) def test_elbo_enumerate_plates_7(scale): # +-------------+ # | N=2 | @@ -2346,21 +2929,27 @@ def test_elbo_enumerate_plates_7(scale): # +----------------+ # This tests tree-structured dependencies among variables but # non-tree dependencies among plate nestings. - pyro.param("probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("probs_b", - torch.tensor([[0.6, 0.4], [0.4, 0.6]]), - constraint=constraints.simplex) - pyro.param("probs_c", - torch.tensor([[0.75, 0.25], [0.55, 0.45]]), - constraint=constraints.simplex) - pyro.param("probs_d", - torch.tensor([[0.3, 0.7], [0.2, 0.8]]), - constraint=constraints.simplex) - pyro.param("probs_e", - torch.tensor([[0.4, 0.6], [0.3, 0.7]]), - constraint=constraints.simplex) + pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) + pyro.param( + "probs_b", + torch.tensor([[0.6, 0.4], [0.4, 0.6]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_c", + torch.tensor([[0.75, 0.25], [0.55, 0.45]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_d", + torch.tensor([[0.3, 0.7], [0.2, 0.8]]), + constraint=constraints.simplex, + ) + pyro.param( + "probs_e", + torch.tensor([[0.4, 0.6], [0.3, 0.7]]), + constraint=constraints.simplex, + ) @config_enumerate @poutine.scale(scale=scale) @@ -2373,14 +2962,24 @@ def model_iplate_iplate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] for i in b_axis: for j in c_axis: - pyro.sample("d_{}_{}".format(i, j), dist.Categorical(probs_d[b[i]]), - obs=data[i, j]) - pyro.sample("e_{}_{}".format(i, j), dist.Categorical(probs_e[c[j]]), - obs=data[i, j]) + pyro.sample( + "d_{}_{}".format(i, j), + dist.Categorical(probs_d[b[i]]), + obs=data[i, j], + ) + pyro.sample( + "e_{}_{}".format(i, j), + dist.Categorical(probs_e[c[j]]), + obs=data[i, j], + ) @config_enumerate @poutine.scale(scale=scale) @@ -2393,15 +2992,17 @@ def model_iplate_plate(data): b_axis = pyro.plate("b_axis", 2) c_axis = pyro.plate("c_axis", 2) a = pyro.sample("a", dist.Categorical(probs_a)) - b = [pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis] + b = [ + pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a])) for i in b_axis + ] with c_axis: c = pyro.sample("c", dist.Categorical(probs_c[a])) for i in b_axis: with c_axis: - pyro.sample("d_{}".format(i), dist.Categorical(probs_d[b[i]]), - obs=data[i]) - pyro.sample("e_{}".format(i), dist.Categorical(probs_e[c]), - obs=data[i]) + pyro.sample( + "d_{}".format(i), dist.Categorical(probs_d[b[i]]), obs=data[i] + ) + pyro.sample("e_{}".format(i), dist.Categorical(probs_e[c]), obs=data[i]) @config_enumerate @poutine.scale(scale=scale) @@ -2416,13 +3017,17 @@ def model_plate_iplate(data): a = pyro.sample("a", dist.Categorical(probs_a)) with b_axis: b = pyro.sample("b", dist.Categorical(probs_b[a])) - c = [pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis] + c = [ + pyro.sample("c_{}".format(j), dist.Categorical(probs_c[a])) for j in c_axis + ] with b_axis: for j in c_axis: - pyro.sample("d_{}".format(j), dist.Categorical(probs_d[b]), - obs=data[:, j]) - pyro.sample("e_{}".format(j), dist.Categorical(probs_e[c[j]]), - obs=data[:, j]) + pyro.sample( + "d_{}".format(j), dist.Categorical(probs_d[b]), obs=data[:, j] + ) + pyro.sample( + "e_{}".format(j), dist.Categorical(probs_e[c[j]]), obs=data[:, j] + ) @config_enumerate @poutine.scale(scale=scale) @@ -2460,12 +3065,16 @@ def guide(data): _check_loss_and_grads(loss_iplate_iplate, loss_plate_plate) -@pytest.mark.parametrize('guide_scale', [1]) -@pytest.mark.parametrize('model_scale', [1]) -@pytest.mark.parametrize('outer_vectorized,inner_vectorized,xfail', - [(False, True, False), (True, False, True), (True, True, True)], - ids=['iplate-plate', 'plate-iplate', 'plate-plate']) -def test_elbo_enumerate_plates_8(model_scale, guide_scale, inner_vectorized, outer_vectorized, xfail): +@pytest.mark.parametrize("guide_scale", [1]) +@pytest.mark.parametrize("model_scale", [1]) +@pytest.mark.parametrize( + "outer_vectorized,inner_vectorized,xfail", + [(False, True, False), (True, False, True), (True, True, True)], + ids=["iplate-plate", "plate-iplate", "plate-plate"], +) +def test_elbo_enumerate_plates_8( + model_scale, guide_scale, inner_vectorized, outer_vectorized, xfail +): # Guide Model # a # +-----------|--------+ @@ -2474,19 +3083,22 @@ def test_elbo_enumerate_plates_8(model_scale, guide_scale, inner_vectorized, out # | b ----> c | | # | +----------+ | # +--------------------+ - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([[0, 1], [0, 2]]) @config_enumerate @@ -2499,9 +3111,7 @@ def model_plate_plate(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b)) with pyro.plate("inner", 2): - pyro.sample("c", - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) @config_enumerate @poutine.scale(scale=model_scale) @@ -2514,9 +3124,11 @@ def model_iplate_plate(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) with inner: - pyro.sample("c_{}".format(i), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[:, i]) + pyro.sample( + "c_{}".format(i), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[:, i], + ) @config_enumerate @poutine.scale(scale=model_scale) @@ -2528,9 +3140,11 @@ def model_plate_iplate(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b)) for j in pyro.plate("inner", 2): - pyro.sample("c_{}".format(j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[j]) + pyro.sample( + "c_{}".format(j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j], + ) @config_enumerate @poutine.scale(scale=model_scale) @@ -2543,9 +3157,11 @@ def model_iplate_iplate(): for i in pyro.plate("outer", 2): b = pyro.sample("b_{}".format(i), dist.Categorical(probs_b)) for j in inner: - pyro.sample("c_{}_{}".format(i, j), - dist.Categorical(Vindex(probs_c)[a, b]), - obs=data[j, i]) + pyro.sample( + "c_{}_{}".format(i, j), + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j, i], + ) @config_enumerate @poutine.scale(scale=guide_scale) @@ -2565,9 +3181,12 @@ def guide_iplate(): expected_loss = elbo.differentiable_loss(model_iplate_iplate, guide_iplate) with ExitStack() as stack: if xfail: - stack.enter_context(pytest.raises( - ValueError, - match="Expected model enumeration to be no more global than guide")) + stack.enter_context( + pytest.raises( + ValueError, + match="Expected model enumeration to be no more global than guide", + ) + ) if inner_vectorized: if outer_vectorized: elbo = TraceEnum_ELBO(max_plate_nesting=2) @@ -2584,11 +3203,13 @@ def guide_iplate(): def test_elbo_scale(): # Consider a mixture model with two components, toggled by `which`. def component_model(data, which, suffix=""): - loc = pyro.param("locs", torch.tensor([-1., 1.]))[which] + loc = pyro.param("locs", torch.tensor([-1.0, 1.0]))[which] with pyro.plate("data" + suffix, len(data)): - pyro.sample("obs" + suffix, dist.Normal(loc, 1.), obs=data) + pyro.sample("obs" + suffix, dist.Normal(loc, 1.0), obs=data) - pyro.param("mixture_probs", torch.tensor([0.25, 0.75]), constraint=constraints.simplex) + pyro.param( + "mixture_probs", torch.tensor([0.25, 0.75]), constraint=constraints.simplex + ) # We can implement this in two ways. # First consider automatic enumeration in the guide. @@ -2599,8 +3220,9 @@ def auto_model(data): def auto_guide(data): mixture_probs = pyro.param("mixture_probs") - pyro.sample("which", dist.Categorical(mixture_probs), - infer={"enumerate": "parallel"}) + pyro.sample( + "which", dist.Categorical(mixture_probs), infer={"enumerate": "parallel"} + ) # Second consider explicit enumeration in the model, where we # marginalize out the `which` variable by hand. @@ -2613,7 +3235,7 @@ def hand_model(data): def hand_guide(data): pass - data = dist.Normal(0., 2.).sample((3,)) + data = dist.Normal(0.0, 2.0).sample((3,)) elbo = TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) auto_loss = elbo.differentiable_loss(auto_model, auto_guide, data) hand_loss = elbo.differentiable_loss(hand_model, hand_guide, data) @@ -2626,12 +3248,16 @@ def test_elbo_hmm_growth(): elbo = TraceEnum_ELBO(max_plate_nesting=0) def model(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - emission_probs = pyro.param("emission_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + emission_probs = pyro.param( + "emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = None for i, y in pyro.markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] @@ -2640,15 +3266,17 @@ def model(data): @config_enumerate def guide(data): - transition_probs = pyro.param("transition_probs", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) x = None for i, y in pyro.markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) - sizes = range(3, 1 + int(os.environ.get('GROWTH_SIZE', 15))) + sizes = range(3, 1 + int(os.environ.get("GROWTH_SIZE", 15))) costs = [] times1 = [] times2 = [] @@ -2669,25 +3297,33 @@ def guide(data): for counts in costs: for key, cost in counts.items(): collated_costs[key].append(cost) - logger.debug('\n'.join([ - 'HMM Growth:', - 'sizes = {}'.format(repr(sizes)), - 'costs = {}'.format(repr(dict(collated_costs))), - 'times1 = {}'.format(repr(times1)), - 'times2 = {}'.format(repr(times2)), - ])) - - -@pytest.mark.skipif("CUDA_TEST" in os.environ, reason="https://github.com/pyro-ppl/pyro/issues/1380") + logger.debug( + "\n".join( + [ + "HMM Growth:", + "sizes = {}".format(repr(sizes)), + "costs = {}".format(repr(dict(collated_costs))), + "times1 = {}".format(repr(times1)), + "times2 = {}".format(repr(times2)), + ] + ) + ) + + +@pytest.mark.skipif( + "CUDA_TEST" in os.environ, reason="https://github.com/pyro-ppl/pyro/issues/1380" +) def test_elbo_dbn_growth(): pyro.clear_param_store() elbo = TraceEnum_ELBO(max_plate_nesting=0) def model(data): uniform = torch.tensor([0.5, 0.5]) - probs_z = pyro.param("probs_z", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) + probs_z = pyro.param( + "probs_z", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) for i, z in pyro.markov(enumerate(data)): pyro.sample("x_{}".format(i), dist.Categorical(uniform)) y = pyro.sample("y_{}".format(i), dist.Categorical(uniform)) @@ -2695,20 +3331,23 @@ def model(data): @config_enumerate def guide(data): - probs_x = pyro.param("probs_x", - torch.tensor([[0.75, 0.25], [0.25, 0.75]]), - constraint=constraints.simplex) - probs_y = pyro.param("probs_y", - torch.tensor([[[0.75, 0.25], [0.45, 0.55]], - [[0.55, 0.45], [0.25, 0.75]]]), - constraint=constraints.simplex) + probs_x = pyro.param( + "probs_x", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex, + ) + probs_y = pyro.param( + "probs_y", + torch.tensor([[[0.75, 0.25], [0.45, 0.55]], [[0.55, 0.45], [0.25, 0.75]]]), + constraint=constraints.simplex, + ) x = 0 y = 0 for i in pyro.markov(range(len(data))): x = pyro.sample("x_{}".format(i), dist.Categorical(probs_x[x])) y = pyro.sample("y_{}".format(i), dist.Categorical(probs_y[x, y])) - sizes = range(3, 1 + int(os.environ.get('GROWTH_SIZE', 15))) + sizes = range(3, 1 + int(os.environ.get("GROWTH_SIZE", 15))) costs = [] times1 = [] times2 = [] @@ -2729,13 +3368,17 @@ def guide(data): for counts in costs: for key, cost in counts.items(): collated_costs[key].append(cost) - logger.debug('\n'.join([ - 'DBN Growth:', - 'sizes = {}'.format(repr(sizes)), - 'costs = {}'.format(repr(dict(collated_costs))), - 'times1 = {}'.format(repr(times1)), - 'times2 = {}'.format(repr(times2)), - ])) + logger.debug( + "\n".join( + [ + "DBN Growth:", + "sizes = {}".format(repr(sizes)), + "costs = {}".format(repr(dict(collated_costs))), + "times1 = {}".format(repr(times1)), + "times2 = {}".format(repr(times2)), + ] + ) + ) @pytest.mark.parametrize("pi_a", [0.33]) @@ -2745,7 +3388,9 @@ def guide(data): @pytest.mark.parametrize("N_c", [5, 6]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("expand", [True, False]) -def test_bernoulli_pyramid_elbo_gradient(enumerate1, N_b, N_c, pi_a, pi_b, pi_c, expand): +def test_bernoulli_pyramid_elbo_gradient( + enumerate1, N_b, N_c, pi_a, pi_b, pi_c, expand +): pyro.clear_param_store() def model(): @@ -2766,12 +3411,13 @@ def guide(): pyro.sample("c", dist.Bernoulli(qc).expand_by([N_c, N_b])) logger.info("Computing gradients using surrogate loss") - elbo = TraceEnum_ELBO(max_plate_nesting=2, - strict_enumeration_warning=True) - elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1, expand=expand)) - actual_grad_qa = pyro.param('qa').grad - actual_grad_qb = pyro.param('qb').grad - actual_grad_qc = pyro.param('qc').grad + elbo = TraceEnum_ELBO(max_plate_nesting=2, strict_enumeration_warning=True) + elbo.loss_and_grads( + model, config_enumerate(guide, default=enumerate1, expand=expand) + ) + actual_grad_qa = pyro.param("qa").grad + actual_grad_qb = pyro.param("qb").grad + actual_grad_qc = pyro.param("qc").grad logger.info("Computing analytic gradients") qa = torch.tensor(pi_a, requires_grad=True) @@ -2779,27 +3425,58 @@ def guide(): qc = torch.tensor(pi_c, requires_grad=True) elbo = kl_divergence(dist.Bernoulli(qa), dist.Bernoulli(0.33)) elbo = elbo + N_b * qa * kl_divergence(dist.Bernoulli(qb), dist.Bernoulli(0.75)) - elbo = elbo + N_b * (1.0 - qa) * kl_divergence(dist.Bernoulli(qb), dist.Bernoulli(0.50)) - elbo = elbo + N_c * N_b * qa * qb * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.67)) - elbo = elbo + N_c * N_b * (1.0 - qa) * qb * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.52)) - elbo = elbo + N_c * N_b * qa * (1.0 - qb) * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.47)) - elbo = elbo + N_c * N_b * (1.0 - qa) * (1.0 - qb) * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.32)) + elbo = elbo + N_b * (1.0 - qa) * kl_divergence( + dist.Bernoulli(qb), dist.Bernoulli(0.50) + ) + elbo = elbo + N_c * N_b * qa * qb * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.67) + ) + elbo = elbo + N_c * N_b * (1.0 - qa) * qb * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.52) + ) + elbo = elbo + N_c * N_b * qa * (1.0 - qb) * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.47) + ) + elbo = elbo + N_c * N_b * (1.0 - qa) * (1.0 - qb) * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.32) + ) expected_grad_qa, expected_grad_qb, expected_grad_qc = grad(elbo, [qa, qb, qc]) prec = 0.001 - assert_equal(actual_grad_qa, expected_grad_qa, prec=prec, msg="".join([ - "\nqa expected = {}".format(expected_grad_qa.data.cpu().numpy()), - "\nqa actual = {}".format(actual_grad_qa.data.cpu().numpy()), - ])) - assert_equal(actual_grad_qb, expected_grad_qb, prec=prec, msg="".join([ - "\nqb expected = {}".format(expected_grad_qb.data.cpu().numpy()), - "\nqb actual = {}".format(actual_grad_qb.data.cpu().numpy()), - ])) - assert_equal(actual_grad_qc, expected_grad_qc, prec=prec, msg="".join([ - "\nqc expected = {}".format(expected_grad_qc.data.cpu().numpy()), - "\nqc actual = {}".format(actual_grad_qc.data.cpu().numpy()), - ])) + assert_equal( + actual_grad_qa, + expected_grad_qa, + prec=prec, + msg="".join( + [ + "\nqa expected = {}".format(expected_grad_qa.data.cpu().numpy()), + "\nqa actual = {}".format(actual_grad_qa.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_qb, + expected_grad_qb, + prec=prec, + msg="".join( + [ + "\nqb expected = {}".format(expected_grad_qb.data.cpu().numpy()), + "\nqb actual = {}".format(actual_grad_qb.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_qc, + expected_grad_qc, + prec=prec, + msg="".join( + [ + "\nqc expected = {}".format(expected_grad_qc.data.cpu().numpy()), + "\nqc actual = {}".format(actual_grad_qc.data.cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("pi_a", [0.33]) @@ -2811,8 +3488,19 @@ def guide(): @pytest.mark.parametrize("d_offset", [0.32]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("expand", [True, False]) -def test_bernoulli_non_tree_elbo_gradient(enumerate1, b_factor, c_factor, pi_a, pi_b, pi_c, pi_d, - expand, d_offset, N_b=2, N_c=2): +def test_bernoulli_non_tree_elbo_gradient( + enumerate1, + b_factor, + c_factor, + pi_a, + pi_b, + pi_c, + pi_d, + expand, + d_offset, + N_b=2, + N_c=2, +): pyro.clear_param_store() def model(): @@ -2832,13 +3520,14 @@ def guide(): pyro.sample("d", dist.Bernoulli(qd)) logger.info("Computing gradients using surrogate loss") - elbo = TraceEnum_ELBO(max_plate_nesting=2, - strict_enumeration_warning=True) - elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1, expand=expand)) - actual_grad_qa = pyro.param('qa').grad - actual_grad_qb = pyro.param('qb').grad - actual_grad_qc = pyro.param('qc').grad - actual_grad_qd = pyro.param('qd').grad + elbo = TraceEnum_ELBO(max_plate_nesting=2, strict_enumeration_warning=True) + elbo.loss_and_grads( + model, config_enumerate(guide, default=enumerate1, expand=expand) + ) + actual_grad_qa = pyro.param("qa").grad + actual_grad_qb = pyro.param("qb").grad + actual_grad_qc = pyro.param("qc").grad + actual_grad_qd = pyro.param("qd").grad logger.info("Computing analytic gradients") qa = torch.tensor(pi_a, requires_grad=True) @@ -2851,39 +3540,83 @@ def guide(): elbo = elbo + (1.0 - qa) * kl_divergence(dist.Bernoulli(qb), dist.Bernoulli(0.50)) elbo = elbo + qa * qb * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.85)) - elbo = elbo + (1.0 - qa) * qb * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.60)) - elbo = elbo + qa * (1.0 - qb) * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.75)) - elbo = elbo + (1.0 - qa) * (1.0 - qb) * kl_divergence(dist.Bernoulli(qc), dist.Bernoulli(0.50)) - - elbo = elbo + qb * qc * kl_divergence(dist.Bernoulli(qd), dist.Bernoulli(b_factor + c_factor + d_offset)) - elbo = elbo + (1.0 - qb) * qc * kl_divergence(dist.Bernoulli(qd), dist.Bernoulli(c_factor + d_offset)) - elbo = elbo + qb * (1.0 - qc) * kl_divergence(dist.Bernoulli(qd), dist.Bernoulli(b_factor + d_offset)) - elbo = elbo + (1.0 - qb) * (1.0 - qc) * kl_divergence(dist.Bernoulli(qd), dist.Bernoulli(d_offset)) - - expected_grad_qa, expected_grad_qb, expected_grad_qc, expected_grad_qd = grad(elbo, [qa, qb, qc, qd]) + elbo = elbo + (1.0 - qa) * qb * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.60) + ) + elbo = elbo + qa * (1.0 - qb) * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.75) + ) + elbo = elbo + (1.0 - qa) * (1.0 - qb) * kl_divergence( + dist.Bernoulli(qc), dist.Bernoulli(0.50) + ) + + elbo = elbo + qb * qc * kl_divergence( + dist.Bernoulli(qd), dist.Bernoulli(b_factor + c_factor + d_offset) + ) + elbo = elbo + (1.0 - qb) * qc * kl_divergence( + dist.Bernoulli(qd), dist.Bernoulli(c_factor + d_offset) + ) + elbo = elbo + qb * (1.0 - qc) * kl_divergence( + dist.Bernoulli(qd), dist.Bernoulli(b_factor + d_offset) + ) + elbo = elbo + (1.0 - qb) * (1.0 - qc) * kl_divergence( + dist.Bernoulli(qd), dist.Bernoulli(d_offset) + ) + + expected_grad_qa, expected_grad_qb, expected_grad_qc, expected_grad_qd = grad( + elbo, [qa, qb, qc, qd] + ) prec = 0.0001 - assert_equal(actual_grad_qa, expected_grad_qa, prec=prec, msg="".join([ - "\nqa expected = {}".format(expected_grad_qa.data.cpu().numpy()), - "\nqa actual = {}".format(actual_grad_qa.data.cpu().numpy()), - ])) - assert_equal(actual_grad_qb, expected_grad_qb, prec=prec, msg="".join([ - "\nqb expected = {}".format(expected_grad_qb.data.cpu().numpy()), - "\nqb actual = {}".format(actual_grad_qb.data.cpu().numpy()), - ])) - assert_equal(actual_grad_qc, expected_grad_qc, prec=prec, msg="".join([ - "\nqc expected = {}".format(expected_grad_qc.data.cpu().numpy()), - "\nqc actual = {}".format(actual_grad_qc.data.cpu().numpy()), - ])) - assert_equal(actual_grad_qd, expected_grad_qd, prec=prec, msg="".join([ - "\nqd expected = {}".format(expected_grad_qd.data.cpu().numpy()), - "\nqd actual = {}".format(actual_grad_qd.data.cpu().numpy()), - ])) + assert_equal( + actual_grad_qa, + expected_grad_qa, + prec=prec, + msg="".join( + [ + "\nqa expected = {}".format(expected_grad_qa.data.cpu().numpy()), + "\nqa actual = {}".format(actual_grad_qa.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_qb, + expected_grad_qb, + prec=prec, + msg="".join( + [ + "\nqb expected = {}".format(expected_grad_qb.data.cpu().numpy()), + "\nqb actual = {}".format(actual_grad_qb.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_qc, + expected_grad_qc, + prec=prec, + msg="".join( + [ + "\nqc expected = {}".format(expected_grad_qc.data.cpu().numpy()), + "\nqc actual = {}".format(actual_grad_qc.data.cpu().numpy()), + ] + ), + ) + assert_equal( + actual_grad_qd, + expected_grad_qd, + prec=prec, + msg="".join( + [ + "\nqd expected = {}".format(expected_grad_qd.data.cpu().numpy()), + "\nqd actual = {}".format(actual_grad_qd.data.cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("gate", [0.1, 0.25, 0.5, 0.75, 0.9]) -@pytest.mark.parametrize("rate", [0.1, 1., 3.]) +@pytest.mark.parametrize("rate", [0.1, 1.0, 3.0]) def test_elbo_zip(gate, rate): # test for ZIP distribution def zip_model(data): @@ -2895,10 +3628,12 @@ def zip_model(data): def composite_model(data): gate = pyro.param("gate") rate = pyro.param("rate") - dist1 = dist.Delta(torch.tensor(0.)) + dist1 = dist.Delta(torch.tensor(0.0)) dist0 = dist.Poisson(rate) with pyro.plate("data", len(data)): - mask = pyro.sample("mask", dist.Bernoulli(gate), infer={"enumerate": "parallel"}).bool() + mask = pyro.sample( + "mask", dist.Bernoulli(gate), infer={"enumerate": "parallel"} + ).bool() pyro.sample("obs", dist.MaskedMixture(mask, dist0, dist1), obs=data) def guide(data): @@ -2907,23 +3642,26 @@ def guide(data): pyro.param("gate", torch.tensor(gate), constraint=constraints.unit_interval) pyro.param("rate", torch.tensor(rate), constraint=constraints.positive) - data = torch.tensor([0., 1., 2.]) + data = torch.tensor([0.0, 1.0, 2.0]) elbo = TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) zip_loss = elbo.differentiable_loss(zip_model, guide, data) composite_loss = elbo.differentiable_loss(composite_model, guide, data) _check_loss_and_grads(zip_loss, composite_loss) -@pytest.mark.parametrize("mixture,scale", [ - (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), - (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), -]) +@pytest.mark.parametrize( + "mixture,scale", + [ + (dist.MixtureOfDiagNormals, [[2.0, 1.0], [1.0, 2], [4.0, 4.0]]), + (dist.MixtureOfDiagNormalsSharedCovariance, [2.0, 1.0]), + ], +) def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 - pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) + pyro.param("locs", torch.tensor([[0.0, 0.0], [0.0, 1.0], [0.0, 10.0]])) pyro.param("coord_scale", torch.tensor(scale), constraint=constraints.positive) - pyro.param("component_logits", torch.tensor([0., -1., 2.])) - data = torch.tensor([[0., 0.], [1., 1.], [2., 3.], [1., 11.]]) + pyro.param("component_logits", torch.tensor([0.0, -1.0, 2.0])) + data = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 3.0], [1.0, 11.0]]) def auto_model(): locs = pyro.param("locs") @@ -2937,12 +3675,20 @@ def hand_model(): coord_scale = pyro.param("coord_scale") component_logits = pyro.param("component_logits") with pyro.plate("data", len(data), dim=-2): - which = pyro.sample("mask", dist.Categorical(logits=component_logits), - infer={"enumerate": "parallel"}) - with pyro.plate("components", len(component_logits), dim=-1) as component_ind: + which = pyro.sample( + "mask", + dist.Categorical(logits=component_logits), + infer={"enumerate": "parallel"}, + ) + with pyro.plate( + "components", len(component_logits), dim=-1 + ) as component_ind: with poutine.mask(mask=(which == component_ind)): - pyro.sample("obs", dist.Normal(locs, coord_scale).to_event(1), - obs=data.unsqueeze(-2)) + pyro.sample( + "obs", + dist.Normal(locs, coord_scale).to_event(1), + obs=data.unsqueeze(-2), + ) def guide(): pass @@ -2953,29 +3699,32 @@ def guide(): _check_loss_and_grads(hand_loss, auto_loss) -@pytest.mark.parametrize("Dist, prior", [ - (dist.Bernoulli, 0.2), - (dist.Categorical, [0.2, 0.8]), - (dist.Categorical, [0.2, 0.3, 0.5]), - (dist.Categorical, [0.2, 0.3, 0.3, 0.2]), - (dist.OneHotCategorical, [0.2, 0.8]), - (dist.OneHotCategorical, [0.2, 0.3, 0.5]), - (dist.OneHotCategorical, [0.2, 0.3, 0.3, 0.2]), -]) +@pytest.mark.parametrize( + "Dist, prior", + [ + (dist.Bernoulli, 0.2), + (dist.Categorical, [0.2, 0.8]), + (dist.Categorical, [0.2, 0.3, 0.5]), + (dist.Categorical, [0.2, 0.3, 0.3, 0.2]), + (dist.OneHotCategorical, [0.2, 0.8]), + (dist.OneHotCategorical, [0.2, 0.3, 0.5]), + (dist.OneHotCategorical, [0.2, 0.3, 0.3, 0.2]), + ], +) def test_compute_marginals_single(Dist, prior): prior = torch.tensor(prior) - data = torch.tensor([0., 0.1, 0.2, 0.9, 1.0, 1.1]) + data = torch.tensor([0.0, 0.1, 0.2, 0.9, 1.0, 1.1]) @config_enumerate def model(): - locs = torch.tensor([-1., 0., 1., 2.]) + locs = torch.tensor([-1.0, 0.0, 1.0, 2.0]) x = pyro.sample("x", Dist(prior)) if Dist is dist.Bernoulli: x = x.long() elif Dist is dist.OneHotCategorical: x = x.max(-1)[1] with pyro.plate("data", len(data)): - pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) + pyro.sample("obs", dist.Normal(locs[x], 1.0), obs=data) # First compute marginals using an empty guide. def empty_guide(): @@ -3001,22 +3750,26 @@ def exact_guide(): assert_equal(grad(loss, [pyro.param("probs")])[0], torch.zeros_like(probs)) -@pytest.mark.parametrize('ok,enumerate_guide,num_particles,vectorize_particles', [ - (True, None, 1, False), - (False, "sequential", 1, False), - (False, "parallel", 1, False), - (False, None, 2, False), - (False, None, 2, True), -]) -def test_compute_marginals_restrictions(ok, enumerate_guide, num_particles, vectorize_particles): - +@pytest.mark.parametrize( + "ok,enumerate_guide,num_particles,vectorize_particles", + [ + (True, None, 1, False), + (False, "sequential", 1, False), + (False, "parallel", 1, False), + (False, None, 2, False), + (False, None, 2, True), + ], +) +def test_compute_marginals_restrictions( + ok, enumerate_guide, num_particles, vectorize_particles +): @config_enumerate def model(): w = pyro.sample("w", dist.Bernoulli(0.1)) x = pyro.sample("x", dist.Bernoulli(0.2)) y = pyro.sample("y", dist.Bernoulli(0.3)) z = pyro.sample("z", dist.Bernoulli(0.4)) - pyro.sample("obs", dist.Normal(0., 1.), obs=w + x + y + z) + pyro.sample("obs", dist.Normal(0.0, 1.0), obs=w + x + y + z) @config_enumerate(default=enumerate_guide) def guide(): @@ -3024,9 +3777,11 @@ def guide(): pyro.sample("y", dist.Bernoulli(0.7)) # Check that the ELBO works fine. - elbo = TraceEnum_ELBO(max_plate_nesting=0, - num_particles=num_particles, - vectorize_particles=vectorize_particles) + elbo = TraceEnum_ELBO( + max_plate_nesting=0, + num_particles=num_particles, + vectorize_particles=vectorize_particles, + ) loss = elbo.loss(model, guide) assert not torch_isnan(loss) @@ -3038,9 +3793,8 @@ def guide(): elbo.compute_marginals(model, guide) -@pytest.mark.parametrize('size', [1, 2, 3, 4, 10, 20, _skip_cuda(30)]) +@pytest.mark.parametrize("size", [1, 2, 3, 4, 10, 20, _skip_cuda(30)]) def test_compute_marginals_hmm(size): - @config_enumerate def model(data): transition_probs = torch.tensor([[0.75, 0.25], [0.25, 0.75]]) @@ -3049,10 +3803,15 @@ def model(data): for i in pyro.markov(range(len(data) + 1)): if i < len(data): x = pyro.sample("x_{}".format(i), dist.Categorical(transition_probs[x])) - pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i]) + pyro.sample( + "y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i] + ) else: - pyro.sample("x_{}".format(i), dist.Categorical(transition_probs[x]), - obs=torch.tensor(1)) + pyro.sample( + "x_{}".format(i), + dist.Categorical(transition_probs[x]), + obs=torch.tensor(1), + ) def guide(data): pass @@ -3076,7 +3835,6 @@ def guide(data): @pytest.mark.parametrize("observed", ["", "a", "b", "ab"]) def test_marginals_2678(observed): - @config_enumerate def model(a=None, b=None): a = pyro.sample("a", dist.Bernoulli(0.75), obs=a) @@ -3085,39 +3843,41 @@ def model(a=None, b=None): def guide(a=None, b=None): pass - kwargs = {name: torch.tensor(1.) for name in observed} + kwargs = {name: torch.tensor(1.0) for name in observed} elbo = TraceEnum_ELBO(strict_enumeration_warning=False) elbo.compute_marginals(model, guide, **kwargs) -@pytest.mark.parametrize("data", [ - [None, None], - [torch.tensor(0.), None], - [None, torch.tensor(0.)], - [torch.tensor(0.), torch.tensor(0)], -]) +@pytest.mark.parametrize( + "data", + [ + [None, None], + [torch.tensor(0.0), None], + [None, torch.tensor(0.0)], + [torch.tensor(0.0), torch.tensor(0)], + ], +) def test_backwardsample_posterior_smoke(data): - @config_enumerate def model(data): xs = list(data) zs = [] for i in range(2): K = i + 2 # number of mixture components - zs.append(pyro.sample("z_{}".format(i), - dist.Categorical(torch.ones(K)))) + zs.append(pyro.sample("z_{}".format(i), dist.Categorical(torch.ones(K)))) if i == 0: loc = pyro.param("loc", torch.randn(K))[zs[i]] - xs[i] = pyro.sample("x_{}".format(i), - dist.Normal(loc, 1.), obs=data[i]) + xs[i] = pyro.sample( + "x_{}".format(i), dist.Normal(loc, 1.0), obs=data[i] + ) elif i == 1: logits = pyro.param("logits", torch.randn(K, 2))[zs[i]] - xs[i] = pyro.sample("x_{}".format(i), - dist.Categorical(logits=logits), - obs=data[i]) + xs[i] = pyro.sample( + "x_{}".format(i), dist.Categorical(logits=logits), obs=data[i] + ) z12 = zs[0] + 2 * zs[1] - pyro.sample("z_12", dist.Categorical(torch.arange(6.)), obs=z12) + pyro.sample("z_12", dist.Categorical(torch.arange(6.0)), obs=z12) return xs, zs def guide(data): @@ -3180,22 +3940,26 @@ def guide(data): assert abs(expected - actual) < 0.05 -@pytest.mark.parametrize('ok,enumerate_guide,num_particles,vectorize_particles', [ - (True, None, 1, False), - (False, "sequential", 1, False), - (False, "parallel", 1, False), - (False, None, 2, False), - (False, None, 2, True), -]) -def test_backwardsample_posterior_restrictions(ok, enumerate_guide, num_particles, vectorize_particles): - +@pytest.mark.parametrize( + "ok,enumerate_guide,num_particles,vectorize_particles", + [ + (True, None, 1, False), + (False, "sequential", 1, False), + (False, "parallel", 1, False), + (False, None, 2, False), + (False, None, 2, True), + ], +) +def test_backwardsample_posterior_restrictions( + ok, enumerate_guide, num_particles, vectorize_particles +): @config_enumerate def model(): w = pyro.sample("w", dist.Bernoulli(0.1)) x = pyro.sample("x", dist.Bernoulli(0.2)) y = pyro.sample("y", dist.Bernoulli(0.3)) z = pyro.sample("z", dist.Bernoulli(0.4)) - pyro.sample("obs", dist.Normal(0., 1.), obs=w + x + y + z) + pyro.sample("obs", dist.Normal(0.0, 1.0), obs=w + x + y + z) return w, x, y, z @config_enumerate(default=enumerate_guide) @@ -3204,9 +3968,11 @@ def guide(): pyro.sample("y", dist.Bernoulli(0.7)) # Check that the ELBO works fine. - elbo = TraceEnum_ELBO(max_plate_nesting=0, - num_particles=num_particles, - vectorize_particles=vectorize_particles) + elbo = TraceEnum_ELBO( + max_plate_nesting=0, + num_particles=num_particles, + vectorize_particles=vectorize_particles, + ) loss = elbo.loss(model, guide) assert not torch_isnan(loss) @@ -3224,24 +3990,27 @@ def guide(): @pytest.mark.parametrize("num_samples", [10000, 100000]) def test_vectorized_importance(num_samples): - pyro.param("model_probs_a", - torch.tensor([0.45, 0.55]), - constraint=constraints.simplex) - pyro.param("model_probs_b", - torch.tensor([0.6, 0.4]), - constraint=constraints.simplex) - pyro.param("model_probs_c", - torch.tensor([[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], - [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]]), - constraint=constraints.simplex) - - pyro.param("guide_probs_a", - torch.tensor([0.33, 0.67]), - constraint=constraints.simplex) - - pyro.param("guide_probs_b", - torch.tensor([0.8, 0.2]), - constraint=constraints.simplex) + pyro.param( + "model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_b", torch.tensor([0.6, 0.4]), constraint=constraints.simplex + ) + pyro.param( + "model_probs_c", + torch.tensor( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ), + constraint=constraints.simplex, + ) + + pyro.param( + "guide_probs_a", torch.tensor([0.33, 0.67]), constraint=constraints.simplex + ) + + pyro.param( + "guide_probs_b", torch.tensor([0.8, 0.2]), constraint=constraints.simplex + ) data = torch.tensor([[0, 1], [0, 2]]) @@ -3253,8 +4022,7 @@ def model(): with pyro.plate("outer", 2): b = pyro.sample("b", dist.Categorical(probs_b)) with pyro.plate("inner", 2): - pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), - obs=data) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) def guide(): probs_a = pyro.param("guide_probs_a") @@ -3263,9 +4031,13 @@ def guide(): with pyro.plate("outer", 2): pyro.sample("b", dist.Categorical(probs_b)) - vectorized_weights, _, _ = vectorized_importance_weights(model, guide, max_plate_nesting=4, num_samples=num_samples) + vectorized_weights, _, _ = vectorized_importance_weights( + model, guide, max_plate_nesting=4, num_samples=num_samples + ) - elbo = Trace_ELBO(vectorize_particles=True, num_particles=num_samples).loss(model, guide) + elbo = Trace_ELBO(vectorize_particles=True, num_particles=num_samples).loss( + model, guide + ) assert_equal(vectorized_weights.sum().item() / num_samples, -elbo, prec=0.02) @@ -3283,23 +4055,27 @@ def test_multi_dependence_enumeration(): @config_enumerate def model(N=1): - with pyro.plate('data_plate', N, dim=-2): - mixing_weights = pyro.param('pi', torch.ones(K) / K, constraint=constraints.simplex) - means = pyro.sample('mu', dist.Normal(torch.zeros(K, d), torch.ones(K, d)).to_event(2)) + with pyro.plate("data_plate", N, dim=-2): + mixing_weights = pyro.param( + "pi", torch.ones(K) / K, constraint=constraints.simplex + ) + means = pyro.sample( + "mu", dist.Normal(torch.zeros(K, d), torch.ones(K, d)).to_event(2) + ) - with pyro.plate('observations', N_obs, dim=-1): - s = pyro.sample('s', dist.Categorical(mixing_weights)) + with pyro.plate("observations", N_obs, dim=-1): + s = pyro.sample("s", dist.Categorical(mixing_weights)) - pyro.sample('x', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1)) - pyro.sample('y', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1)) + pyro.sample("x", dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1)) + pyro.sample("y", dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1)) - x = poutine.trace(model).get_trace(N=2).nodes['x']['value'] + x = poutine.trace(model).get_trace(N=2).nodes["x"]["value"] pyro.clear_param_store() - conditioned_model = pyro.condition(model, data={'x': x}) - guide = infer.autoguide.AutoDelta(poutine.block(conditioned_model, hide=['s'])) + conditioned_model = pyro.condition(model, data={"x": x}) + guide = infer.autoguide.AutoDelta(poutine.block(conditioned_model, hide=["s"])) elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) elbo.loss_and_grads(conditioned_model, guide, x.size(0)) - assert pyro.get_param_store()._params['pi'].grad is not None + assert pyro.get_param_store()._params["pi"].grad is not None diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index 45ca0d0932..f38bc2e214 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -35,20 +35,27 @@ def DiffTrace_ELBO(*args, **kwargs): return Trace_ELBO(*args, **kwargs).differentiable_loss -@pytest.mark.parametrize("scale", [1., 2.], ids=["unscaled", "scaled"]) -@pytest.mark.parametrize("reparameterized,has_rsample", - [(True, None), (True, False), (True, True), (False, None)], - ids=["reparam", "reparam-False", "reparam-True", "nonreparam"]) +@pytest.mark.parametrize("scale", [1.0, 2.0], ids=["unscaled", "scaled"]) +@pytest.mark.parametrize( + "reparameterized,has_rsample", + [(True, None), (True, False), (True, True), (False, None)], + ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], +) @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo,local_samples", [ - (Trace_ELBO, False), - (DiffTrace_ELBO, False), - (TraceGraph_ELBO, False), - (TraceMeanField_ELBO, False), - (TraceEnum_ELBO, False), - (TraceEnum_ELBO, True), -]) -def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale): +@pytest.mark.parametrize( + "Elbo,local_samples", + [ + (Trace_ELBO, False), + (DiffTrace_ELBO, False), + (TraceGraph_ELBO, False), + (TraceMeanField_ELBO, False), + (TraceEnum_ELBO, False), + (TraceEnum_ELBO, True), + ], +) +def test_subsample_gradient( + Elbo, reparameterized, has_rsample, subsample, local_samples, scale +): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) @@ -80,30 +87,48 @@ def guide(subsample): num_particles = 1 optim = Adam({"lr": 0.1}) - elbo = Elbo(max_plate_nesting=1, # set this to ensure rng agrees across runs - num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=False) + elbo = Elbo( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) inference = SVI(model, guide, optim, loss=elbo) with xfail_if_not_implemented(): if subsample_size == 1: - inference.loss_and_grads(model, guide, subsample=torch.tensor([0], dtype=torch.long)) - inference.loss_and_grads(model, guide, subsample=torch.tensor([1], dtype=torch.long)) + inference.loss_and_grads( + model, guide, subsample=torch.tensor([0], dtype=torch.long) + ) + inference.loss_and_grads( + model, guide, subsample=torch.tensor([1], dtype=torch.long) + ) else: - inference.loss_and_grads(model, guide, subsample=torch.tensor([0, 1], dtype=torch.long)) + inference.loss_and_grads( + model, guide, subsample=torch.tensor([0, 1], dtype=torch.long) + ) params = dict(pyro.get_param_store().named_parameters()) normalizer = 2 if subsample else 1 - actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()} - - expected_grads = {'loc': scale * np.array([0.5, -2.0]), 'scale': scale * np.array([2.0])} + actual_grads = { + name: param.grad.detach().cpu().numpy() / normalizer + for name, param in params.items() + } + + expected_grads = { + "loc": scale * np.array([0.5, -2.0]), + "scale": scale * np.array([2.0]), + } for name in sorted(params): - logger.info('expected {} = {}'.format(name, expected_grads[name])) - logger.info('actual {} = {}'.format(name, actual_grads[name])) + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision) -@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +@pytest.mark.parametrize( + "reparameterized", [True, False], ids=["reparam", "nonreparam"] +) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO] +) def test_plate(Elbo, reparameterized): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -125,7 +150,7 @@ def model(): def guide(): loc = pyro.param("loc", torch.zeros(len(data))) - scale = pyro.param("scale", torch.tensor([1.])) + scale = pyro.param("scale", torch.tensor([1.0])) pyro.sample("nuisance_c", Normal(4, 5)) with pyro.plate("particles", num_particles, dim=-2): @@ -139,18 +164,24 @@ def guide(): inference = SVI(model, guide, optim, loss=elbo) inference.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) - actual_grads = {name: param.grad.detach().cpu().numpy() / num_particles - for name, param in params.items()} + actual_grads = { + name: param.grad.detach().cpu().numpy() / num_particles + for name, param in params.items() + } - expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} + expected_grads = {"loc": np.array([0.5, -2.0]), "scale": np.array([2.0])} for name in sorted(params): - logger.info('expected {} = {}'.format(name, expected_grads[name])) - logger.info('actual {} = {}'.format(name, actual_grads[name])) + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision) -@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +@pytest.mark.parametrize( + "reparameterized", [True, False], ids=["reparam", "nonreparam"] +) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO] +) def test_plate_elbo_vectorized_particles(Elbo, reparameterized): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -171,7 +202,7 @@ def model(): def guide(): loc = pyro.param("loc", torch.zeros(len(data))) - scale = pyro.param("scale", torch.tensor([1.])) + scale = pyro.param("scale", torch.tensor([1.0])) pyro.sample("nuisance_c", Normal(4, 5)) with pyro.plate("data", len(data)): @@ -180,38 +211,54 @@ def guide(): pyro.sample("nuisance_a", Normal(0, 1)) optim = Adam({"lr": 0.1}) - loss = Elbo(num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=False) + loss = Elbo( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) inference = SVI(model, guide, optim, loss=loss) inference.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) - actual_grads = {name: param.grad.detach().cpu().numpy() - for name, param in params.items()} + actual_grads = { + name: param.grad.detach().cpu().numpy() for name, param in params.items() + } - expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} + expected_grads = {"loc": np.array([0.5, -2.0]), "scale": np.array([2.0])} for name in sorted(params): - logger.info('expected {} = {}'.format(name, expected_grads[name])) - logger.info('actual {} = {}'.format(name, actual_grads[name])) + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision) -@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) +@pytest.mark.parametrize( + "reparameterized", [True, False], ids=["reparam", "nonreparam"] +) @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - TraceMeanField_ELBO, - xfail_param(JitTrace_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), - xfail_param(JitTraceGraph_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), - xfail_param(JitTraceEnum_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), - xfail_param(JitTraceMeanField_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceMeanField_ELBO, + xfail_param( + JitTrace_ELBO, + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor", + ), + xfail_param( + JitTraceGraph_ELBO, + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor", + ), + xfail_param( + JitTraceEnum_ELBO, + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor", + ), + xfail_param( + JitTraceMeanField_ELBO, + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor", + ), + ], +) def test_subsample_gradient_sequential(Elbo, reparameterized, subsample): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -241,13 +288,15 @@ def guide(): inference.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) - actual_grads = {name: param.grad.detach().cpu().numpy() / iters - for name, param in params.items()} + actual_grads = { + name: param.grad.detach().cpu().numpy() / iters + for name, param in params.items() + } - expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} + expected_grads = {"loc": np.array([0.5, -2.0]), "scale": np.array([2.0])} for name in sorted(params): - logger.info('expected {} = {}'.format(name, expected_grads[name])) - logger.info('actual {} = {}'.format(name, actual_grads[name])) + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision) @@ -257,7 +306,7 @@ def test_collapse_beta_binomial(): pytest.importorskip("funsor") total_count = 10 - data = torch.tensor(3.) + data = torch.tensor(3.0) def model1(): c1 = pyro.param("c1", torch.tensor(0.5), constraint=constraints.positive) @@ -269,8 +318,7 @@ def model1(): def model2(): c1 = pyro.param("c1", torch.tensor(0.5), constraint=constraints.positive) c0 = pyro.param("c0", torch.tensor(1.5), constraint=constraints.positive) - pyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), - obs=data) + pyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), obs=data) trace1 = poutine.trace(model1).get_trace() trace2 = poutine.trace(model2).get_trace() diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index c87a19a0b8..a598ea9c54 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -54,23 +54,20 @@ def param_abs_error(name, target): @pytest.mark.stage("integration", "integration_batch_1") class NormalNormalTests(TestCase): - def setUp(self): # normal-normal; known covariance - self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior - self.loc0 = torch.tensor([0.0, 0.5]) # prior mean + self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior + self.loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise self.lam = torch.tensor([6.0, 4.0]) - self.data = torch.tensor([[-0.1, 0.3], - [0.00, 0.4], - [0.20, 0.5], - [0.10, 0.7]]) + self.data = torch.tensor([[-0.1, 0.3], [0.00, 0.4], [0.20, 0.5], [0.10, 0.7]]) self.n_data = torch.tensor([float(len(self.data))]) self.data_sum = self.data.sum(0) self.analytic_lam_n = self.lam0 + self.n_data.expand_as(self.lam) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) - self.analytic_loc_n = self.data_sum * (self.lam / self.analytic_lam_n) +\ - self.loc0 * (self.lam0 / self.analytic_lam_n) + self.analytic_loc_n = self.data_sum * ( + self.lam / self.analytic_lam_n + ) + self.loc0 * (self.lam0 / self.analytic_lam_n) self.batch_size = 4 self.sample_batch_size = 2 @@ -81,16 +78,24 @@ def test_elbo_analytic_kl(self): self.do_elbo_test(True, 3000, TraceMeanField_ELBO()) def test_elbo_tail_adaptive(self): - self.do_elbo_test(True, 3000, TraceTailAdaptive_ELBO(num_particles=10, vectorize_particles=True)) + self.do_elbo_test( + True, + 3000, + TraceTailAdaptive_ELBO(num_particles=10, vectorize_particles=True), + ) def test_elbo_nonreparameterized(self): self.do_elbo_test(False, 15000, Trace_ELBO()) def test_renyi_reparameterized(self): - self.do_elbo_test(True, 2500, RenyiELBO(num_particles=3, vectorize_particles=False)) + self.do_elbo_test( + True, 2500, RenyiELBO(num_particles=3, vectorize_particles=False) + ) def test_renyi_nonreparameterized(self): - self.do_elbo_test(False, 7500, RenyiELBO(num_particles=3, vectorize_particles=True)) + self.do_elbo_test( + False, 7500, RenyiELBO(num_particles=3, vectorize_particles=True) + ) def test_rws_reparameterized(self): self.do_elbo_test(True, 2500, ReweightedWakeSleep(num_particles=3)) @@ -101,46 +106,60 @@ def test_rws_nonreparameterized(self): def test_mmd_vectorized(self): z_size = self.loc0.shape[0] self.do_fit_prior_test( - True, 1000, Trace_MMD( + True, + 1000, + Trace_MMD( kernel=kernels.RBF( z_size, - lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)) - ), vectorize_particles=True, num_particles=100 - ) + lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)), + ), + vectorize_particles=True, + num_particles=100, + ), ) def test_mmd_nonvectorized(self): z_size = self.loc0.shape[0] self.do_fit_prior_test( - True, 100, Trace_MMD( + True, + 100, + Trace_MMD( kernel=kernels.RBF( z_size, - lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)) - ), vectorize_particles=False, num_particles=100 - ), lr=0.0146 + lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)), + ), + vectorize_particles=False, + num_particles=100, + ), + lr=0.0146, ) def do_elbo_test(self, reparameterized, n_steps, loss): pyro.clear_param_store() def model(): - loc_latent = pyro.sample("loc_latent", - dist.Normal(self.loc0, torch.pow(self.lam0, -0.5)) - .to_event(1)) - with pyro.plate('data', self.batch_size): - pyro.sample("obs", - dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), - obs=self.data) + loc_latent = pyro.sample( + "loc_latent", + dist.Normal(self.loc0, torch.pow(self.lam0, -0.5)).to_event(1), + ) + with pyro.plate("data", self.batch_size): + pyro.sample( + "obs", + dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), + obs=self.data, + ) return loc_latent def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.detach() + 0.134) - log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.data.detach() - 0.14) + log_sig_q = pyro.param( + "log_sig_q", self.analytic_log_sig_n.data.detach() - 0.14 + ) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("loc_latent", Normal(loc_q, sig_q).to_event(1)) - adam = optim.Adam({"lr": .001}) + adam = optim.Adam({"lr": 0.001}) svi = SVI(model, guide, adam, loss=loss) for k in range(n_steps): @@ -156,25 +175,31 @@ def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.00 pyro.clear_param_store() def model(): - with pyro.plate('samples', self.sample_batch_size): + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "loc_latent", dist.Normal( - torch.stack([self.loc0]*self.sample_batch_size, dim=0), - torch.stack([torch.pow(self.lam0, -0.5)]*self.sample_batch_size, dim=0) - ).to_event(1) + "loc_latent", + dist.Normal( + torch.stack([self.loc0] * self.sample_batch_size, dim=0), + torch.stack( + [torch.pow(self.lam0, -0.5)] * self.sample_batch_size, dim=0 + ), + ).to_event(1), ) def guide(): loc_q = pyro.param("loc_q", self.loc0.detach() + 0.134) - log_sig_q = pyro.param("log_sig_q", -0.5*torch.log(self.lam0).data.detach() - 0.14) + log_sig_q = pyro.param( + "log_sig_q", -0.5 * torch.log(self.lam0).data.detach() - 0.14 + ) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - with pyro.plate('samples', self.sample_batch_size): + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "loc_latent", Normal( - torch.stack([loc_q]*self.sample_batch_size, dim=0), - torch.stack([sig_q]*self.sample_batch_size, dim=0) - ).to_event(1) + "loc_latent", + Normal( + torch.stack([loc_q] * self.sample_batch_size, dim=0), + torch.stack([sig_q] * self.sample_batch_size, dim=0), + ).to_event(1), ) adam = optim.Adam({"lr": lr}) @@ -185,15 +210,24 @@ def guide(): svi.step() if debug: loc_error = param_mse("loc_q", self.loc0) - log_sig_error = param_mse("log_sig_q", -0.5*torch.log(self.lam0)) + log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0)) with torch.no_grad(): if k == 0: - avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide) + ( + avg_loglikelihood, + avg_penalty, + ) = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) - loglikelihood, penalty = loss._differentiable_loss_parts(model, guide) - avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood) - avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty) + loglikelihood, penalty = loss._differentiable_loss_parts( + model, guide + ) + avg_loglikelihood = alpha * avg_loglikelihood + ( + 1 - alpha + ) * torch_item(loglikelihood) + avg_penalty = alpha * avg_penalty + (1 - alpha) * torch_item( + penalty + ) if k % 100 == 0: print(loc_error, log_sig_error) print(avg_loglikelihood, avg_penalty) @@ -217,29 +251,25 @@ def do_test_fixedness(self, fixed_parts): pyro.clear_param_store() def model(): - alpha_p_log = pyro.param( - "alpha_p_log", self.alpha_p_log_0.clone()) - beta_p_log = pyro.param( - "beta_p_log", self.beta_p_log_0.clone()) + alpha_p_log = pyro.param("alpha_p_log", self.alpha_p_log_0.clone()) + beta_p_log = pyro.param("beta_p_log", self.beta_p_log_0.clone()) alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log) lambda_latent = pyro.sample("lambda_latent", dist.Gamma(alpha_p, beta_p)) pyro.sample("obs", dist.Poisson(lambda_latent), obs=self.data) return lambda_latent def guide(): - alpha_q_log = pyro.param( - "alpha_q_log", self.alpha_q_log_0.clone()) - beta_q_log = pyro.param( - "beta_q_log", self.beta_q_log_0.clone()) + alpha_q_log = pyro.param("alpha_q_log", self.alpha_q_log_0.clone()) + beta_q_log = pyro.param("beta_q_log", self.beta_q_log_0.clone()) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", dist.Gamma(alpha_q, beta_q)) def per_param_args(param_name): - if 'model' in fixed_parts and 'p_' in param_name: - return {'lr': 0.0} - if 'guide' in fixed_parts and 'q_' in param_name: - return {'lr': 0.0} - return {'lr': 0.01} + if "model" in fixed_parts and "p_" in param_name: + return {"lr": 0.0} + if "guide" in fixed_parts and "q_" in param_name: + return {"lr": 0.0} + return {"lr": 0.01} adam = optim.Adam(per_param_args) svi = SVI(model, guide, adam, loss=Trace_ELBO()) @@ -247,14 +277,18 @@ def per_param_args(param_name): for _ in range(3): svi.step() - model_unchanged = (torch.equal(pyro.param("alpha_p_log").data, self.alpha_p_log_0)) and\ - (torch.equal(pyro.param("beta_p_log").data, self.beta_p_log_0)) - guide_unchanged = (torch.equal(pyro.param("alpha_q_log").data, self.alpha_q_log_0)) and\ - (torch.equal(pyro.param("beta_q_log").data, self.beta_q_log_0)) + model_unchanged = ( + torch.equal(pyro.param("alpha_p_log").data, self.alpha_p_log_0) + ) and (torch.equal(pyro.param("beta_p_log").data, self.beta_p_log_0)) + guide_unchanged = ( + torch.equal(pyro.param("alpha_q_log").data, self.alpha_q_log_0) + ) and (torch.equal(pyro.param("beta_q_log").data, self.beta_q_log_0)) model_changed = not model_unchanged guide_changed = not guide_unchanged - error = ('model' in fixed_parts and model_changed) or ('guide' in fixed_parts and guide_changed) - return (not error) + error = ("model" in fixed_parts and model_changed) or ( + "guide" in fixed_parts and guide_changed + ) + return not error def test_model_fixed(self): assert self.do_test_fixedness(fixed_parts=["model"]) @@ -305,12 +339,18 @@ def test_rws_nonreparameterized(self): def test_mmd_vectorized(self): z_size = 1 self.do_fit_prior_test( - True, 500, Trace_MMD( + True, + 500, + Trace_MMD( kernel=kernels.RBF( z_size, - lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)) - ), vectorize_particles=True, num_particles=100 - ), debug=True, lr=0.09 + lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)), + ), + vectorize_particles=True, + num_particles=100, + ), + debug=True, + lr=0.09, ) def do_elbo_test(self, reparameterized, n_steps, loss): @@ -324,47 +364,79 @@ def model(): return lambda_latent def guide(): - alpha_q = pyro.param("alpha_q", self.alpha_n.detach() + math.exp(0.17), - constraint=constraints.positive) - beta_q = pyro.param("beta_q", self.beta_n.detach() / math.exp(0.143), - constraint=constraints.positive) + alpha_q = pyro.param( + "alpha_q", + self.alpha_n.detach() + math.exp(0.17), + constraint=constraints.positive, + ) + beta_q = pyro.param( + "beta_q", + self.beta_n.detach() / math.exp(0.143), + constraint=constraints.positive, + ) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) - adam = optim.Adam({"lr": .0002, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.0002, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss) for k in range(n_steps): svi.step() - assert_equal(pyro.param("alpha_q"), self.alpha_n, prec=0.2, msg='{} vs {}'.format( - pyro.param("alpha_q").detach().cpu().numpy(), self.alpha_n.detach().cpu().numpy())) - assert_equal(pyro.param("beta_q"), self.beta_n, prec=0.15, msg='{} vs {}'.format( - pyro.param("beta_q").detach().cpu().numpy(), self.beta_n.detach().cpu().numpy())) + assert_equal( + pyro.param("alpha_q"), + self.alpha_n, + prec=0.2, + msg="{} vs {}".format( + pyro.param("alpha_q").detach().cpu().numpy(), + self.alpha_n.detach().cpu().numpy(), + ), + ) + assert_equal( + pyro.param("beta_q"), + self.beta_n, + prec=0.15, + msg="{} vs {}".format( + pyro.param("beta_q").detach().cpu().numpy(), + self.beta_n.detach().cpu().numpy(), + ), + ) def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.0002): pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): - with pyro.plate('samples', self.sample_batch_size): + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "lambda_latent", Gamma( - torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size), - torch.stack([torch.stack([self.beta0])]*self.sample_batch_size) - ).to_event(1) + "lambda_latent", + Gamma( + torch.stack( + [torch.stack([self.alpha0])] * self.sample_batch_size + ), + torch.stack( + [torch.stack([self.beta0])] * self.sample_batch_size + ), + ).to_event(1), ) def guide(): - alpha_q = pyro.param("alpha_q", self.alpha0.detach() + math.exp(0.17), - constraint=constraints.positive) - beta_q = pyro.param("beta_q", self.beta0.detach() / math.exp(0.143), - constraint=constraints.positive) - with pyro.plate('samples', self.sample_batch_size): + alpha_q = pyro.param( + "alpha_q", + self.alpha0.detach() + math.exp(0.17), + constraint=constraints.positive, + ) + beta_q = pyro.param( + "beta_q", + self.beta0.detach() / math.exp(0.143), + constraint=constraints.positive, + ) + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "lambda_latent", Gamma( - torch.stack([torch.stack([alpha_q])]*self.sample_batch_size), - torch.stack([torch.stack([beta_q])]*self.sample_batch_size) - ).to_event(1) + "lambda_latent", + Gamma( + torch.stack([torch.stack([alpha_q])] * self.sample_batch_size), + torch.stack([torch.stack([beta_q])] * self.sample_batch_size), + ).to_event(1), ) adam = optim.Adam({"lr": lr, "betas": (0.97, 0.999)}) @@ -378,39 +450,69 @@ def guide(): beta_error = param_mse("beta_q", self.beta0) with torch.no_grad(): if k == 0: - avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide, (), {}) + ( + avg_loglikelihood, + avg_penalty, + ) = loss._differentiable_loss_parts(model, guide, (), {}) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) - loglikelihood, penalty = loss._differentiable_loss_parts(model, guide, (), {}) - avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood) - avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty) + loglikelihood, penalty = loss._differentiable_loss_parts( + model, guide, (), {} + ) + avg_loglikelihood = alpha * avg_loglikelihood + ( + 1 - alpha + ) * torch_item(loglikelihood) + avg_penalty = alpha * avg_penalty + (1 - alpha) * torch_item( + penalty + ) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) print() - assert_equal(pyro.param("alpha_q"), self.alpha0, prec=0.2, msg='{} vs {}'.format( - pyro.param("alpha_q").detach().cpu().numpy(), self.alpha0.detach().cpu().numpy())) - assert_equal(pyro.param("beta_q"), self.beta0, prec=0.15, msg='{} vs {}'.format( - pyro.param("beta_q").detach().cpu().numpy(), self.beta0.detach().cpu().numpy())) + assert_equal( + pyro.param("alpha_q"), + self.alpha0, + prec=0.2, + msg="{} vs {}".format( + pyro.param("alpha_q").detach().cpu().numpy(), + self.alpha0.detach().cpu().numpy(), + ), + ) + assert_equal( + pyro.param("beta_q"), + self.beta0, + prec=0.15, + msg="{} vs {}".format( + pyro.param("beta_q").detach().cpu().numpy(), + self.beta0.detach().cpu().numpy(), + ), + ) @pytest.mark.stage("integration", "integration_batch_1") -@pytest.mark.parametrize('elbo_impl', [ - xfail_param(JitTrace_ELBO, reason="incorrect gradients", run=False), - xfail_param(JitTraceGraph_ELBO, reason="incorrect gradients", run=False), - xfail_param(JitTraceEnum_ELBO, reason="incorrect gradients", run=False), - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - RenyiELBO, - ReweightedWakeSleep -]) -@pytest.mark.parametrize('gamma_dist,n_steps', [ - (dist.Gamma, 5000), - (fakes.NonreparameterizedGamma, 10000), - (ShapeAugmentedGamma, 5000), -], ids=['reparam', 'nonreparam', 'rsvi']) +@pytest.mark.parametrize( + "elbo_impl", + [ + xfail_param(JitTrace_ELBO, reason="incorrect gradients", run=False), + xfail_param(JitTraceGraph_ELBO, reason="incorrect gradients", run=False), + xfail_param(JitTraceEnum_ELBO, reason="incorrect gradients", run=False), + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + RenyiELBO, + ReweightedWakeSleep, + ], +) +@pytest.mark.parametrize( + "gamma_dist,n_steps", + [ + (dist.Gamma, 5000), + (fakes.NonreparameterizedGamma, 10000), + (ShapeAugmentedGamma, 5000), + ], + ids=["reparam", "nonreparam", "rsvi"], +) def test_exponential_gamma(gamma_dist, n_steps, elbo_impl): pyro.clear_param_store() @@ -431,18 +533,31 @@ def model(alpha0, beta0, alpha_n, beta_n): return lambda_latent def guide(alpha0, beta0, alpha_n, beta_n): - alpha_q = pyro.param("alpha_q", alpha_n * math.exp(0.17), constraint=constraints.positive) - beta_q = pyro.param("beta_q", beta_n / math.exp(0.143), constraint=constraints.positive) + alpha_q = pyro.param( + "alpha_q", alpha_n * math.exp(0.17), constraint=constraints.positive + ) + beta_q = pyro.param( + "beta_q", beta_n / math.exp(0.143), constraint=constraints.positive + ) pyro.sample("lambda_latent", gamma_dist(alpha_q, beta_q)) - adam = optim.Adam({"lr": .0003, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.0003, "betas": (0.97, 0.999)}) if elbo_impl is RenyiELBO: - elbo = elbo_impl(alpha=0.2, num_particles=3, max_plate_nesting=1, strict_enumeration_warning=False) + elbo = elbo_impl( + alpha=0.2, + num_particles=3, + max_plate_nesting=1, + strict_enumeration_warning=False, + ) elif elbo_impl is ReweightedWakeSleep: if gamma_dist is ShapeAugmentedGamma: - pytest.xfail(reason="ShapeAugmentedGamma not suported for ReweightedWakeSleep") + pytest.xfail( + reason="ShapeAugmentedGamma not suported for ReweightedWakeSleep" + ) else: - elbo = elbo_impl(num_particles=3, max_plate_nesting=1, strict_enumeration_warning=False) + elbo = elbo_impl( + num_particles=3, max_plate_nesting=1, strict_enumeration_warning=False + ) else: elbo = elbo_impl(max_plate_nesting=1, strict_enumeration_warning=False) svi = SVI(model, guide, adam, loss=elbo) @@ -451,10 +566,22 @@ def guide(alpha0, beta0, alpha_n, beta_n): for k in range(n_steps): svi.step(alpha0, beta0, alpha_n, beta_n) - assert_equal(pyro.param("alpha_q"), alpha_n, prec=prec, msg='{} vs {}'.format( - pyro.param("alpha_q").detach().cpu().numpy(), alpha_n.detach().cpu().numpy())) - assert_equal(pyro.param("beta_q"), beta_n, prec=prec, msg='{} vs {}'.format( - pyro.param("beta_q").detach().cpu().numpy(), beta_n.detach().cpu().numpy())) + assert_equal( + pyro.param("alpha_q"), + alpha_n, + prec=prec, + msg="{} vs {}".format( + pyro.param("alpha_q").detach().cpu().numpy(), alpha_n.detach().cpu().numpy() + ), + ) + assert_equal( + pyro.param("beta_q"), + beta_n, + prec=prec, + msg="{} vs {}".format( + pyro.param("beta_q").detach().cpu().numpy(), beta_n.detach().cpu().numpy() + ), + ) @pytest.mark.stage("integration", "integration_batch_2") @@ -483,13 +610,19 @@ def test_elbo_nonreparameterized(self): # this is used to detect bugs related to https://github.com/pytorch/pytorch/issues/9521 def test_elbo_reparameterized_vectorized(self): - self.do_elbo_test(True, 5000, Trace_ELBO(num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + True, + 5000, + Trace_ELBO(num_particles=2, vectorize_particles=True, max_plate_nesting=1), + ) # this is used to detect bugs related to https://github.com/pytorch/pytorch/issues/9521 def test_elbo_nonreparameterized_vectorized(self): - self.do_elbo_test(False, 5000, Trace_ELBO(num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + False, + 5000, + Trace_ELBO(num_particles=2, vectorize_particles=True, max_plate_nesting=1), + ) def test_renyi_reparameterized(self): self.do_elbo_test(True, 5000, RenyiELBO(num_particles=2)) @@ -498,12 +631,23 @@ def test_renyi_nonreparameterized(self): self.do_elbo_test(False, 5000, RenyiELBO(alpha=0.2, num_particles=2)) def test_renyi_reparameterized_vectorized(self): - self.do_elbo_test(True, 5000, RenyiELBO(num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + True, + 5000, + RenyiELBO(num_particles=2, vectorize_particles=True, max_plate_nesting=1), + ) def test_renyi_nonreparameterized_vectorized(self): - self.do_elbo_test(False, 5000, RenyiELBO(alpha=0.2, num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + False, + 5000, + RenyiELBO( + alpha=0.2, + num_particles=2, + vectorize_particles=True, + max_plate_nesting=1, + ), + ) def test_rws_reparameterized(self): self.do_elbo_test(True, 5000, ReweightedWakeSleep(num_particles=2)) @@ -512,22 +656,36 @@ def test_rws_nonreparameterized(self): self.do_elbo_test(False, 5000, ReweightedWakeSleep(num_particles=2)) def test_rws_reparameterized_vectorized(self): - self.do_elbo_test(True, 5000, ReweightedWakeSleep(num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + True, + 5000, + ReweightedWakeSleep( + num_particles=2, vectorize_particles=True, max_plate_nesting=1 + ), + ) def test_rws_nonreparameterized_vectorized(self): - self.do_elbo_test(False, 5000, ReweightedWakeSleep(num_particles=2, vectorize_particles=True, - max_plate_nesting=1)) + self.do_elbo_test( + False, + 5000, + ReweightedWakeSleep( + num_particles=2, vectorize_particles=True, max_plate_nesting=1 + ), + ) def test_mmd_vectorized(self): z_size = 1 self.do_fit_prior_test( - True, 2500, Trace_MMD( + True, + 2500, + Trace_MMD( kernel=kernels.RBF( z_size, - lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)) - ), vectorize_particles=True, num_particles=100 - ) + lengthscale=torch.sqrt(torch.tensor(z_size, dtype=torch.float)), + ), + vectorize_particles=True, + num_particles=100, + ), ) def do_elbo_test(self, reparameterized, n_steps, loss): @@ -541,14 +699,12 @@ def model(): return p_latent def guide(): - alpha_q_log = pyro.param("alpha_q_log", - self.log_alpha_n + 0.17) - beta_q_log = pyro.param("beta_q_log", - self.log_beta_n - 0.143) + alpha_q_log = pyro.param("alpha_q_log", self.log_alpha_n + 0.17) + beta_q_log = pyro.param("beta_q_log", self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", Beta(alpha_q, beta_q)) - adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=loss) for k in range(n_steps): @@ -564,29 +720,33 @@ def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False): Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta def model(): - with pyro.plate('samples', self.sample_batch_size): + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "p_latent", Beta( - torch.stack([torch.stack([self.alpha0])]*self.sample_batch_size), - torch.stack([torch.stack([self.beta0])]*self.sample_batch_size) - ).to_event(1) + "p_latent", + Beta( + torch.stack( + [torch.stack([self.alpha0])] * self.sample_batch_size + ), + torch.stack( + [torch.stack([self.beta0])] * self.sample_batch_size + ), + ).to_event(1), ) def guide(): - alpha_q_log = pyro.param("alpha_q_log", - torch.log(self.alpha0) + 0.17) - beta_q_log = pyro.param("beta_q_log", - torch.log(self.beta0) - 0.143) + alpha_q_log = pyro.param("alpha_q_log", torch.log(self.alpha0) + 0.17) + beta_q_log = pyro.param("beta_q_log", torch.log(self.beta0) - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) - with pyro.plate('samples', self.sample_batch_size): + with pyro.plate("samples", self.sample_batch_size): pyro.sample( - "p_latent", Beta( - torch.stack([torch.stack([alpha_q])]*self.sample_batch_size), - torch.stack([torch.stack([beta_q])]*self.sample_batch_size) - ).to_event(1) + "p_latent", + Beta( + torch.stack([torch.stack([alpha_q])] * self.sample_batch_size), + torch.stack([torch.stack([beta_q])] * self.sample_batch_size), + ).to_event(1), ) - adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=loss) alpha = 0.99 @@ -597,12 +757,21 @@ def guide(): beta_error = param_abs_error("beta_q_log", torch.log(self.beta0)) with torch.no_grad(): if k == 0: - avg_loglikelihood, avg_penalty = loss._differentiable_loss_parts(model, guide) + ( + avg_loglikelihood, + avg_penalty, + ) = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) - loglikelihood, penalty = loss._differentiable_loss_parts(model, guide) - avg_loglikelihood = alpha * avg_loglikelihood + (1-alpha) * torch_item(loglikelihood) - avg_penalty = alpha * avg_penalty + (1-alpha) * torch_item(penalty) + loglikelihood, penalty = loss._differentiable_loss_parts( + model, guide + ) + avg_loglikelihood = alpha * avg_loglikelihood + ( + 1 - alpha + ) * torch_item(loglikelihood) + avg_penalty = alpha * avg_penalty + (1 - alpha) * torch_item( + penalty + ) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) @@ -615,7 +784,6 @@ def guide(): class SafetyTests(TestCase): - def setUp(self): # normal-normal; known covariance def model_dup(): @@ -624,7 +792,9 @@ def model_dup(): def model_obs_dup(): pyro.sample("loc_q", dist.Normal(torch.zeros(1), torch.ones(1))) - pyro.sample("loc_q", dist.Normal(torch.zeros(1), torch.ones(1)), obs=torch.zeros(1)) + pyro.sample( + "loc_q", dist.Normal(torch.zeros(1), torch.ones(1)), obs=torch.zeros(1) + ) def model(): pyro.sample("loc_q", dist.Normal(torch.zeros(1), torch.ones(1))) @@ -642,7 +812,7 @@ def guide(): def test_duplicate_names(self): pyro.clear_param_store() - adam = optim.Adam({"lr": .001}) + adam = optim.Adam({"lr": 0.001}) svi = SVI(self.duplicate_model, self.guide, adam, loss=Trace_ELBO()) with pytest.raises(RuntimeError): @@ -651,7 +821,7 @@ def test_duplicate_names(self): def test_extra_samples(self): pyro.clear_param_store() - adam = optim.Adam({"lr": .001}) + adam = optim.Adam({"lr": 0.001}) svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO()) with pytest.warns(Warning): @@ -660,7 +830,7 @@ def test_extra_samples(self): def test_duplicate_obs_name(self): pyro.clear_param_store() - adam = optim.Adam({"lr": .001}) + adam = optim.Adam({"lr": 0.001}) svi = SVI(self.duplicate_obs, self.guide, adam, loss=Trace_ELBO()) with pytest.raises(RuntimeError): @@ -670,7 +840,6 @@ def test_duplicate_obs_name(self): @pytest.mark.stage("integration", "integration_batch_1") @pytest.mark.parametrize("prior_scale", [0, 1e-4]) def test_energy_distance_univariate(prior_scale): - def model(data): loc = pyro.sample("loc", dist.Normal(0, 100)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) @@ -678,12 +847,14 @@ def model(data): pyro.sample("obs", dist.Normal(loc, scale), obs=data) def guide(data): - loc_loc = pyro.param("loc_loc", torch.tensor(0.)) - loc_scale = pyro.param("loc_scale", torch.tensor(1.), - constraint=constraints.positive) - log_scale_loc = pyro.param("log_scale_loc", torch.tensor(0.)) - log_scale_scale = pyro.param("log_scale_scale", torch.tensor(1.), - constraint=constraints.positive) + loc_loc = pyro.param("loc_loc", torch.tensor(0.0)) + loc_scale = pyro.param( + "loc_scale", torch.tensor(1.0), constraint=constraints.positive + ) + log_scale_loc = pyro.param("log_scale_loc", torch.tensor(0.0)) + log_scale_scale = pyro.param( + "log_scale_scale", torch.tensor(1.0), constraint=constraints.positive + ) pyro.sample("loc", dist.Normal(loc_loc, loc_scale)) pyro.sample("scale", dist.LogNormal(log_scale_loc, log_scale_scale)) @@ -694,9 +865,14 @@ def guide(data): for step in range(2001): loss = svi.step(data) if step % 20 == 0: - logger.info("step {} loss = {:0.4g}, loc = {:0.4g}, scale = {:0.4g}" - .format(step, loss, pyro.param("loc_loc").item(), - pyro.param("log_scale_loc").exp().item())) + logger.info( + "step {} loss = {:0.4g}, loc = {:0.4g}, scale = {:0.4g}".format( + step, + loss, + pyro.param("loc_loc").item(), + pyro.param("log_scale_loc").exp().item(), + ) + ) expected_loc = data.mean() expected_scale = data.std() @@ -709,7 +885,6 @@ def guide(data): @pytest.mark.stage("integration", "integration_batch_1") @pytest.mark.parametrize("prior_scale", [0, 1]) def test_energy_distance_multivariate(prior_scale): - def model(data): loc = torch.zeros(2) cov = pyro.sample("cov", dist.Normal(0, 100).expand([2, 2]).to_event(2)) @@ -717,8 +892,9 @@ def model(data): pyro.sample("obs", dist.MultivariateNormal(loc, cov), obs=data) def guide(data): - scale_tril = pyro.param("scale_tril", torch.eye(2), - constraint=constraints.lower_cholesky) + scale_tril = pyro.param( + "scale_tril", torch.eye(2), constraint=constraints.lower_cholesky + ) pyro.sample("cov", dist.Delta(scale_tril @ scale_tril.t(), event_dim=2)) cov = torch.tensor([[1, 0.8], [0.8, 1]]) @@ -743,9 +919,9 @@ def test_reparam_stable(): @poutine.reparam(config={"dz": LatentStableReparam(), "y": LatentStableReparam()}) def model(): - stability = pyro.sample("stability", dist.Uniform(1., 2.)) - trans_skew = pyro.sample("trans_skew", dist.Uniform(-1., 1.)) - obs_skew = pyro.sample("obs_skew", dist.Uniform(-1., 1.)) + stability = pyro.sample("stability", dist.Uniform(1.0, 2.0)) + trans_skew = pyro.sample("trans_skew", dist.Uniform(-1.0, 1.0)) + obs_skew = pyro.sample("obs_skew", dist.Uniform(-1.0, 1.0)) scale = pyro.sample("scale", dist.Gamma(3, 1)) # We use separate plates because the .cumsum() op breaks independence. @@ -769,17 +945,19 @@ def test_sequential_plating_sum(): """Example from https://github.com/pyro-ppl/pyro/issues/2361""" def model(data): - x = pyro.sample('x', dist.Bernoulli(torch.tensor(0.5))) - for i in pyro.plate('data_plate', len(data)): - pyro.sample('data_{:d}'.format(i), - dist.Normal(x, scale=torch.tensor(0.1)), - obs=data[i]) + x = pyro.sample("x", dist.Bernoulli(torch.tensor(0.5))) + for i in pyro.plate("data_plate", len(data)): + pyro.sample( + "data_{:d}".format(i), + dist.Normal(x, scale=torch.tensor(0.1)), + obs=data[i], + ) def guide(data): - p = pyro.param('p', torch.tensor(0.5)) - pyro.sample('x', pyro.distributions.Bernoulli(p)) + p = pyro.param("p", torch.tensor(0.5)) + pyro.sample("x", pyro.distributions.Bernoulli(p)) - data = torch.cat([torch.randn([5]), 1. + torch.randn([5])]) + data = torch.cat([torch.randn([5]), 1.0 + torch.randn([5])]) adam = optim.Adam({"lr": 0.01}) loss_fn = RenyiELBO(alpha=0, num_particles=30, vectorize_particles=True) svi = SVI(model, guide, adam, loss_fn) @@ -800,20 +978,22 @@ def model(data, weights): scale = torch.tensor(0.1) # Sample latents (shares no dimensions with data) - with pyro.plate('x_plate', weights.shape[0]): - x = pyro.sample('x', pyro.distributions.Normal(loc, scale)) + with pyro.plate("x_plate", weights.shape[0]): + x = pyro.sample("x", pyro.distributions.Normal(loc, scale)) # Combine with weights and sample - with pyro.plate('data_plate_1', data.shape[-1]): - with pyro.plate('data_plate_2', data.shape[-2]): - pyro.sample('data', pyro.distributions.Normal(x @ weights, scale), obs=data) + with pyro.plate("data_plate_1", data.shape[-1]): + with pyro.plate("data_plate_2", data.shape[-2]): + pyro.sample( + "data", pyro.distributions.Normal(x @ weights, scale), obs=data + ) def guide(data, weights): - loc = pyro.param('x_loc', torch.tensor(0.5)) + loc = pyro.param("x_loc", torch.tensor(0.5)) scale = torch.tensor(0.1) - with pyro.plate('x_plate', weights.shape[0]): - pyro.sample('x', pyro.distributions.Normal(loc, scale)) + with pyro.plate("x_plate", weights.shape[0]): + pyro.sample("x", pyro.distributions.Normal(loc, scale)) data = torch.randn([5, 3]) weights = torch.randn([2, 3]) diff --git a/tests/infer/test_initialization.py b/tests/infer/test_initialization.py index 1590814a9a..323a6225dc 100644 --- a/tests/infer/test_initialization.py +++ b/tests/infer/test_initialization.py @@ -24,8 +24,10 @@ def __init__(self): self.counter = 0 def __call__(self): - values = {"x": torch.tensor(self.counter + 0.0), - "y": torch.tensor(self.counter + 0.5)} + values = { + "x": torch.tensor(self.counter + 0.0), + "y": torch.tensor(self.counter + 0.5), + } self.counter += 1 return init_to_value(values=values) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 5b4e444c44..b459ef6a74 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -45,104 +45,107 @@ def test_simple(): y = torch.ones(2) def f(x): - logger.debug('Inside f') + logger.debug("Inside f") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) assert x is y return y + 1.0 - logger.debug('Compiling f') + logger.debug("Compiling f") f = torch.jit.trace(f, (y,), check_trace=False) - logger.debug('Calling f(y)') - assert_equal(f(y), torch.tensor([2., 2.])) - logger.debug('Calling f(y)') - assert_equal(f(y), torch.tensor([2., 2.])) - logger.debug('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2)), torch.tensor([1., 1.])) - logger.debug('Calling f(torch.zeros(5))') - assert_equal(f(torch.ones(5)), torch.tensor([2., 2., 2., 2., 2.])) + logger.debug("Calling f(y)") + assert_equal(f(y), torch.tensor([2.0, 2.0])) + logger.debug("Calling f(y)") + assert_equal(f(y), torch.tensor([2.0, 2.0])) + logger.debug("Calling f(torch.zeros(2))") + assert_equal(f(torch.zeros(2)), torch.tensor([1.0, 1.0])) + logger.debug("Calling f(torch.zeros(5))") + assert_equal(f(torch.ones(5)), torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])) def test_multi_output(): y = torch.ones(2) def f(x): - logger.debug('Inside f') + logger.debug("Inside f") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) assert x is y return y - 1.0, y + 1.0 - logger.debug('Compiling f') + logger.debug("Compiling f") f = torch.jit.trace(f, (y,), check_trace=False) - logger.debug('Calling f(y)') - assert_equal(f(y)[1], torch.tensor([2., 2.])) - logger.debug('Calling f(y)') - assert_equal(f(y)[1], torch.tensor([2., 2.])) - logger.debug('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2))[1], torch.tensor([1., 1.])) - logger.debug('Calling f(torch.zeros(5))') - assert_equal(f(torch.ones(5))[1], torch.tensor([2., 2., 2., 2., 2.])) + logger.debug("Calling f(y)") + assert_equal(f(y)[1], torch.tensor([2.0, 2.0])) + logger.debug("Calling f(y)") + assert_equal(f(y)[1], torch.tensor([2.0, 2.0])) + logger.debug("Calling f(torch.zeros(2))") + assert_equal(f(torch.zeros(2))[1], torch.tensor([1.0, 1.0])) + logger.debug("Calling f(torch.zeros(5))") + assert_equal(f(torch.ones(5))[1], torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])) def test_backward(): y = torch.ones(2, requires_grad=True) def f(x): - logger.debug('Inside f') + logger.debug("Inside f") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) assert x is y return (y + 1.0).sum() - logger.debug('Compiling f') + logger.debug("Compiling f") f = torch.jit.trace(f, (y,), check_trace=False) - logger.debug('Calling f(y)') + logger.debug("Calling f(y)") f(y).backward() - logger.debug('Calling f(y)') + logger.debug("Calling f(y)") f(y) - logger.debug('Calling f(torch.zeros(2))') + logger.debug("Calling f(torch.zeros(2))") f(torch.zeros(2, requires_grad=True)) - logger.debug('Calling f(torch.zeros(5))') + logger.debug("Calling f(torch.zeros(5))") f(torch.ones(5, requires_grad=True)) @pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad(): - def f(x, y): - logger.debug('Inside f') + logger.debug("Inside f") loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) - logger.debug('Compiling f') - f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True))) - logger.debug('Invoking f') + logger.debug("Compiling f") + f = torch.jit.trace( + f, (torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) + ) + logger.debug("Invoking f") f(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) - logger.debug('Invoking f') + logger.debug("Invoking f") f(torch.zeros(2, requires_grad=True), torch.zeros(2, requires_grad=True)) @pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad_expand(): - def f(x, y): - logger.debug('Inside f') + logger.debug("Inside f") loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) - logger.debug('Compiling f') - f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True))) - logger.debug('Invoking f') + logger.debug("Compiling f") + f = torch.jit.trace( + f, (torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) + ) + logger.debug("Invoking f") f(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) - logger.debug('Invoking f') + logger.debug("Invoking f") f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) def test_scale_and_mask(): - def f(tensor, scale, mask): return scale_and_mask(tensor, scale=scale, mask=mask) + def f(tensor, scale, mask): + return scale_and_mask(tensor, scale=scale, mask=mask) - x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + x = torch.tensor([-float("inf"), -1.0, 0.0, 1.0, float("inf")]) y = x / x.unsqueeze(-1) mask = y == y scale = torch.ones(y.shape) @@ -156,11 +159,10 @@ def f(tensor, scale, mask): return scale_and_mask(tensor, scale=scale, mask=mask def test_masked_fill(): - def f(y, mask): - return y.clone().masked_fill_(mask, 0.) + return y.clone().masked_fill_(mask, 0.0) - x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + x = torch.tensor([-float("inf"), -1.0, 0.0, 1.0, float("inf")]) y = x / x.unsqueeze(-1) mask = ~(y == y) jit_f = torch.jit.trace(f, (y, mask)) @@ -173,7 +175,6 @@ def f(y, mask): @pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11614") def test_scatter(): - def make_one_hot(x, i): return torch.zeros_like(x).scatter(-1, i.unsqueeze(-1), 1.0) @@ -182,9 +183,8 @@ def make_one_hot(x, i): torch.jit.trace(make_one_hot, (x, i)) -@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') +@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python integer") def test_scatter_workaround(): - def make_one_hot_expected(x, i): return torch.zeros_like(x).scatter(-1, i.unsqueeze(-1), 1.0) @@ -200,9 +200,9 @@ def make_one_hot_actual(x, i): assert_equal(actual, expected) -@pytest.mark.parametrize('expand', [False, True]) -@pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) -@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python boolean') +@pytest.mark.parametrize("expand", [False, True]) +@pytest.mark.parametrize("shape", [(), (4,), (5, 4)]) +@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean") def test_bernoulli_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.full(shape, 0.25) @@ -217,8 +217,8 @@ def f(probs): assert log_prob.shape == (2,) + shape -@pytest.mark.parametrize('expand', [False, True]) -@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +@pytest.mark.parametrize("expand", [False, True]) +@pytest.mark.parametrize("shape", [(3,), (4, 3), (5, 4, 3)]) def test_categorical_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.ones(shape) @@ -234,9 +234,9 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape -@pytest.mark.parametrize('expand', [False, True]) -@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) -@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') +@pytest.mark.parametrize("expand", [False, True]) +@pytest.mark.parametrize("shape", [(3,), (4, 3), (5, 4, 3)]) +@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python integer") def test_one_hot_categorical_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.ones(shape) @@ -252,25 +252,30 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape -@pytest.mark.parametrize('num_particles', [1, 10]) -@pytest.mark.parametrize('Elbo', [ - Trace_ELBO, - JitTrace_ELBO, - TraceGraph_ELBO, - JitTraceGraph_ELBO, - TraceEnum_ELBO, - JitTraceEnum_ELBO, - TraceMeanField_ELBO, - JitTraceMeanField_ELBO, -]) +@pytest.mark.parametrize("num_particles", [1, 10]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + JitTrace_ELBO, + TraceGraph_ELBO, + JitTraceGraph_ELBO, + TraceEnum_ELBO, + JitTraceEnum_ELBO, + TraceMeanField_ELBO, + JitTraceMeanField_ELBO, + ], +) def test_svi(Elbo, num_particles): pyro.clear_param_store() - data = torch.arange(10.) + data = torch.arange(10.0) def model(data): loc = pyro.param("loc", constant(0.0)) scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) - pyro.sample("x", dist.Normal(loc, scale).expand_by(data.shape).to_event(1), obs=data) + pyro.sample( + "x", dist.Normal(loc, scale).expand_by(data.shape).to_event(1), obs=data + ) def guide(data): pass @@ -299,7 +304,9 @@ def guide(): q = pyro.param("q") pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1}) for i in pyro.plate("plate", plate_dim): - pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2}) + pyro.sample( + "y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2} + ) kl = (1 + plate_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() @@ -307,26 +314,44 @@ def guide(): inner_particles = 2 outer_particles = num_particles // inner_particles - elbo = TraceEnum_ELBO(max_plate_nesting=0, - strict_enumeration_warning=any([enumerate1, enumerate2]), - num_particles=inner_particles, - ignore_jit_warnings=True) - actual_loss = sum(elbo.loss_and_grads(model, guide) - for i in range(outer_particles)) / outer_particles + elbo = TraceEnum_ELBO( + max_plate_nesting=0, + strict_enumeration_warning=any([enumerate1, enumerate2]), + num_particles=inner_particles, + ignore_jit_warnings=True, + ) + actual_loss = ( + sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) + / outer_particles + ) actual_grad = q.unconstrained().grad / outer_particles - assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) - assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) - - -@pytest.mark.parametrize('vectorized', [False, True]) -@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) + assert_equal( + actual_loss, + expected_loss, + prec=0.3, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) + assert_equal( + actual_grad, + expected_grad, + prec=0.5, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) + + +@pytest.mark.parametrize("vectorized", [False, True]) +@pytest.mark.parametrize("Elbo", [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_beta_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) @@ -342,35 +367,39 @@ def model2(data): alpha0 = constant(10.0) beta0 = constant(10.0) f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) - pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), - obs=data) + pyro.sample( + "obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data + ) model = model2 if vectorized else model1 def guide(data): - alpha_q = pyro.param("alpha_q", constant(15.0), - constraint=constraints.positive) - beta_q = pyro.param("beta_q", constant(15.0), - constraint=constraints.positive) + alpha_q = pyro.param("alpha_q", constant(15.0), constraint=constraints.positive) + beta_q = pyro.param("beta_q", constant(15.0), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) + elbo = Elbo( + num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True + ) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): svi.step(data) -@pytest.mark.parametrize('Elbo', [ - Trace_ELBO, - JitTrace_ELBO, - TraceGraph_ELBO, - JitTraceGraph_ELBO, - TraceEnum_ELBO, - JitTraceEnum_ELBO, - TraceMeanField_ELBO, - JitTraceMeanField_ELBO, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + JitTrace_ELBO, + TraceGraph_ELBO, + JitTraceGraph_ELBO, + TraceEnum_ELBO, + JitTraceEnum_ELBO, + TraceMeanField_ELBO, + JitTraceMeanField_ELBO, + ], +) def test_svi_irregular_batch_size(Elbo): pyro.clear_param_store() @@ -379,9 +408,7 @@ def model(data): loc = pyro.param("loc", constant(0.0)) scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) with pyro.plate("data", data.shape[0]): - pyro.sample("x", - dist.Normal(loc, scale).expand([data.shape[0]]), - obs=data) + pyro.sample("x", dist.Normal(loc, scale).expand([data.shape[0]]), obs=data) def guide(data): pass @@ -393,8 +420,8 @@ def guide(data): inference.step(torch.ones(3)) -@pytest.mark.parametrize('vectorized', [False, True]) -@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) +@pytest.mark.parametrize("vectorized", [False, True]) +@pytest.mark.parametrize("Elbo", [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) @@ -408,29 +435,35 @@ def model1(data): def model2(data): concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] - pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), - obs=data) + pyro.sample( + "obs", dist.Bernoulli(f).expand_by(data.shape).to_event(1), obs=data + ) model = model2 if vectorized else model1 def guide(data): - concentration_q = pyro.param("concentration_q", constant([15.0, 15.0]), - constraint=constraints.positive) + concentration_q = pyro.param( + "concentration_q", constant([15.0, 15.0]), constraint=constraints.positive + ) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) + elbo = Elbo( + num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True + ) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): svi.step(data) -@pytest.mark.parametrize('length', [1, 2, 10]) +@pytest.mark.parametrize("length", [1, 2, 10]) def test_traceenum_elbo(length): hidden_dim = 10 - transition = pyro.param("transition", - 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim), - constraint=constraints.positive) + transition = pyro.param( + "transition", + 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim), + constraint=constraints.positive, + ) means = pyro.param("means", torch.arange(float(hidden_dim))) data = 1 + 2 * torch.randn(length) @@ -440,19 +473,27 @@ def model(data): means = pyro.param("means") states = [torch.tensor(0)] for t in pyro.markov(range(len(data))): - states.append(pyro.sample("states_{}".format(t), - dist.Categorical(transition[states[-1]]), - infer={"enumerate": "parallel"})) - pyro.sample("obs_{}".format(t), - dist.Normal(means[states[-1]], 1.), - obs=data[t]) + states.append( + pyro.sample( + "states_{}".format(t), + dist.Categorical(transition[states[-1]]), + infer={"enumerate": "parallel"}, + ) + ) + pyro.sample( + "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] + ) return tuple(states) def guide(data): pass - expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) - actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) + expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss( + model, guide, data + ) + actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss( + model, guide, data + ) assert_equal(expected_loss, actual_loss) expected_grads = grad(expected_loss, [transition, means], allow_unused=True) @@ -461,20 +502,23 @@ def guide(data): assert_equal(e, a, msg="bad gradient for {}".format(name)) -@pytest.mark.parametrize('length', [1, 2, 10]) -@pytest.mark.parametrize('temperature', [0, 1], ids=['map', 'sample']) +@pytest.mark.parametrize("length", [1, 2, 10]) +@pytest.mark.parametrize("temperature", [0, 1], ids=["map", "sample"]) def test_infer_discrete(temperature, length): - @ignore_jit_warnings() def hmm(transition, means, data): states = [torch.tensor(0)] for t in pyro.markov(range(len(data))): - states.append(pyro.sample("states_{}".format(t), - dist.Categorical(transition[states[-1]]), - infer={"enumerate": "parallel"})) - pyro.sample("obs_{}".format(t), - dist.Normal(means[states[-1]], 1.), - obs=data[t]) + states.append( + pyro.sample( + "states_{}".format(t), + dist.Categorical(transition[states[-1]]), + infer={"enumerate": "parallel"}, + ) + ) + pyro.sample( + "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] + ) return tuple(states) hidden_dim = 10 @@ -494,10 +538,19 @@ def hmm(transition, means, data): assert_equal(state, jit_state) -@pytest.mark.parametrize("x,y", [ - (CondIndepStackFrame("a", -1, torch.tensor(2000), 2), CondIndepStackFrame("a", -1, 2000, 2)), - (CondIndepStackFrame("a", -1, 1, 2), CondIndepStackFrame("a", -1, torch.tensor(1), 2)), -]) +@pytest.mark.parametrize( + "x,y", + [ + ( + CondIndepStackFrame("a", -1, torch.tensor(2000), 2), + CondIndepStackFrame("a", -1, 2000, 2), + ), + ( + CondIndepStackFrame("a", -1, 1, 2), + CondIndepStackFrame("a", -1, torch.tensor(1), 2), + ), + ], +) def test_cond_indep_equality(x, y): assert x == y assert not x != y diff --git a/tests/infer/test_multi_sample_elbos.py b/tests/infer/test_multi_sample_elbos.py index 1fcfd9978d..5b56e68267 100644 --- a/tests/infer/test_multi_sample_elbos.py +++ b/tests/infer/test_multi_sample_elbos.py @@ -32,8 +32,7 @@ def model(): with pyro.plate("outer", 3, dim=-1): x = pyro.sample("x", dist.Normal(0, 1)) with pyro.plate("inner", 2, dim=-2): - pyro.sample("y", dist.Normal(x, 1), - obs=data) + pyro.sample("y", dist.Normal(x, 1), obs=data) def guide(): with pyro.plate("outer", 3, dim=-1): @@ -50,8 +49,7 @@ def model(): with pyro.plate("outer", 2, dim=-2): x = pyro.sample("x", dist.Normal(0, 1)) with pyro.plate("inner", 3, dim=-1): - pyro.sample("y", dist.Normal(x, 1), - obs=data) + pyro.sample("y", dist.Normal(x, 1), obs=data) def guide(): with pyro.plate("outer", 2, dim=-2): diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index cf7d342e36..f3f6b4553d 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -15,7 +15,7 @@ def model(num_trials): with pyro.plate("data", num_trials.size(0)): - phi_prior = dist.Uniform(num_trials.new_tensor(0.), num_trials.new_tensor(1.)) + phi_prior = dist.Uniform(num_trials.new_tensor(0.0), num_trials.new_tensor(1.0)) success_prob = pyro.sample("phi", phi_prior) return pyro.sample("obs", dist.Binomial(num_trials, success_prob)) @@ -28,8 +28,12 @@ def one_hot_model(pseudocounts, classes=None): def beta_guide(num_trials): - phi_c0 = pyro.param("phi_c0", num_trials.new_tensor(5.0).expand([num_trials.size(0)])) - phi_c1 = pyro.param("phi_c1", num_trials.new_tensor(5.0).expand([num_trials.size(0)])) + phi_c0 = pyro.param( + "phi_c0", num_trials.new_tensor(5.0).expand([num_trials.size(0)]) + ) + phi_c1 = pyro.param( + "phi_c1", num_trials.new_tensor(5.0).expand([num_trials.size(0)]) + ) with pyro.plate("data", num_trials.size(0)): phi_posterior = dist.Beta(concentration0=phi_c0, concentration1=phi_c1) pyro.sample("phi", phi_posterior) @@ -45,8 +49,13 @@ def test_posterior_predictive_svi_manual_guide(parallel): svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=1.0)), elbo) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive(model, guide=beta_guide, num_samples=10000, - parallel=parallel, return_sites=["_RETURN"]) + posterior_predictive = Predictive( + model, + guide=beta_guide, + num_samples=10000, + parallel=parallel, + return_sites=["_RETURN"], + ) marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -61,7 +70,9 @@ def test_posterior_predictive_svi_auto_delta_guide(parallel): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=parallel) + posterior_predictive = Predictive( + model, guide=guide, num_samples=10000, parallel=parallel + ) marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -76,9 +87,13 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=True) + posterior_predictive = Predictive( + model, guide=guide, num_samples=10000, parallel=True + ) if return_trace: - marginal_return_vals = posterior_predictive.get_vectorized_trace(num_trials).nodes["obs"]["value"] + marginal_return_vals = posterior_predictive.get_vectorized_trace( + num_trials + ).nodes["obs"]["value"] else: marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -117,8 +132,13 @@ def model(): expected = poutine.replay(vectorize(model), trace)() # Use Predictive. - predictive = Predictive(model, guide=guide, return_sites=["x", "y"], - num_samples=num_samples, parallel=parallel) + predictive = Predictive( + model, + guide=guide, + return_sites=["x", "y"], + num_samples=num_samples, + parallel=parallel, + ) actual = predictive.get_samples() assert set(actual) == set(expected) assert actual["x"].shape == expected["x"].shape @@ -136,13 +156,15 @@ def model(y=None): pyro.deterministic("x3", x2) return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y) - y = torch.tensor(4.) + y = torch.tensor(4.0) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(100): svi.step(y) - actual = Predictive(model, guide=guide, return_sites=["x2", "x3"], num_samples=1000)() + actual = Predictive( + model, guide=guide, return_sites=["x2", "x3"], num_samples=1000 + )() x2_batch_shape = (3,) if with_plate else () assert actual["x2"].shape == (1000,) + x2_batch_shape + event_shape # x3 shape is prepended 1 to match Pyro shape semantics @@ -153,10 +175,9 @@ def model(y=None): def test_get_mask_optimization(): - def model(): x = pyro.sample("x", dist.Normal(0, 1)) - pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.)) + pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0)) called.add("model-always") if poutine.get_mask() is not False: called.add("model-sometimes") diff --git a/tests/infer/test_sampling.py b/tests/infer/test_sampling.py index 8c2e9662aa..d28dcd14c4 100644 --- a/tests/infer/test_sampling.py +++ b/tests/infer/test_sampling.py @@ -14,7 +14,6 @@ class HMMSamplingTestCase(TestCase): - def setUp(self): # simple Gaussian-emission HMM @@ -27,13 +26,23 @@ def model(): for t in range(self.model_steps): latents.append( - pyro.sample("latent_{}".format(str(t)), - Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long())))) + pyro.sample( + "latent_{}".format(str(t)), + Bernoulli( + torch.index_select(p_latent, 0, latents[-1].view(-1).long()) + ), + ) + ) observes.append( - pyro.sample("observe_{}".format(str(t)), - Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())), - obs=self.data[t])) + pyro.sample( + "observe_{}".format(str(t)), + Bernoulli( + torch.index_select(p_obs, 0, latents[-1].view(-1).long()) + ), + obs=self.data[t], + ) + ) return torch.sum(torch.cat(latents)) self.model_steps = 3 @@ -42,21 +51,18 @@ def model(): class NormalNormalSamplingTestCase(TestCase): - def setUp(self): pyro.clear_param_store() def model(): - loc = pyro.sample("loc", Normal(torch.zeros(1), - torch.ones(1))) + loc = pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1))) xd = Normal(loc, torch.ones(1)) pyro.sample("xs", xd, obs=self.data) return loc def guide(): - return pyro.sample("loc", Normal(torch.zeros(1), - torch.ones(1))) + return pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1))) # data self.data = torch.zeros(50, 1) @@ -69,17 +75,24 @@ def guide(): class ImportanceTest(NormalNormalSamplingTestCase): - @pytest.mark.init(rng_seed=0) def test_importance_guide(self): - posterior = pyro.infer.Importance(self.model, guide=self.guide, num_samples=5000).run() + posterior = pyro.infer.Importance( + self.model, guide=self.guide, num_samples=5000 + ).run() marginal = EmpiricalMarginal(posterior) assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01) - assert_equal(0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1) + assert_equal( + 0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1 + ) @pytest.mark.init(rng_seed=0) def test_importance_prior(self): - posterior = pyro.infer.Importance(self.model, guide=None, num_samples=10000).run() + posterior = pyro.infer.Importance( + self.model, guide=None, num_samples=10000 + ).run() marginal = EmpiricalMarginal(posterior) assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01) - assert_equal(0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1) + assert_equal( + 0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1 + ) diff --git a/tests/infer/test_smcfilter.py b/tests/infer/test_smcfilter.py index 4073496f10..0001f72d71 100644 --- a/tests/infer/test_smcfilter.py +++ b/tests/infer/test_smcfilter.py @@ -21,8 +21,9 @@ def test_systematic_sample(size): num_samples = 20000 index = _systematic_sample(probs.expand(num_samples, size)) histogram = torch.zeros_like(probs) - histogram.scatter_add_(-1, index.reshape(-1), - probs.new_ones(1).expand(num_samples * size)) + histogram.scatter_add_( + -1, index.reshape(-1), probs.new_ones(1).expand(num_samples * size) + ) expected = probs * size actual = histogram / num_samples @@ -30,48 +31,57 @@ def test_systematic_sample(size): class SmokeModel: - def __init__(self, state_size, plate_size): self.state_size = state_size self.plate_size = plate_size def init(self, state): self.t = 0 - state["x_mean"] = pyro.sample("x_mean", dist.Normal(0., 1.)) - state["y_mean"] = pyro.sample("y_mean", - dist.MultivariateNormal(torch.zeros(self.state_size), - torch.eye(self.state_size))) + state["x_mean"] = pyro.sample("x_mean", dist.Normal(0.0, 1.0)) + state["y_mean"] = pyro.sample( + "y_mean", + dist.MultivariateNormal( + torch.zeros(self.state_size), torch.eye(self.state_size) + ), + ) def step(self, state, x=None, y=None): - v = pyro.sample("v_{}".format(self.t), dist.Normal(0., 1.)) + v = pyro.sample("v_{}".format(self.t), dist.Normal(0.0, 1.0)) with pyro.plate("plate", self.plate_size): - w = pyro.sample("w_{}".format(self.t), dist.Normal(v, 1.)) - x = pyro.sample("x_{}".format(self.t), - dist.Normal(state["x_mean"] + w, 1), obs=x) - y = pyro.sample("y_{}".format(self.t), - dist.MultivariateNormal(state["y_mean"] + w.unsqueeze(-1), torch.eye(self.state_size)), - obs=y) + w = pyro.sample("w_{}".format(self.t), dist.Normal(v, 1.0)) + x = pyro.sample( + "x_{}".format(self.t), dist.Normal(state["x_mean"] + w, 1), obs=x + ) + y = pyro.sample( + "y_{}".format(self.t), + dist.MultivariateNormal( + state["y_mean"] + w.unsqueeze(-1), torch.eye(self.state_size) + ), + obs=y, + ) self.t += 1 return x, y class SmokeGuide: - def __init__(self, state_size, plate_size): self.state_size = state_size self.plate_size = plate_size def init(self, state): self.t = 0 - pyro.sample("x_mean", dist.Normal(0., 2.)) - pyro.sample("y_mean", - dist.MultivariateNormal(torch.zeros(self.state_size), - 2.*torch.eye(self.state_size))) + pyro.sample("x_mean", dist.Normal(0.0, 2.0)) + pyro.sample( + "y_mean", + dist.MultivariateNormal( + torch.zeros(self.state_size), 2.0 * torch.eye(self.state_size) + ), + ) def step(self, state, x=None, y=None): - v = pyro.sample("v_{}".format(self.t), dist.Normal(0., 2.)) + v = pyro.sample("v_{}".format(self.t), dist.Normal(0.0, 2.0)) with pyro.plate("plate", self.plate_size): - pyro.sample("w_{}".format(self.t), dist.Normal(v, 2.)) + pyro.sample("w_{}".format(self.t), dist.Normal(v, 2.0)) self.t += 1 @@ -83,7 +93,9 @@ def test_smoke(max_plate_nesting, state_size, plate_size, num_steps): model = SmokeModel(state_size, plate_size) guide = SmokeGuide(state_size, plate_size) - smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=max_plate_nesting) + smc = SMCFilter( + model, guide, num_particles=100, max_plate_nesting=max_plate_nesting + ) true_model = SmokeModel(state_size, plate_size) @@ -100,27 +112,27 @@ def test_smoke(max_plate_nesting, state_size, plate_size, num_steps): class HarmonicModel: - def __init__(self): - self.A = torch.tensor([[0., 1.], - [-1., 0.]]) - self.B = torch.tensor([3., 3.]) - self.sigma_z = torch.tensor(1.) - self.sigma_y = torch.tensor(1.) + self.A = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) + self.B = torch.tensor([3.0, 3.0]) + self.sigma_z = torch.tensor(1.0) + self.sigma_y = torch.tensor(1.0) def init(self, state): self.t = 0 - state["z"] = pyro.sample("z_init", - dist.Delta(torch.tensor([1., 0.]), event_dim=1)) + state["z"] = pyro.sample( + "z_init", dist.Delta(torch.tensor([1.0, 0.0]), event_dim=1) + ) def step(self, state, y=None): self.t += 1 - state["z"] = pyro.sample("z_{}".format(self.t), - dist.Normal(state["z"].matmul(self.A), - self.B*self.sigma_z).to_event(1)) - y = pyro.sample("y_{}".format(self.t), - dist.Normal(state["z"][..., 0], self.sigma_y), - obs=y) + state["z"] = pyro.sample( + "z_{}".format(self.t), + dist.Normal(state["z"].matmul(self.A), self.B * self.sigma_z).to_event(1), + ) + y = pyro.sample( + "y_{}".format(self.t), dist.Normal(state["z"][..., 0], self.sigma_y), obs=y + ) state["z_{}".format(self.t)] = state["z"] # saved for testing @@ -128,21 +140,23 @@ def step(self, state, y=None): class HarmonicGuide: - def __init__(self): self.model = HarmonicModel() def init(self, state): self.t = 0 - pyro.sample("z_init", dist.Delta(torch.tensor([1., 0.]), event_dim=1)) + pyro.sample("z_init", dist.Delta(torch.tensor([1.0, 0.0]), event_dim=1)) def step(self, state, y=None): self.t += 1 # Proposal distribution - pyro.sample("z_{}".format(self.t), - dist.Normal(state["z"].matmul(self.model.A), - torch.tensor([2., 2.])).to_event(1)) + pyro.sample( + "z_{}".format(self.t), + dist.Normal( + state["z"].matmul(self.model.A), torch.tensor([2.0, 2.0]) + ).to_event(1), + ) def generate_data(): @@ -150,7 +164,7 @@ def generate_data(): state = {} model.init(state) - zs = [torch.tensor([1., 0.])] + zs = [torch.tensor([1.0, 0.0])] ys = [None] for t in range(50): z, y = model.step(state) @@ -187,17 +201,19 @@ def test_likelihood_ratio(): i = smc.state._log_weights.max(0)[1] values = {k: v[i] for k, v in smc.state.items()} - zs_pred = [torch.tensor([1., 0.])] + zs_pred = [torch.tensor([1.0, 0.0])] zs_pred += [values["z_{}".format(t)] for t in range(1, 51)] - assert(score_latent(zs_true, ys_true) > score_latent(zs, ys_true)) - assert(score_latent(zs_pred, ys_true) > score_latent(zs_pred, ys)) - assert(score_latent(zs_pred, ys_true) > score_latent(zs, ys_true)) + assert score_latent(zs_true, ys_true) > score_latent(zs, ys_true) + assert score_latent(zs_pred, ys_true) > score_latent(zs_pred, ys) + assert score_latent(zs_pred, ys_true) > score_latent(zs, ys_true) def test_gaussian_filter(): dim = 4 - init_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim) * 10) + init_dist = dist.MultivariateNormal( + torch.zeros(dim), scale_tril=torch.eye(dim) * 10 + ) trans_mat = torch.eye(dim) trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim)) obs_mat = torch.eye(dim) @@ -210,11 +226,15 @@ def init(self, state): self.t = 0 def step(self, state, datum=None): - state["z"] = pyro.sample("z_{}".format(self.t), - dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril)) - datum = pyro.sample("obs_{}".format(self.t), - dist.MultivariateNormal(state["z"], scale_tril=obs_dist.scale_tril), - obs=datum) + state["z"] = pyro.sample( + "z_{}".format(self.t), + dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril), + ) + datum = pyro.sample( + "obs_{}".format(self.t), + dist.MultivariateNormal(state["z"], scale_tril=obs_dist.scale_tril), + obs=datum, + ) self.t += 1 return datum @@ -224,8 +244,12 @@ def init(self, state): self.t = 0 def step(self, state, datum): - pyro.sample("z_{}".format(self.t), - dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril * 2)) + pyro.sample( + "z_{}".format(self.t), + dist.MultivariateNormal( + state["z"], scale_tril=trans_dist.scale_tril * 2 + ), + ) self.t += 1 # Generate data. @@ -242,8 +266,10 @@ def step(self, state, datum): smc.init() for t, datum in enumerate(data): smc.step(datum) - expected = hmm.filter(data[:1+t]) + expected = hmm.filter(data[: 1 + t]) actual = smc.get_empirical()["z"] - assert_close(actual.variance ** 0.5, expected.variance ** 0.5, atol=0.1, rtol=0.5) + assert_close( + actual.variance ** 0.5, expected.variance ** 0.5, atol=0.1, rtol=0.5 + ) sigma = actual.variance.max().item() ** 0.5 assert_close(actual.mean, expected.mean, atol=3 * sigma) diff --git a/tests/infer/test_svgd.py b/tests/infer/test_svgd.py index c6944dedc2..d90efb0c8c 100644 --- a/tests/infer/test_svgd.py +++ b/tests/infer/test_svgd.py @@ -12,10 +12,15 @@ from tests.common import assert_equal -@pytest.mark.parametrize("latent_dist", [dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1), - dist.LogNormal(torch.tensor([-1.0]), torch.tensor([0.7])).to_event(1), - dist.LogNormal(torch.tensor(-1.0), torch.tensor(0.7)), - dist.Beta(torch.tensor([0.3]), torch.tensor([0.7])).to_event(1)]) +@pytest.mark.parametrize( + "latent_dist", + [ + dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1), + dist.LogNormal(torch.tensor([-1.0]), torch.tensor([0.7])).to_event(1), + dist.LogNormal(torch.tensor(-1.0), torch.tensor(0.7)), + dist.Beta(torch.tensor([0.3]), torch.tensor([0.7])).to_event(1), + ], +) @pytest.mark.parametrize("mode", ["univariate", "multivariate"]) @pytest.mark.parametrize("stein_kernel", [RBFSteinKernel, IMQSteinKernel]) def test_mean_variance(latent_dist, mode, stein_kernel, verbose=True): @@ -34,25 +39,33 @@ def model(): # scramble initial particles svgd.step() - pyro.param('svgd_particles').unconstrained().data *= 1.3 - pyro.param('svgd_particles').unconstrained().data += 0.7 + pyro.param("svgd_particles").unconstrained().data *= 1.3 + pyro.param("svgd_particles").unconstrained().data += 0.7 for step in range(n_steps): - kernel.bandwidth_factor = bandwidth_start + (step / n_steps) * (bandwidth_end - bandwidth_start) + kernel.bandwidth_factor = bandwidth_start + (step / n_steps) * ( + bandwidth_end - bandwidth_start + ) squared_gradients = svgd.step() if step % 125 == 0: print("[step %03d] " % step, squared_gradients) - final_particles = svgd.get_named_particles()['z'] + final_particles = svgd.get_named_particles()["z"] if verbose: - print("[mean]: actual, expected = ", final_particles.mean(0).data.numpy(), - latent_dist.mean.data.numpy()) - print("[var]: actual, expected = ", final_particles.var(0).data.numpy(), - latent_dist.variance.data.numpy()) + print( + "[mean]: actual, expected = ", + final_particles.mean(0).data.numpy(), + latent_dist.mean.data.numpy(), + ) + print( + "[var]: actual, expected = ", + final_particles.var(0).data.numpy(), + latent_dist.variance.data.numpy(), + ) assert_equal(final_particles.mean(0), latent_dist.mean, prec=0.01) - prec = 0.05 if mode == 'multivariate' else 0.02 + prec = 0.05 if mode == "multivariate" else 0.02 assert_equal(final_particles.var(0), latent_dist.variance, prec=prec) @@ -77,12 +90,12 @@ def model(): svgd.step() particles = svgd.get_named_particles() - assert particles['z1'].shape == (num_particles,) + shape1 - assert particles['z2'].shape == (num_particles,) + shape2 + assert particles["z1"].shape == (num_particles,) + shape1 + assert particles["z2"].shape == (num_particles,) + shape2 for particle in range(num_particles): - assert_equal(particles['z1'][particle, ...], mean_init1.exp(), prec=1.0e-6) - assert_equal(particles['z2'][particle, ...], mean_init2, prec=1.0e-6) + assert_equal(particles["z1"][particle, ...], mean_init1.exp(), prec=1.0e-6) + assert_equal(particles["z2"][particle, ...], mean_init2, prec=1.0e-6) @pytest.mark.parametrize("mode", ["univariate", "multivariate"]) @@ -92,12 +105,11 @@ def test_conjugate(mode, stein_kernel, verbose=False): alpha0 = torch.tensor([1.0, 1.8, 2.3]) beta0 = torch.tensor([2.3, 1.5, 1.2]) alpha_n = alpha0 + data.sum(0) # posterior alpha - beta_n = beta0 + data.size(0) # posterior beta + beta_n = beta0 + data.size(0) # posterior beta def model(): with pyro.plate("rates", alpha0.size(0)): - latent = pyro.sample("latent", - dist.Gamma(alpha0, beta0)) + latent = pyro.sample("latent", dist.Gamma(alpha0, beta0)) with pyro.plate("data", data.size(0)): pyro.sample("obs", dist.Poisson(latent), obs=data) @@ -110,20 +122,28 @@ def model(): n_steps = 451 for step in range(n_steps): - kernel.bandwidth_factor = bandwidth_start + (step / n_steps) * (bandwidth_end - bandwidth_start) + kernel.bandwidth_factor = bandwidth_start + (step / n_steps) * ( + bandwidth_end - bandwidth_start + ) squared_gradients = svgd.step() if step % 150 == 0: print("[step %03d] " % step, squared_gradients) - final_particles = svgd.get_named_particles()['latent'] + final_particles = svgd.get_named_particles()["latent"] posterior_dist = dist.Gamma(alpha_n, beta_n) if verbose: - print("[mean]: actual, expected = ", final_particles.mean(0).data.numpy(), - posterior_dist.mean.data.numpy()) - print("[var]: actual, expected = ", final_particles.var(0).data.numpy(), - posterior_dist.variance.data.numpy()) + print( + "[mean]: actual, expected = ", + final_particles.mean(0).data.numpy(), + posterior_dist.mean.data.numpy(), + ) + print( + "[var]: actual, expected = ", + final_particles.var(0).data.numpy(), + posterior_dist.variance.data.numpy(), + ) assert_equal(final_particles.mean(0)[0], posterior_dist.mean, prec=0.02) - prec = 0.05 if mode == 'multivariate' else 0.02 + prec = 0.05 if mode == "multivariate" else 0.02 assert_equal(final_particles.var(0)[0], posterior_dist.variance, prec=prec) diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index 35667e56ef..0c29792704 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -29,11 +29,13 @@ def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy): qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))] for i in range(1, depth): - qs.append(pyro.param( - "q{}".format(i), - torch.randn(2, 2).abs().detach().requires_grad_(), - constraint=constraints.simplex - )) + qs.append( + pyro.param( + "q{}".format(i), + torch.randn(2, 2).abs().detach().requires_grad_(), + constraint=constraints.simplex, + ) + ) qs.append(pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True))) qs = [q.unconstrained() for q in qs] @@ -44,33 +46,56 @@ def model(): x = pyro.sample("x0", dist.Categorical(pyro.param("q0"))) with pyro.plate("local", 3): for i in range(1, depth): - x = pyro.sample("x{}".format(i), - dist.Categorical(pyro.param("q{}".format(i))[..., x, :])) + x = pyro.sample( + "x{}".format(i), + dist.Categorical(pyro.param("q{}".format(i))[..., x, :]), + ) with pyro.plate("data", 4): - pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), - obs=data) + pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), obs=data) elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting) - enum_model = config_enumerate(model, default="parallel", expand=False, num_samples=None, tmc=tmc_strategy) + enum_model = config_enumerate( + model, default="parallel", expand=False, num_samples=None, tmc=tmc_strategy + ) expected_loss = (-elbo.differentiable_loss(enum_model, lambda: None)).exp() expected_grads = grad(expected_loss, qs) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) - tmc_model = config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) + tmc_model = config_enumerate( + model, + default="parallel", + expand=False, + num_samples=num_samples, + tmc=tmc_strategy, + ) actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp() actual_grads = grad(actual_loss, qs) prec = 0.05 - assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) + assert_equal( + actual_loss, + expected_loss, + prec=prec, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=prec, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("depth", [1, 2]) @@ -79,62 +104,106 @@ def model(): @pytest.mark.parametrize("reparameterized", [True, False]) @pytest.mark.parametrize("guide_type", ["prior", "factorized", "nonfactorized"]) @pytest.mark.parametrize("tmc_strategy", ["diagonal", "mixture"]) -def test_tmc_normals_chain_iwae(depth, num_samples, max_plate_nesting, - reparameterized, guide_type, expand, tmc_strategy): +def test_tmc_normals_chain_iwae( + depth, + num_samples, + max_plate_nesting, + reparameterized, + guide_type, + expand, + tmc_strategy, +): # compare iwae and tmc q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True)) qs = (q2.unconstrained(),) def model(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) - pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) + pyro.sample("y", Normal(x, 1.0), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i+1) / depth))) + pyro.sample("x{}".format(i), Normal(0.0, math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) - - guide = factorized_guide if guide_type == "factorized" else \ - nonfactorized_guide if guide_type == "nonfactorized" else \ - poutine.block(model, hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"]) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) + + guide = ( + factorized_guide + if guide_type == "factorized" + else nonfactorized_guide + if guide_type == "nonfactorized" + else poutine.block( + model, hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"] + ) + ) flat_num_samples = num_samples ** min(depth, 2) # don't use too many, expensive vectorized_log_weights, _, _ = vectorized_importance_weights( - model, guide, True, + model, + guide, + True, max_plate_nesting=max_plate_nesting, - num_samples=flat_num_samples) + num_samples=flat_num_samples, + ) assert vectorized_log_weights.shape == (flat_num_samples,) - expected_loss = (vectorized_log_weights.logsumexp(dim=-1) - math.log(float(flat_num_samples))).exp() + expected_loss = ( + vectorized_log_weights.logsumexp(dim=-1) - math.log(float(flat_num_samples)) + ).exp() expected_grads = grad(expected_loss, qs) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = config_enumerate( - model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) + model, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) tmc_guide = config_enumerate( - guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) - actual_loss = (-tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() + guide, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) + actual_loss = ( + -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized) + ).exp() actual_grads = grad(actual_loss, qs) - assert_equal(actual_loss, expected_loss, prec=0.05, msg="".join([ - "\nexpected loss = {}".format(expected_loss), - "\n actual loss = {}".format(actual_loss), - ])) + assert_equal( + actual_loss, + expected_loss, + prec=0.05, + msg="".join( + [ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ] + ), + ) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=grad_prec, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) @pytest.mark.parametrize("depth", [1, 2, 3, 4]) @@ -143,54 +212,84 @@ def nonfactorized_guide(reparameterized): @pytest.mark.parametrize("guide_type", ["prior", "factorized", "nonfactorized"]) @pytest.mark.parametrize("reparameterized", [False, True]) @pytest.mark.parametrize("tmc_strategy", ["diagonal", "mixture"]) -def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting, expand, - guide_type, reparameterized, tmc_strategy): +def test_tmc_normals_chain_gradient( + depth, + num_samples, + max_plate_nesting, + expand, + guide_type, + reparameterized, + tmc_strategy, +): # compare reparameterized and nonreparameterized gradient estimates q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True)) qs = (q2.unconstrained(),) def model(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) - pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) + pyro.sample("y", Normal(x, 1.0), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i+1) / depth))) + pyro.sample("x{}".format(i), Normal(0.0, math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal - x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) + x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1.0 / depth))) for i in range(1, depth): - x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) + x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1.0 / depth))) tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = config_enumerate( - model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) - guide = factorized_guide if guide_type == "factorized" else \ - nonfactorized_guide if guide_type == "nonfactorized" else \ - lambda *args: None + model, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) + guide = ( + factorized_guide + if guide_type == "factorized" + else nonfactorized_guide + if guide_type == "nonfactorized" + else lambda *args: None + ) tmc_guide = config_enumerate( - guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) + guide, + default="parallel", + expand=expand, + num_samples=num_samples, + tmc=tmc_strategy, + ) # gold values from Funsor - expected_grads = (torch.tensor( - {1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771}[depth] - ),) + expected_grads = ( + torch.tensor({1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771}[depth]), + ) # convert to linear space for unbiasedness - actual_loss = (-tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() + actual_loss = ( + -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized) + ).exp() actual_grads = grad(actual_loss, qs) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): print(actual_loss) - assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ - "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), - "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), - ])) + assert_equal( + actual_grad, + expected_grad, + prec=grad_prec, + msg="".join( + [ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ] + ), + ) diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 7fda7a399f..45f5807295 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -16,15 +16,15 @@ def xy_model(): d = dist.Bernoulli(0.5) - x_axis = pyro.plate('x_axis', 2, dim=-1) - y_axis = pyro.plate('y_axis', 3, dim=-2) - pyro.sample('b', d) + x_axis = pyro.plate("x_axis", 2, dim=-1) + y_axis = pyro.plate("y_axis", 3, dim=-2) + pyro.sample("b", d) with x_axis: - pyro.sample('bx', d.expand_by([2])) + pyro.sample("bx", d.expand_by([2])) with y_axis: - pyro.sample('by', d.expand_by([3, 1])) + pyro.sample("by", d.expand_by([3, 1])) with x_axis, y_axis: - pyro.sample('bxy', d.expand_by([3, 2])) + pyro.sample("bxy", d.expand_by([3, 2])) def test_multi_frame_tensor(): @@ -41,22 +41,20 @@ def test_multi_frame_tensor(): logp = math.log(0.5) expected = { - 'b': torch.ones(torch.Size()) * logp * (1 + 2 + 3 + 6), - 'bx': torch.ones(torch.Size((2,))) * logp * (1 + 1 + 3 + 3), - 'by': torch.ones(torch.Size((3, 1))) * logp * (1 + 2 + 1 + 2), - 'bxy': torch.ones(torch.Size((3, 2))) * logp * (1 + 1 + 1 + 1), + "b": torch.ones(torch.Size()) * logp * (1 + 2 + 3 + 6), + "bx": torch.ones(torch.Size((2,))) * logp * (1 + 1 + 3 + 3), + "by": torch.ones(torch.Size((3, 1))) * logp * (1 + 2 + 1 + 2), + "bxy": torch.ones(torch.Size((3, 2))) * logp * (1 + 1 + 1 + 1), } for name, expected_sum in expected.items(): actual_sum = actual.sum_to(stacks[name]) assert_equal(actual_sum, expected_sum, msg=name) -@pytest.mark.parametrize('max_particles', [250 * 1000, 500 * 1000]) -@pytest.mark.parametrize('scale,krange', [(0.5, (0.7, 0.9)), - (0.95, (0.05, 0.2))]) -@pytest.mark.parametrize('zdim', [1, 5]) +@pytest.mark.parametrize("max_particles", [250 * 1000, 500 * 1000]) +@pytest.mark.parametrize("scale,krange", [(0.5, (0.7, 0.9)), (0.95, (0.05, 0.2))]) +@pytest.mark.parametrize("zdim", [1, 5]) def test_psis_diagnostic(scale, krange, zdim, max_particles, num_particles=500 * 1000): - def model(zdim=1, scale=1.0): with pyro.plate("x_axis", zdim, dim=-1): pyro.sample("z", dist.Normal(0.0, 1.0).expand([zdim])) @@ -65,6 +63,12 @@ def guide(zdim=1, scale=1.0): with pyro.plate("x_axis", zdim, dim=-1): pyro.sample("z", dist.Normal(0.0, scale).expand([zdim])) - k = psis_diagnostic(model, guide, num_particles=num_particles, max_simultaneous_particles=max_particles, - zdim=zdim, scale=scale) + k = psis_diagnostic( + model, + guide, + num_particles=num_particles, + max_simultaneous_particles=max_particles, + zdim=zdim, + scale=scale, + ) assert k > krange[0] and k < krange[1] diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index 4c54015b3d..0caf0f01a2 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -61,7 +61,9 @@ def assert_ok(model, guide, elbo, **kwargs): if hasattr(elbo, "differentiable_loss"): try: pyro.set_rng_seed(0) - differentiable_loss = torch_item(elbo.differentiable_loss(model, guide, **kwargs)) + differentiable_loss = torch_item( + elbo.differentiable_loss(model, guide, **kwargs) + ) except ValueError: pass # Ignore cases where elbo cannot be differentiated else: @@ -79,9 +81,11 @@ def assert_error(model, guide, elbo, match=None): Assert that inference fails with an error. """ pyro.clear_param_store() - inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) - with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), - match=match): + inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) + with pytest.raises( + (NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), + match=match, + ): inference.step() @@ -90,26 +94,28 @@ def assert_warning(model, guide, elbo): Assert that inference works but with a warning. """ pyro.clear_param_store() - inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) + inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inference.step() - assert len(w), 'No warnings were raised' + assert len(w), "No warnings were raised" for warning in w: logger.info(warning) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - TraceTMC_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceTMC_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) @pytest.mark.parametrize("strict_enumeration_warning", [True, False]) def test_nonempty_model_empty_guide_ok(Elbo, strict_enumeration_warning): - def model(): loc = torch.tensor([0.0, 0.0]) scale = torch.tensor([1.0, 1.0]) @@ -125,17 +131,19 @@ def guide(): assert_ok(model, guide, elbo) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - TraceTMC_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceTMC_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) @pytest.mark.parametrize("strict_enumeration_warning", [True, False]) def test_nonempty_model_empty_guide_error(Elbo, strict_enumeration_warning): - def model(): pyro.sample("x", dist.Normal(0, 1)) @@ -146,10 +154,11 @@ def guide(): assert_error(model, guide, elbo) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) @pytest.mark.parametrize("strict_enumeration_warning", [True, False]) def test_empty_model_empty_guide_ok(Elbo, strict_enumeration_warning): - def model(): pass @@ -163,9 +172,10 @@ def guide(): assert_ok(model, guide, elbo) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_variable_clash_in_model_error(Elbo): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -175,12 +185,13 @@ def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) pyro.sample("x", dist.Bernoulli(p)) - assert_error(model, guide, Elbo(), match='Multiple sample sites named') + assert_error(model, guide, Elbo(), match="Multiple sample sites named") -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_model_guide_dim_mismatch_error(Elbo): - def model(): loc = torch.zeros(2) scale = torch.ones(2) @@ -191,13 +202,18 @@ def guide(): scale = pyro.param("scale", torch.ones(2, 1, requires_grad=True)) pyro.sample("x", dist.Normal(loc, scale).to_event(2)) - assert_error(model, guide, Elbo(strict_enumeration_warning=False), - match='invalid log_prob shape|Model and guide event_dims disagree') + assert_error( + model, + guide, + Elbo(strict_enumeration_warning=False), + match="invalid log_prob shape|Model and guide event_dims disagree", + ) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_model_guide_shape_mismatch_error(Elbo): - def model(): loc = torch.zeros(1, 2) scale = torch.ones(1, 2) @@ -208,13 +224,16 @@ def guide(): scale = pyro.param("scale", torch.ones(2, 1, requires_grad=True)) pyro.sample("x", dist.Normal(loc, scale).to_event(2)) - assert_error(model, guide, Elbo(strict_enumeration_warning=False), - match='Model and guide shapes disagree') + assert_error( + model, + guide, + Elbo(strict_enumeration_warning=False), + match="Model and guide shapes disagree", + ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_variable_clash_in_guide_error(Elbo): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -224,11 +243,13 @@ def guide(): pyro.sample("x", dist.Bernoulli(p)) pyro.sample("x", dist.Bernoulli(p)) # Should error here. - assert_error(model, guide, Elbo(), match='Multiple sample sites named') + assert_error(model, guide, Elbo(), match="Multiple sample sites named") @pytest.mark.parametrize("has_rsample", [False, True]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_set_has_rsample_ok(has_rsample, Elbo): # This model has sparse gradients, so users may want to disable @@ -237,10 +258,10 @@ def test_set_has_rsample_ok(has_rsample, Elbo): def model(): z = pyro.sample("z", dist.Normal(0, 1)) loc = (z * 100).clamp(min=0, max=1) # sparse gradients - pyro.sample("x", dist.Normal(loc, 1), obs=torch.tensor(0.)) + pyro.sample("x", dist.Normal(loc, 1), obs=torch.tensor(0.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) + loc = pyro.param("loc", torch.tensor(0.0)) pyro.sample("z", dist.Normal(loc, 1).has_rsample_(has_rsample)) if Elbo is TraceEnum_ELBO: @@ -251,16 +272,17 @@ def guide(): assert_ok(model, guide, Elbo(strict_enumeration_warning=False)) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_not_has_rsample_ok(Elbo): - def model(): z = pyro.sample("z", dist.Normal(0, 1)) p = z.round().clamp(min=0.2, max=0.8) # discontinuous - pyro.sample("x", dist.Bernoulli(p), obs=torch.tensor(0.)) + pyro.sample("x", dist.Bernoulli(p), obs=torch.tensor(0.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) + loc = pyro.param("loc", torch.tensor(0.0)) pyro.sample("z", dist.Normal(loc, 1).has_rsample_(False)) if Elbo is TraceEnum_ELBO: @@ -272,9 +294,10 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_ok(subsample_size, Elbo): - def model(): p = torch.tensor(0.5) for i in pyro.plate("plate", 4, subsample_size): @@ -293,9 +316,10 @@ def guide(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_variable_clash_error(Elbo): - def model(): p = torch.tensor(0.5) for i in pyro.plate("plate", 2): @@ -313,13 +337,14 @@ def guide(): elif Elbo is TraceTMC_ELBO: guide = config_enumerate(guide, num_samples=2) - assert_error(model, guide, Elbo(), match='Multiple sample sites named') + assert_error(model, guide, Elbo(), match="Multiple sample sites named") @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_ok(subsample_size, Elbo): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, subsample_size) as ind: @@ -339,9 +364,10 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_subsample_param_ok(subsample_size, Elbo): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, subsample_size): @@ -349,7 +375,7 @@ def model(): def guide(): with pyro.plate("plate", 10, subsample_size) as ind: - p0 = pyro.param("p0", torch.tensor(0.), event_dim=0) + p0 = pyro.param("p0", torch.tensor(0.0), event_dim=0) assert p0.shape == () p = pyro.param("p", 0.5 * torch.ones(10), event_dim=0) assert len(p) == len(ind) @@ -364,9 +390,10 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_subsample_primitive_ok(subsample_size, Elbo): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, subsample_size): @@ -374,7 +401,7 @@ def model(): def guide(): with pyro.plate("plate", 10, subsample_size) as ind: - p0 = torch.tensor(0.) + p0 = torch.tensor(0.0) p0 = pyro.subsample(p0, event_dim=0) assert p0.shape == () p = 0.5 * torch.ones(10) @@ -391,18 +418,22 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) -@pytest.mark.parametrize("shape,ok", [ - ((), True), - ((1,), True), - ((10,), True), - ((3, 1), True), - ((3, 10), True), - ((5), False), - ((3, 5), False), -]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) +@pytest.mark.parametrize( + "shape,ok", + [ + ((), True), + ((1,), True), + ((10,), True), + ((3, 1), True), + ((3, 10), True), + ((5), False), + ((3, 5), False), + ], +) def test_plate_param_size_mismatch_error(subsample_size, Elbo, shape, ok): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, subsample_size): @@ -425,9 +456,10 @@ def guide(): assert_error(model, guide, Elbo(), match="invalid shape of pyro.param") -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_no_size_ok(Elbo): - def model(): p = torch.tensor(0.5) with pyro.plate("plate"): @@ -446,11 +478,12 @@ def guide(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("max_plate_nesting", [0, float('inf')]) +@pytest.mark.parametrize("max_plate_nesting", [0, float("inf")]) @pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_iplate_ok(subsample_size, Elbo, max_plate_nesting): - def model(): p = torch.tensor(0.5) outer_iplate = pyro.plate("plate_0", 3, subsample_size) @@ -475,11 +508,12 @@ def guide(): assert_ok(model, guide, Elbo(max_plate_nesting=max_plate_nesting)) -@pytest.mark.parametrize("max_plate_nesting", [0, float('inf')]) +@pytest.mark.parametrize("max_plate_nesting", [0, float("inf")]) @pytest.mark.parametrize("subsample_size", [None, 2], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_iplate_swap_ok(subsample_size, Elbo, max_plate_nesting): - def model(): p = torch.tensor(0.5) outer_iplate = pyro.plate("plate_0", 3, subsample_size) @@ -505,9 +539,10 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_in_model_not_guide_ok(subsample_size, Elbo): - def model(): p = torch.tensor(0.5) for i in pyro.plate("plate", 10, subsample_size): @@ -527,10 +562,11 @@ def guide(): @pytest.mark.parametrize("subsample_size", [None, 5], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) @pytest.mark.parametrize("is_validate", [True, False]) def test_iplate_in_guide_not_model_error(subsample_size, Elbo, is_validate): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -543,26 +579,30 @@ def guide(): with pyro.validation_enabled(is_validate): if is_validate: - assert_error(model, guide, Elbo(), - match='Found plate statements in guide but not model') + assert_error( + model, + guide, + Elbo(), + match="Found plate statements in guide but not model", + ) else: assert_ok(model, guide, Elbo()) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_plate_broadcast_error(Elbo): - def model(): p = torch.tensor(0.5, requires_grad=True) with pyro.plate("plate", 10, 5): pyro.sample("x", dist.Bernoulli(p).expand_by([2])) - assert_error(model, model, Elbo(), match='Shape mismatch inside plate') + assert_error(model, model, Elbo(), match="Shape mismatch inside plate") -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_iplate_ok(Elbo): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 3, 2) as ind: @@ -583,9 +623,10 @@ def guide(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_iplate_plate_ok(Elbo): - def model(): p = torch.tensor(0.5) inner_plate = pyro.plate("plate", 3, 2) @@ -608,10 +649,11 @@ def guide(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) @pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)]) def test_plate_stack_ok(Elbo, sizes): - def model(): p = torch.tensor(0.5) with pyro.plate_stack("plate_stack", sizes): @@ -630,10 +672,11 @@ def guide(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) @pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)]) def test_plate_stack_and_plate_ok(Elbo, sizes): - def model(): p = torch.tensor(0.5) with pyro.plate_stack("plate_stack", sizes): @@ -656,7 +699,6 @@ def guide(): @pytest.mark.parametrize("sizes", [(3,), (3, 4), (3, 4, 5)]) def test_plate_stack_sizes(sizes): - def model(): p = 0.5 * torch.ones(3) with pyro.plate_stack("plate_stack", sizes): @@ -666,15 +708,18 @@ def model(): model() -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_nested_plate_plate_ok(Elbo): - def model(): p = torch.tensor(0.5, requires_grad=True) with pyro.plate("plate_outer", 10, 5) as ind_outer: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)])) with pyro.plate("plate_inner", 11, 6) as ind_inner: - pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "y", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)]) + ) if Elbo is TraceEnum_ELBO: guide = config_enumerate(model) @@ -686,9 +731,10 @@ def model(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_plate_reuse_ok(Elbo): - def model(): p = torch.tensor(0.5, requires_grad=True) plate_outer = pyro.plate("plate_outer", 10, 5, dim=-1) @@ -698,7 +744,9 @@ def model(): with plate_inner as ind_inner: pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner), 1])) with plate_outer as ind_outer, plate_inner as ind_inner: - pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)])) + pyro.sample( + "z", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)]) + ) if Elbo is TraceEnum_ELBO: guide = config_enumerate(model) @@ -710,16 +758,21 @@ def model(): assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO]) +@pytest.mark.parametrize( + "Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO, TraceTMC_ELBO] +) def test_nested_plate_plate_dim_error_1(Elbo): - def model(): p = torch.tensor([0.5], requires_grad=True) with pyro.plate("plate_outer", 10, 5) as ind_outer: - pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)])) # error here + pyro.sample( + "x", dist.Bernoulli(p).expand_by([len(ind_outer)]) + ) # error here with pyro.plate("plate_inner", 11, 6) as ind_inner: pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)])) - pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)])) + pyro.sample( + "z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)]) + ) if Elbo is TraceEnum_ELBO: guide = config_enumerate(model) @@ -728,57 +781,61 @@ def model(): else: guide = model - assert_error(model, guide, Elbo(), match='invalid log_prob shape') + assert_error(model, guide, Elbo(), match="invalid log_prob shape") @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nested_plate_plate_dim_error_2(Elbo): - def model(): p = torch.tensor([0.5], requires_grad=True) with pyro.plate("plate_outer", 10, 5) as ind_outer: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1])) with pyro.plate("plate_inner", 11, 6) as ind_inner: - pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_outer)])) # error here - pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)])) + pyro.sample( + "y", dist.Bernoulli(p).expand_by([len(ind_outer)]) + ) # error here + pyro.sample( + "z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_inner)]) + ) guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model - assert_error(model, guide, Elbo(), match='Shape mismatch inside plate') + assert_error(model, guide, Elbo(), match="Shape mismatch inside plate") @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nested_plate_plate_dim_error_3(Elbo): - def model(): p = torch.tensor([0.5], requires_grad=True) with pyro.plate("plate_outer", 10, 5) as ind_outer: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1])) with pyro.plate("plate_inner", 11, 6) as ind_inner: pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)])) - pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_inner), 1])) # error here + pyro.sample( + "z", dist.Bernoulli(p).expand_by([len(ind_inner), 1]) + ) # error here guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model - assert_error(model, guide, Elbo(), match='invalid log_prob shape|shape mismatch') + assert_error(model, guide, Elbo(), match="invalid log_prob shape|shape mismatch") @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nested_plate_plate_dim_error_4(Elbo): - def model(): p = torch.tensor([0.5], requires_grad=True) with pyro.plate("plate_outer", 10, 5) as ind_outer: pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1])) with pyro.plate("plate_inner", 11, 6) as ind_inner: pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)])) - pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_outer)])) # error here + pyro.sample( + "z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_outer)]) + ) # error here guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model - assert_error(model, guide, Elbo(), match='hape mismatch inside plate') + assert_error(model, guide, Elbo(), match="hape mismatch inside plate") @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nested_plate_plate_subsample_param_ok(Elbo): - def model(): with pyro.plate("plate_outer", 10, 5): pyro.sample("x", dist.Bernoulli(0.2)) @@ -806,7 +863,6 @@ def guide(): @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_nonnested_plate_plate_ok(Elbo): - def model(): p = torch.tensor(0.5, requires_grad=True) with pyro.plate("plate_0", 10, 5) as ind1: @@ -824,6 +880,7 @@ def test_three_indep_plate_at_different_depths_ok(): /\ ia ia ia """ + def model(): p = torch.tensor(0.5) inner_plate = pyro.plate("plate2", 10, 5) @@ -854,7 +911,6 @@ def guide(): def test_plate_wrong_size_error(): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, 5) as ind: @@ -865,11 +921,10 @@ def guide(): with pyro.plate("plate", 10, 5) as ind: pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)])) - assert_error(model, guide, TraceGraph_ELBO(), match='Shape mismatch inside plate') + assert_error(model, guide, TraceGraph_ELBO(), match="Shape mismatch inside plate") def test_block_plate_name_ok(): - def model(): a = pyro.sample("a", dist.Normal(0, 1)) assert a.shape == () @@ -894,7 +949,6 @@ def guide(): def test_block_plate_dim_ok(): - def model(): a = pyro.sample("a", dist.Normal(0, 1)) assert a.shape == () @@ -919,7 +973,6 @@ def guide(): def test_block_plate_missing_error(): - def model(): with block_plate("plate"): pyro.sample("a", dist.Normal(0, 1)) @@ -927,14 +980,12 @@ def model(): def guide(): pyro.sample("a", dist.Normal(0, 1)) - assert_error(model, guide, Trace_ELBO(), - match="block_plate matched 0 messengers") + assert_error(model, guide, Trace_ELBO(), match="block_plate matched 0 messengers") @pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_enum_discrete_misuse_warning(Elbo, enumerate_): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -950,7 +1001,6 @@ def guide(): def test_enum_discrete_single_ok(): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -964,7 +1014,6 @@ def guide(): @pytest.mark.parametrize("strict_enumeration_warning", [False, True]) def test_enum_discrete_missing_config_warning(strict_enumeration_warning): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -981,7 +1030,6 @@ def guide(): def test_enum_discrete_single_single_ok(): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p)) @@ -996,7 +1044,6 @@ def guide(): def test_enum_discrete_iplate_single_ok(): - def model(): p = torch.tensor(0.5) for i in pyro.plate("plate", 10, 5): @@ -1011,7 +1058,6 @@ def guide(): def test_plate_enum_discrete_batch_ok(): - def model(): p = torch.tensor(0.5) with pyro.plate("plate", 10, 5) as ind: @@ -1027,7 +1073,6 @@ def guide(): @pytest.mark.parametrize("strict_enumeration_warning", [False, True]) def test_plate_enum_discrete_no_discrete_vars_warning(strict_enumeration_warning): - def model(): loc = torch.tensor(0.0) scale = torch.tensor(1.0) @@ -1049,7 +1094,6 @@ def guide(): def test_no_plate_enum_discrete_batch_error(): - def model(): p = torch.tensor(0.5) pyro.sample("x", dist.Bernoulli(p).expand_by([5])) @@ -1058,34 +1102,38 @@ def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) pyro.sample("x", dist.Bernoulli(p).expand_by([5])) - assert_error(model, config_enumerate(guide), TraceEnum_ELBO(), - match='invalid log_prob shape') + assert_error( + model, config_enumerate(guide), TraceEnum_ELBO(), match="invalid log_prob shape" + ) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2, float('inf')]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2, float("inf")]) def test_enum_discrete_parallel_ok(max_plate_nesting): - guessed_nesting = 0 if max_plate_nesting == float('inf') else max_plate_nesting + guessed_nesting = 0 if max_plate_nesting == float("inf") else max_plate_nesting plate_shape = torch.Size([1] * guessed_nesting) def model(): p = torch.tensor(0.5) x = pyro.sample("x", dist.Bernoulli(p)) - if max_plate_nesting != float('inf'): + if max_plate_nesting != float("inf"): assert x.shape == torch.Size([2]) + plate_shape def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) x = pyro.sample("x", dist.Bernoulli(p)) - if max_plate_nesting != float('inf'): + if max_plate_nesting != float("inf"): assert x.shape == torch.Size([2]) + plate_shape - assert_ok(model, config_enumerate(guide, "parallel"), - TraceEnum_ELBO(max_plate_nesting=max_plate_nesting)) + assert_ok( + model, + config_enumerate(guide, "parallel"), + TraceEnum_ELBO(max_plate_nesting=max_plate_nesting), + ) -@pytest.mark.parametrize('max_plate_nesting', [0, 1, 2, float('inf')]) +@pytest.mark.parametrize("max_plate_nesting", [0, 1, 2, float("inf")]) def test_enum_discrete_parallel_nested_ok(max_plate_nesting): - guessed_nesting = 0 if max_plate_nesting == float('inf') else max_plate_nesting + guessed_nesting = 0 if max_plate_nesting == float("inf") else max_plate_nesting plate_shape = torch.Size([1] * guessed_nesting) def model(): @@ -1093,24 +1141,29 @@ def model(): p3 = torch.ones(3) / 3 x2 = pyro.sample("x2", dist.OneHotCategorical(p2)) x3 = pyro.sample("x3", dist.OneHotCategorical(p3)) - if max_plate_nesting != float('inf'): + if max_plate_nesting != float("inf"): assert x2.shape == torch.Size([2]) + plate_shape + p2.shape assert x3.shape == torch.Size([3, 1]) + plate_shape + p3.shape - assert_ok(model, config_enumerate(model, "parallel"), - TraceEnum_ELBO(max_plate_nesting=max_plate_nesting)) - - -@pytest.mark.parametrize('enumerate_,expand,num_samples', [ - (None, False, None), - ("sequential", False, None), - ("sequential", True, None), - ("parallel", False, None), - ("parallel", True, None), - ("parallel", True, 3), -]) + assert_ok( + model, + config_enumerate(model, "parallel"), + TraceEnum_ELBO(max_plate_nesting=max_plate_nesting), + ) + + +@pytest.mark.parametrize( + "enumerate_,expand,num_samples", + [ + (None, False, None), + ("sequential", False, None), + ("sequential", True, None), + ("parallel", False, None), + ("parallel", True, None), + ("parallel", True, 3), + ], +) def test_enumerate_parallel_plate_ok(enumerate_, expand, num_samples): - def model(): p2 = torch.ones(2) / 2 p34 = torch.ones(3, 4) / 4 @@ -1127,18 +1180,18 @@ def model(): if num_samples: n = num_samples # Meaning of dimensions: [ enum dims | plate dims ] - assert x2.shape == torch.Size([ n, 1, 1]) # noqa: E201 - assert x34.shape == torch.Size([ n, 1, 1, 3]) # noqa: E201 + assert x2.shape == torch.Size([n, 1, 1]) # noqa: E201 + assert x34.shape == torch.Size([n, 1, 1, 3]) # noqa: E201 assert x536.shape == torch.Size([n, 1, 1, 5, 3]) # noqa: E201 elif expand: # Meaning of dimensions: [ enum dims | plate dims ] - assert x2.shape == torch.Size([ 2, 1, 1]) # noqa: E201 - assert x34.shape == torch.Size([ 4, 1, 1, 3]) # noqa: E201 + assert x2.shape == torch.Size([2, 1, 1]) # noqa: E201 + assert x34.shape == torch.Size([4, 1, 1, 3]) # noqa: E201 assert x536.shape == torch.Size([6, 1, 1, 5, 3]) # noqa: E201 else: # Meaning of dimensions: [ enum dims | plate placeholders ] - assert x2.shape == torch.Size([ 2, 1, 1]) # noqa: E201 - assert x34.shape == torch.Size([ 4, 1, 1, 1]) # noqa: E201 + assert x2.shape == torch.Size([2, 1, 1]) # noqa: E201 + assert x34.shape == torch.Size([4, 1, 1, 1]) # noqa: E201 assert x536.shape == torch.Size([6, 1, 1, 1, 1]) # noqa: E201 elif enumerate_ == "sequential": if expand: @@ -1162,16 +1215,18 @@ def model(): assert_ok(model, guide, elbo) -@pytest.mark.parametrize('max_plate_nesting', [1, float('inf')]) -@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) -@pytest.mark.parametrize('is_validate', [True, False]) -def test_enum_discrete_plate_dependency_warning(enumerate_, is_validate, max_plate_nesting): - +@pytest.mark.parametrize("max_plate_nesting", [1, float("inf")]) +@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) +@pytest.mark.parametrize("is_validate", [True, False]) +def test_enum_discrete_plate_dependency_warning( + enumerate_, is_validate, max_plate_nesting +): def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) with pyro.plate("plate", 10, 5): - x = pyro.sample("x", dist.Bernoulli(0.5).expand_by([5]), - infer={'enumerate': enumerate_}) + x = pyro.sample( + "x", dist.Bernoulli(0.5).expand_by([5]), infer={"enumerate": enumerate_} + ) pyro.sample("y", dist.Bernoulli(x.mean())) # user should move this line up with pyro.validation_enabled(is_validate): @@ -1182,35 +1237,41 @@ def model(): assert_ok(model, model, elbo) -@pytest.mark.parametrize('max_plate_nesting', [1, float('inf')]) -@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) +@pytest.mark.parametrize("max_plate_nesting", [1, float("inf")]) +@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) def test_enum_discrete_iplate_plate_dependency_ok(enumerate_, max_plate_nesting): - def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) inner_plate = pyro.plate("plate", 10, 5) for i in pyro.plate("iplate", 3): pyro.sample("y_{}".format(i), dist.Bernoulli(0.5)) with inner_plate: - pyro.sample("x_{}".format(i), dist.Bernoulli(0.5).expand_by([5]), - infer={'enumerate': enumerate_}) + pyro.sample( + "x_{}".format(i), + dist.Bernoulli(0.5).expand_by([5]), + infer={"enumerate": enumerate_}, + ) assert_ok(model, model, TraceEnum_ELBO(max_plate_nesting=max_plate_nesting)) -@pytest.mark.parametrize('max_plate_nesting', [1, float('inf')]) -@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) -@pytest.mark.parametrize('is_validate', [True, False]) -def test_enum_discrete_iplates_plate_dependency_warning(enumerate_, is_validate, max_plate_nesting): - +@pytest.mark.parametrize("max_plate_nesting", [1, float("inf")]) +@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) +@pytest.mark.parametrize("is_validate", [True, False]) +def test_enum_discrete_iplates_plate_dependency_warning( + enumerate_, is_validate, max_plate_nesting +): def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) inner_plate = pyro.plate("plate", 10, 5) for i in pyro.plate("iplate1", 2): with inner_plate: - pyro.sample("x_{}".format(i), dist.Bernoulli(0.5).expand_by([5]), - infer={'enumerate': enumerate_}) + pyro.sample( + "x_{}".format(i), + dist.Bernoulli(0.5).expand_by([5]), + infer={"enumerate": enumerate_}, + ) for i in pyro.plate("iplate2", 2): pyro.sample("y_{}".format(i), dist.Bernoulli(0.5)) @@ -1223,11 +1284,10 @@ def model(): assert_ok(model, model, elbo) -@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) +@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) def test_enum_discrete_plates_dependency_ok(enumerate_): - def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) x_plate = pyro.plate("x_plate", 10, 5, dim=-1) y_plate = pyro.plate("y_plate", 11, 6, dim=-2) pyro.sample("a", dist.Bernoulli(0.5)) @@ -1242,21 +1302,22 @@ def model(): assert_ok(model, model, TraceEnum_ELBO(max_plate_nesting=2)) -@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) +@pytest.mark.parametrize("enumerate_", [None, "sequential", "parallel"]) def test_enum_discrete_non_enumerated_plate_ok(enumerate_): - def model(): - pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'}) + pyro.sample("w", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) with pyro.plate("non_enum", 2): - a = pyro.sample("a", dist.Bernoulli(0.5).expand_by([2]), - infer={'enumerate': None}) + a = pyro.sample( + "a", dist.Bernoulli(0.5).expand_by([2]), infer={"enumerate": None} + ) p = (1.0 + a.sum(-1)) / (2.0 + a.size(0)) # introduce dependency of b on a with pyro.plate("enum_1", 3): - pyro.sample("b", dist.Bernoulli(p).expand_by([3]), - infer={'enumerate': enumerate_}) + pyro.sample( + "b", dist.Bernoulli(p).expand_by([3]), infer={"enumerate": enumerate_} + ) with pyro.validation_enabled(): assert_ok(model, model, TraceEnum_ELBO(max_plate_nesting=1)) @@ -1281,16 +1342,18 @@ def guide(): assert_ok(model, guide, Trace_ELBO()) -@pytest.mark.parametrize('enumerate_,expand,num_samples', [ - (None, True, None), - ("sequential", True, None), - ("sequential", False, None), - ("parallel", True, None), - ("parallel", False, None), - ("parallel", True, 3), -]) +@pytest.mark.parametrize( + "enumerate_,expand,num_samples", + [ + (None, True, None), + ("sequential", True, None), + ("sequential", False, None), + ("parallel", True, None), + ("parallel", False, None), + ("parallel", True, 3), + ], +) def test_enum_discrete_plate_shape_broadcasting_ok(enumerate_, expand, num_samples): - def model(): x_plate = pyro.plate("x_plate", 10, 5, dim=-1) y_plate = pyro.plate("y_plate", 11, 6, dim=-2) @@ -1334,20 +1397,26 @@ def model(): assert c.shape == (50, 6, 1) assert d.shape == (50, 6, 5) - guide = config_enumerate(model, default=enumerate_, expand=expand, num_samples=num_samples) - elbo = TraceEnum_ELBO(max_plate_nesting=3, - strict_enumeration_warning=(enumerate_ == "parallel")) + guide = config_enumerate( + model, default=enumerate_, expand=expand, num_samples=num_samples + ) + elbo = TraceEnum_ELBO( + max_plate_nesting=3, strict_enumeration_warning=(enumerate_ == "parallel") + ) assert_ok(model, guide, elbo) -@pytest.mark.parametrize("Elbo,expand", [ - (Trace_ELBO, False), - (TraceGraph_ELBO, False), - (TraceEnum_ELBO, False), - (TraceEnum_ELBO, True), -]) +@pytest.mark.parametrize( + "Elbo,expand", + [ + (Trace_ELBO, False), + (TraceGraph_ELBO, False), + (TraceEnum_ELBO, False), + (TraceEnum_ELBO, True), + ], +) def test_dim_allocation_ok(Elbo, expand): - enumerate_ = (Elbo is TraceEnum_ELBO) + enumerate_ = Elbo is TraceEnum_ELBO def model(): p = torch.tensor(0.5, requires_grad=True) @@ -1375,18 +1444,23 @@ def model(): assert z.shape == (5, 7, 6) assert q.shape == (8, 5, 7, 6) - guide = config_enumerate(model, "sequential", expand=expand) if enumerate_ else model + guide = ( + config_enumerate(model, "sequential", expand=expand) if enumerate_ else model + ) assert_ok(model, guide, Elbo(max_plate_nesting=4)) -@pytest.mark.parametrize("Elbo,expand", [ - (Trace_ELBO, False), - (TraceGraph_ELBO, False), - (TraceEnum_ELBO, False), - (TraceEnum_ELBO, True), -]) +@pytest.mark.parametrize( + "Elbo,expand", + [ + (Trace_ELBO, False), + (TraceGraph_ELBO, False), + (TraceEnum_ELBO, False), + (TraceEnum_ELBO, True), + ], +) def test_dim_allocation_error(Elbo, expand): - enumerate_ = (Elbo is TraceEnum_ELBO) + enumerate_ = Elbo is TraceEnum_ELBO def model(): p = torch.tensor(0.5, requires_grad=True) @@ -1408,21 +1482,21 @@ def model(): assert y.shape == (5, 6) guide = config_enumerate(model, expand=expand) if Elbo is TraceEnum_ELBO else model - assert_error(model, guide, Elbo(), match='collide at dim=') + assert_error(model, guide, Elbo(), match="collide at dim=") def test_enum_in_model_ok(): - infer = {'enumerate': 'parallel'} + infer = {"enumerate": "parallel"} def model(): - p = pyro.param('p', torch.tensor(0.25)) - a = pyro.sample('a', dist.Bernoulli(p)) - b = pyro.sample('b', dist.Bernoulli(p + a / 2)) - c = pyro.sample('c', dist.Bernoulli(p + b / 2), infer=infer) - d = pyro.sample('d', dist.Bernoulli(p + c / 2)) - e = pyro.sample('e', dist.Bernoulli(p + d / 2)) - f = pyro.sample('f', dist.Bernoulli(p + e / 2), infer=infer) - g = pyro.sample('g', dist.Bernoulli(p + f / 2), obs=torch.tensor(0.)) + p = pyro.param("p", torch.tensor(0.25)) + a = pyro.sample("a", dist.Bernoulli(p)) + b = pyro.sample("b", dist.Bernoulli(p + a / 2)) + c = pyro.sample("c", dist.Bernoulli(p + b / 2), infer=infer) + d = pyro.sample("d", dist.Bernoulli(p + c / 2)) + e = pyro.sample("e", dist.Bernoulli(p + d / 2)) + f = pyro.sample("f", dist.Bernoulli(p + e / 2), infer=infer) + g = pyro.sample("g", dist.Bernoulli(p + f / 2), obs=torch.tensor(0.0)) # check shapes assert a.shape == () @@ -1434,11 +1508,11 @@ def model(): assert g.shape == () def guide(): - p = pyro.param('p', torch.tensor(0.25)) - a = pyro.sample('a', dist.Bernoulli(p)) - b = pyro.sample('b', dist.Bernoulli(p + a / 2), infer=infer) - d = pyro.sample('d', dist.Bernoulli(p + b / 2)) - e = pyro.sample('e', dist.Bernoulli(p + d / 2), infer=infer) + p = pyro.param("p", torch.tensor(0.25)) + a = pyro.sample("a", dist.Bernoulli(p)) + b = pyro.sample("b", dist.Bernoulli(p + a / 2), infer=infer) + d = pyro.sample("d", dist.Bernoulli(p + b / 2)) + e = pyro.sample("e", dist.Bernoulli(p + d / 2), infer=infer) # check shapes assert a.shape == () @@ -1450,18 +1524,18 @@ def guide(): def test_enum_in_model_plate_ok(): - infer = {'enumerate': 'parallel'} + infer = {"enumerate": "parallel"} def model(): - p = pyro.param('p', torch.tensor(0.25)) - a = pyro.sample('a', dist.Bernoulli(p)) - b = pyro.sample('b', dist.Bernoulli(p + a / 2)) - with pyro.plate('data', 3): - c = pyro.sample('c', dist.Bernoulli(p + b / 2), infer=infer) - d = pyro.sample('d', dist.Bernoulli(p + c / 2)) - e = pyro.sample('e', dist.Bernoulli(p + d / 2)) - f = pyro.sample('f', dist.Bernoulli(p + e / 2), infer=infer) - g = pyro.sample('g', dist.Bernoulli(p + f / 2), obs=torch.zeros(3)) + p = pyro.param("p", torch.tensor(0.25)) + a = pyro.sample("a", dist.Bernoulli(p)) + b = pyro.sample("b", dist.Bernoulli(p + a / 2)) + with pyro.plate("data", 3): + c = pyro.sample("c", dist.Bernoulli(p + b / 2), infer=infer) + d = pyro.sample("d", dist.Bernoulli(p + c / 2)) + e = pyro.sample("e", dist.Bernoulli(p + d / 2)) + f = pyro.sample("f", dist.Bernoulli(p + e / 2), infer=infer) + g = pyro.sample("g", dist.Bernoulli(p + f / 2), obs=torch.zeros(3)) # check shapes assert a.shape == () @@ -1473,12 +1547,12 @@ def model(): assert g.shape == (3,) def guide(): - p = pyro.param('p', torch.tensor(0.25)) - a = pyro.sample('a', dist.Bernoulli(p)) - b = pyro.sample('b', dist.Bernoulli(p + a / 2), infer=infer) - with pyro.plate('data', 3): - d = pyro.sample('d', dist.Bernoulli(p + b / 2)) - e = pyro.sample('e', dist.Bernoulli(p + d / 2), infer=infer) + p = pyro.param("p", torch.tensor(0.25)) + a = pyro.sample("a", dist.Bernoulli(p)) + b = pyro.sample("b", dist.Bernoulli(p + a / 2), infer=infer) + with pyro.plate("data", 3): + d = pyro.sample("d", dist.Bernoulli(p + b / 2)) + e = pyro.sample("e", dist.Bernoulli(p + d / 2), infer=infer) # check shapes assert a.shape == () @@ -1490,29 +1564,31 @@ def guide(): def test_enum_sequential_in_model_error(): - def model(): - p = pyro.param('p', torch.tensor(0.25)) - pyro.sample('a', dist.Bernoulli(p), infer={'enumerate': 'sequential'}) + p = pyro.param("p", torch.tensor(0.25)) + pyro.sample("a", dist.Bernoulli(p), infer={"enumerate": "sequential"}) def guide(): pass - assert_error(model, guide, TraceEnum_ELBO(max_plate_nesting=0), - match='Found vars in model but not guide') + assert_error( + model, + guide, + TraceEnum_ELBO(max_plate_nesting=0), + match="Found vars in model but not guide", + ) def test_enum_in_model_plate_reuse_ok(): - @config_enumerate def model(): p = pyro.param("p", torch.tensor([0.2, 0.8])) a = pyro.sample("a", dist.Bernoulli(0.3)).long() with pyro.plate("b_axis", 2): - pyro.sample("b", dist.Bernoulli(p[a]), obs=torch.tensor([0., 1.])) + pyro.sample("b", dist.Bernoulli(p[a]), obs=torch.tensor([0.0, 1.0])) c = pyro.sample("c", dist.Bernoulli(0.3)).long() with pyro.plate("c_axis", 2): - pyro.sample("d", dist.Bernoulli(p[c]), obs=torch.tensor([0., 0.])) + pyro.sample("d", dist.Bernoulli(p[c]), obs=torch.tensor([0.0, 0.0])) def guide(): pass @@ -1521,22 +1597,25 @@ def guide(): def test_enum_in_model_multi_scale_error(): - @config_enumerate def model(): p = pyro.param("p", torch.tensor([0.2, 0.8])) x = pyro.sample("x", dist.Bernoulli(0.3)).long() - with poutine.scale(scale=2.): - pyro.sample("y", dist.Bernoulli(p[x]), obs=torch.tensor(0.)) + with poutine.scale(scale=2.0): + pyro.sample("y", dist.Bernoulli(p[x]), obs=torch.tensor(0.0)) def guide(): pass - assert_error(model, guide, TraceEnum_ELBO(max_plate_nesting=0), - match='Expected all enumerated sample sites to share a common poutine.scale') + assert_error( + model, + guide, + TraceEnum_ELBO(max_plate_nesting=0), + match="Expected all enumerated sample sites to share a common poutine.scale", + ) -@pytest.mark.parametrize('use_vindex', [False, True]) +@pytest.mark.parametrize("use_vindex", [False, True]) def test_enum_in_model_diamond_error(use_vindex): data = torch.tensor([[0, 1], [0, 0]]) @@ -1545,8 +1624,10 @@ def model(): pyro.param("probs_a", torch.tensor([0.45, 0.55])) pyro.param("probs_b", torch.tensor([[0.6, 0.4], [0.4, 0.6]])) pyro.param("probs_c", torch.tensor([[0.75, 0.25], [0.55, 0.45]])) - pyro.param("probs_d", torch.tensor([[[0.4, 0.6], [0.3, 0.7]], - [[0.3, 0.7], [0.2, 0.8]]])) + pyro.param( + "probs_d", + torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), + ) probs_a = pyro.param("probs_a") probs_b = pyro.param("probs_b") probs_c = pyro.param("probs_c") @@ -1569,8 +1650,12 @@ def model(): def guide(): pass - assert_error(model, guide, TraceEnum_ELBO(max_plate_nesting=2), - match='Expected tree-structured plate nesting') + assert_error( + model, + guide, + TraceEnum_ELBO(max_plate_nesting=2), + match="Expected tree-structured plate nesting", + ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -1590,23 +1675,33 @@ def guide(): pyro.clear_param_store() guide = config_enumerate(guide) if Elbo is TraceEnum_ELBO else guide - assert_ok(model, guide, Elbo(num_particles=10, - vectorize_particles=True, - max_plate_nesting=2, - strict_enumeration_warning=False)) - - -@pytest.mark.parametrize('enumerate_,expand,num_samples', [ - (None, False, None), - ("sequential", False, None), - ("sequential", True, None), - ("parallel", False, None), - ("parallel", True, None), - ("parallel", True, 3), -]) -@pytest.mark.parametrize('num_particles', [1, 50]) -def test_enum_discrete_vectorized_num_particles(enumerate_, expand, num_samples, num_particles): - + assert_ok( + model, + guide, + Elbo( + num_particles=10, + vectorize_particles=True, + max_plate_nesting=2, + strict_enumeration_warning=False, + ), + ) + + +@pytest.mark.parametrize( + "enumerate_,expand,num_samples", + [ + (None, False, None), + ("sequential", False, None), + ("sequential", True, None), + ("parallel", False, None), + ("parallel", True, None), + ("parallel", True, 3), + ], +) +@pytest.mark.parametrize("num_particles", [1, 50]) +def test_enum_discrete_vectorized_num_particles( + enumerate_, expand, num_samples, num_particles +): @config_enumerate(default=enumerate_, expand=expand, num_samples=num_samples) def model(): x_plate = pyro.plate("x_plate", 10, 5, dim=-1) @@ -1653,11 +1748,19 @@ def model(): else: if enumerate_ == "parallel": if num_samples and expand: - assert b.shape == (num_samples, 1, 5,) + assert b.shape == ( + num_samples, + 1, + 5, + ) assert c.shape == (num_samples, 1, 6, 1) assert d.shape == (num_samples, 1, num_samples, 6, 5) elif num_samples and not expand: - assert b.shape == (num_samples, 1, 5,) + assert b.shape == ( + num_samples, + 1, + 5, + ) assert c.shape == (num_samples, 1, 6, 1) assert d.shape == (num_samples, 1, 1, 6, 5) elif expand: @@ -1682,14 +1785,19 @@ def model(): assert c.shape == (6, 1) assert d.shape == (6, 5) - assert_ok(model, model, TraceEnum_ELBO(max_plate_nesting=2, - num_particles=num_particles, - vectorize_particles=True, - strict_enumeration_warning=(enumerate_ == "parallel"))) + assert_ok( + model, + model, + TraceEnum_ELBO( + max_plate_nesting=2, + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=(enumerate_ == "parallel"), + ), + ) def test_enum_recycling_chain(): - @config_enumerate def model(): p = pyro.param("p", torch.tensor([[0.2, 0.8], [0.1, 0.9]])) @@ -1705,8 +1813,8 @@ def guide(): assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0)) -@pytest.mark.parametrize('use_vindex', [False, True]) -@pytest.mark.parametrize('markov', [False, True]) +@pytest.mark.parametrize("use_vindex", [False, True]) +@pytest.mark.parametrize("markov", [False, True]) def test_enum_recycling_dbn(markov, use_vindex): # x --> x --> x enum "state" # y | y | y | enum "occlusion" @@ -1729,8 +1837,9 @@ def model(): else: z_ind = torch.arange(4, dtype=torch.long) probs = r[x.unsqueeze(-1), y.unsqueeze(-1), z_ind] - pyro.sample("z_{}".format(t), dist.Categorical(probs), - obs=torch.tensor(0.)) + pyro.sample( + "z_{}".format(t), dist.Categorical(probs), obs=torch.tensor(0.0) + ) def guide(): pass @@ -1782,7 +1891,7 @@ def guide(): assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0)) -@pytest.mark.parametrize('use_vindex', [False, True]) +@pytest.mark.parametrize("use_vindex", [False, True]) def test_enum_recycling_grid(use_vindex): # x---x---x---x -----> i # | | | | | @@ -1803,10 +1912,8 @@ def model(): probs = Vindex(p)[x[i - 1, j], x[i, j - 1]] else: ind = torch.arange(2, dtype=torch.long) - probs = p[x[i - 1, j].unsqueeze(-1), - x[i, j - 1].unsqueeze(-1), ind] - x[i, j] = pyro.sample("x_{}_{}".format(i, j), - dist.Categorical(probs)) + probs = p[x[i - 1, j].unsqueeze(-1), x[i, j - 1].unsqueeze(-1), ind] + x[i, j] = pyro.sample("x_{}_{}".format(i, j), dist.Categorical(probs)) def guide(): pass @@ -1823,15 +1930,19 @@ def test_enum_recycling_reentrant(): def model(data, state=0, address=""): if isinstance(data, bool): p = pyro.param("p_leaf", torch.ones(10)) - pyro.sample("leaf_{}".format(address), - dist.Bernoulli(p[state]), - obs=torch.tensor(1. if data else 0.)) + pyro.sample( + "leaf_{}".format(address), + dist.Bernoulli(p[state]), + obs=torch.tensor(1.0 if data else 0.0), + ) else: p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model(branch, next_state, address + letter) def guide(data): @@ -1840,7 +1951,7 @@ def guide(data): assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0), data=data) -@pytest.mark.parametrize('history', [1, 2]) +@pytest.mark.parametrize("history", [1, 2]) def test_enum_recycling_reentrant_history(history): data = (True, False) for i in range(5): @@ -1850,16 +1961,20 @@ def test_enum_recycling_reentrant_history(history): def model(data, state=0, address=""): if isinstance(data, bool): p = pyro.param("p_leaf", torch.ones(10)) - pyro.sample("leaf_{}".format(address), - dist.Bernoulli(p[state]), - obs=torch.tensor(1. if data else 0.)) + pyro.sample( + "leaf_{}".format(address), + dist.Bernoulli(p[state]), + obs=torch.tensor(1.0 if data else 0.0), + ) else: assert isinstance(data, tuple) p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model(branch, next_state, address + letter) def guide(data): @@ -1875,9 +1990,11 @@ def test_enum_recycling_mutual_recursion(): def model_leaf(data, state=0, address=""): p = pyro.param("p_leaf", torch.ones(10)) - pyro.sample("leaf_{}".format(address), - dist.Bernoulli(p[state]), - obs=torch.tensor(1. if data else 0.)) + pyro.sample( + "leaf_{}".format(address), + dist.Bernoulli(p[state]), + obs=torch.tensor(1.0 if data else 0.0), + ) @pyro.markov def model1(data, state=0, address=""): @@ -1886,9 +2003,11 @@ def model1(data, state=0, address=""): else: p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model2(branch, next_state, address + letter) @pyro.markov @@ -1898,9 +2017,11 @@ def model2(data, state=0, address=""): else: p = pyro.param("p_branch", torch.ones(10, 10)) for branch, letter in zip(data, "abcdefg"): - next_state = pyro.sample("branch_{}".format(address + letter), - dist.Categorical(p[state]), - infer={"enumerate": "parallel"}) + next_state = pyro.sample( + "branch_{}".format(address + letter), + dist.Categorical(p[state]), + infer={"enumerate": "parallel"}, + ) model1(branch, next_state, address + letter) def guide(data): @@ -1910,22 +2031,27 @@ def guide(data): def test_enum_recycling_interleave(): - def model(): with pyro.markov() as m: with pyro.markov(): with m: # error here - pyro.sample("x", dist.Categorical(torch.ones(4)), - infer={"enumerate": "parallel"}) + pyro.sample( + "x", + dist.Categorical(torch.ones(4)), + infer={"enumerate": "parallel"}, + ) def guide(): pass - assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0, strict_enumeration_warning=False)) + assert_ok( + model, + guide, + TraceEnum_ELBO(max_plate_nesting=0, strict_enumeration_warning=False), + ) def test_enum_recycling_plate(): - @config_enumerate def model(): p = pyro.param("p", torch.ones(3, 3)) @@ -1971,16 +2097,18 @@ def guide(): assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=2)) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - TraceTMC_ELBO, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceTMC_ELBO, + ], +) def test_factor_in_model_ok(Elbo): - def model(): - pyro.factor("f", torch.tensor(0.)) + pyro.factor("f", torch.tensor(0.0)) def guide(): pass @@ -1989,27 +2117,28 @@ def guide(): assert_ok(model, guide, elbo) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceGraph_ELBO, - TraceEnum_ELBO, - TraceTMC_ELBO, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceTMC_ELBO, + ], +) def test_factor_in_guide_ok(Elbo): - def model(): pass def guide(): - pyro.factor("f", torch.tensor(0.)) + pyro.factor("f", torch.tensor(0.0)) elbo = Elbo(strict_enumeration_warning=False) assert_ok(model, guide, elbo) -@pytest.mark.parametrize('history', [0, 1, 2, 3]) +@pytest.mark.parametrize("history", [0, 1, 2, 3]) def test_markov_history(history): - @config_enumerate def model(): p = pyro.param("p", 0.25 * torch.ones(2, 2)) @@ -2018,259 +2147,309 @@ def model(): x_curr = torch.tensor(0) for t in pyro.markov(range(10), history=history): probs = p[x_prev, x_curr] - x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long() - pyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), - obs=torch.tensor(0.)) + x_prev, x_curr = ( + x_curr, + pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long(), + ) + pyro.sample( + "y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=torch.tensor(0.0) + ) def guide(): pass if history < 2: - assert_error(model, guide, TraceEnum_ELBO(max_plate_nesting=0), - match="Enumeration dim conflict") + assert_error( + model, + guide, + TraceEnum_ELBO(max_plate_nesting=0), + match="Enumeration dim conflict", + ) else: assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0)) def test_mean_field_ok(): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) - pyro.sample("y", dist.Normal(x, 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) + pyro.sample("y", dist.Normal(x, 1.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - x = pyro.sample("x", dist.Normal(loc, 1.)) - pyro.sample("y", dist.Normal(x, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + x = pyro.sample("x", dist.Normal(loc, 1.0)) + pyro.sample("y", dist.Normal(x, 1.0)) assert_ok(model, guide, TraceMeanField_ELBO()) -@pytest.mark.parametrize('mask', [True, False]) +@pytest.mark.parametrize("mask", [True, False]) def test_mean_field_mask_ok(mask): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.).mask(mask)) - pyro.sample("y", dist.Normal(x, 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0).mask(mask)) + pyro.sample("y", dist.Normal(x, 1.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - x = pyro.sample("x", dist.Normal(loc, 1.).mask(mask)) - pyro.sample("y", dist.Normal(x, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + x = pyro.sample("x", dist.Normal(loc, 1.0).mask(mask)) + pyro.sample("y", dist.Normal(x, 1.0)) assert_ok(model, guide, TraceMeanField_ELBO()) def test_mean_field_warn(): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) - pyro.sample("y", dist.Normal(x, 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) + pyro.sample("y", dist.Normal(x, 1.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - y = pyro.sample("y", dist.Normal(loc, 1.)) - pyro.sample("x", dist.Normal(y, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + y = pyro.sample("y", dist.Normal(loc, 1.0)) + pyro.sample("x", dist.Normal(y, 1.0)) assert_warning(model, guide, TraceMeanField_ELBO()) def test_tail_adaptive_ok(): - def plateless_model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) def plate_model(): - x = pyro.sample("x", dist.Normal(0., 1.)) - with pyro.plate('observe_data'): - pyro.sample('obs', dist.Normal(x, 1.0), obs=torch.arange(5).type_as(x)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) + with pyro.plate("observe_data"): + pyro.sample("obs", dist.Normal(x, 1.0), obs=torch.arange(5).type_as(x)) def rep_guide(): - pyro.sample("x", dist.Normal(0., 2.)) + pyro.sample("x", dist.Normal(0.0, 2.0)) - assert_ok(plateless_model, rep_guide, TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2)) - assert_ok(plate_model, rep_guide, TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2)) + assert_ok( + plateless_model, + rep_guide, + TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2), + ) + assert_ok( + plate_model, + rep_guide, + TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2), + ) def test_tail_adaptive_error(): - def plateless_model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) def rep_guide(): - pyro.sample("x", dist.Normal(0., 2.)) + pyro.sample("x", dist.Normal(0.0, 2.0)) def nonrep_guide(): - pyro.sample("x", fakes.NonreparameterizedNormal(0., 2.)) + pyro.sample("x", fakes.NonreparameterizedNormal(0.0, 2.0)) - assert_error(plateless_model, rep_guide, TraceTailAdaptive_ELBO(vectorize_particles=False, num_particles=2)) - assert_error(plateless_model, nonrep_guide, TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2)) + assert_error( + plateless_model, + rep_guide, + TraceTailAdaptive_ELBO(vectorize_particles=False, num_particles=2), + ) + assert_error( + plateless_model, + nonrep_guide, + TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=2), + ) def test_tail_adaptive_warning(): - def plateless_model(): - pyro.sample("x", dist.Normal(0., 1.)) + pyro.sample("x", dist.Normal(0.0, 1.0)) def rep_guide(): - pyro.sample("x", dist.Normal(0., 2.)) - - assert_warning(plateless_model, rep_guide, TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=1)) - - -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceMeanField_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) + pyro.sample("x", dist.Normal(0.0, 2.0)) + + assert_warning( + plateless_model, + rep_guide, + TraceTailAdaptive_ELBO(vectorize_particles=True, num_particles=1), + ) + + +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceMeanField_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_reparam_ok(Elbo): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) - pyro.sample("y", dist.Normal(x, 1.), obs=torch.tensor(0.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) + pyro.sample("y", dist.Normal(x, 1.0), obs=torch.tensor(0.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - pyro.sample("x", dist.Normal(loc, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + pyro.sample("x", dist.Normal(loc, 1.0)) assert_ok(model, guide, Elbo()) @pytest.mark.parametrize("mask", [True, False, torch.tensor(True), torch.tensor(False)]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceMeanField_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceMeanField_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_reparam_mask_ok(Elbo, mask): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with poutine.mask(mask=mask): - pyro.sample("y", dist.Normal(x, 1.), obs=torch.tensor(0.)) + pyro.sample("y", dist.Normal(x, 1.0), obs=torch.tensor(0.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - pyro.sample("x", dist.Normal(loc, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + pyro.sample("x", dist.Normal(loc, 1.0)) assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("mask", [ - True, - False, - torch.tensor(True), - torch.tensor(False), - torch.tensor([False, True]), -]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceMeanField_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "mask", + [ + True, + False, + torch.tensor(True), + torch.tensor(False), + torch.tensor([False, True]), + ], +) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceMeanField_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_reparam_mask_plate_ok(Elbo, mask): data = torch.randn(2, 3).exp() data /= data.sum(-1, keepdim=True) def model(): - c = pyro.sample("c", dist.LogNormal(0., 1.).expand([3]).to_event(1)) + c = pyro.sample("c", dist.LogNormal(0.0, 1.0).expand([3]).to_event(1)) with pyro.plate("data", len(data)), poutine.mask(mask=mask): pyro.sample("obs", dist.Dirichlet(c), obs=data) def guide(): loc = pyro.param("loc", torch.zeros(3)) - scale = pyro.param("scale", torch.ones(3), - constraint=constraints.positive) + scale = pyro.param("scale", torch.ones(3), constraint=constraints.positive) pyro.sample("c", dist.LogNormal(loc, scale).to_event(1)) assert_ok(model, guide, Elbo()) @pytest.mark.parametrize("num_particles", [1, 2]) -@pytest.mark.parametrize("mask", [ - torch.tensor(True), - torch.tensor(False), - torch.tensor([True]), - torch.tensor([False]), - torch.tensor([False, True, False]), -]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceEnum_ELBO, - TraceGraph_ELBO, - TraceMeanField_ELBO, -]) +@pytest.mark.parametrize( + "mask", + [ + torch.tensor(True), + torch.tensor(False), + torch.tensor([True]), + torch.tensor([False]), + torch.tensor([False, True, False]), + ], +) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, + ], +) def test_obs_mask_ok(Elbo, mask, num_particles): - data = torch.tensor([7., 7., 7.]) + data = torch.tensor([7.0, 7.0, 7.0]) def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with pyro.plate("plate", len(data)): - y = pyro.sample("y", dist.Normal(x, 1.), - obs=data, obs_mask=mask) + y = pyro.sample("y", dist.Normal(x, 1.0), obs=data, obs_mask=mask) assert ((y == data) == mask).all() def guide(): loc = pyro.param("loc", torch.zeros(())) - scale = pyro.param("scale", torch.ones(()), - constraint=constraints.positive) + scale = pyro.param("scale", torch.ones(()), constraint=constraints.positive) x = pyro.sample("x", dist.Normal(loc, scale)) with pyro.plate("plate", len(data)): with poutine.mask(mask=~mask): - pyro.sample("y_unobserved", dist.Normal(x, 1.)) + pyro.sample("y_unobserved", dist.Normal(x, 1.0)) - elbo = Elbo(num_particles=num_particles, vectorize_particles=True, - strict_enumeration_warning=False) + elbo = Elbo( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) assert_ok(model, guide, elbo) @pytest.mark.parametrize("num_particles", [1, 2]) -@pytest.mark.parametrize("mask", [ - torch.tensor(True), - torch.tensor(False), - torch.tensor([True]), - torch.tensor([False]), - torch.tensor([False, True, True, False]), -]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceEnum_ELBO, - TraceGraph_ELBO, - TraceMeanField_ELBO, -]) +@pytest.mark.parametrize( + "mask", + [ + torch.tensor(True), + torch.tensor(False), + torch.tensor([True]), + torch.tensor([False]), + torch.tensor([False, True, True, False]), + ], +) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, + ], +) def test_obs_mask_multivariate_ok(Elbo, mask, num_particles): data = torch.full((4, 3), 7.0) def model(): x = pyro.sample("x", dist.MultivariateNormal(torch.zeros(3), torch.eye(3))) with pyro.plate("plate", len(data)): - y = pyro.sample("y", dist.MultivariateNormal(x, torch.eye(3)), - obs=data, obs_mask=mask) + y = pyro.sample( + "y", dist.MultivariateNormal(x, torch.eye(3)), obs=data, obs_mask=mask + ) assert ((y == data).all(-1) == mask).all() def guide(): loc = pyro.param("loc", torch.zeros(3)) - cov = pyro.param("cov", torch.eye(3), - constraint=constraints.positive_definite) + cov = pyro.param("cov", torch.eye(3), constraint=constraints.positive_definite) x = pyro.sample("x", dist.MultivariateNormal(loc, cov)) with pyro.plate("plate", len(data)): with poutine.mask(mask=~mask): pyro.sample("y_unobserved", dist.MultivariateNormal(x, torch.eye(3))) - elbo = Elbo(num_particles=num_particles, vectorize_particles=True, - strict_enumeration_warning=False) + elbo = Elbo( + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) assert_ok(model, guide, elbo) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceEnum_ELBO, - TraceGraph_ELBO, - TraceMeanField_ELBO, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, + ], +) def test_obs_mask_multivariate_error(Elbo): data = torch.full((3, 2), 7.0) # This mask is invalid because it includes event shape. @@ -2279,8 +2458,9 @@ def test_obs_mask_multivariate_error(Elbo): def model(): x = pyro.sample("x", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) with pyro.plate("plate", len(data)): - pyro.sample("y", dist.MultivariateNormal(x, torch.eye(2)), - obs=data, obs_mask=mask) + pyro.sample( + "y", dist.MultivariateNormal(x, torch.eye(2)), obs=data, obs_mask=mask + ) def guide(): loc = pyro.param("loc", torch.zeros(2)) @@ -2294,73 +2474,82 @@ def guide(): @pytest.mark.parametrize("scale", [1, 0.1, torch.tensor(0.5)]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceMeanField_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceMeanField_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_reparam_scale_ok(Elbo, scale): - def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with poutine.scale(scale=scale): - pyro.sample("y", dist.Normal(x, 1.), obs=torch.tensor(0.)) + pyro.sample("y", dist.Normal(x, 1.0), obs=torch.tensor(0.0)) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - pyro.sample("x", dist.Normal(loc, 1.)) + loc = pyro.param("loc", torch.tensor(0.0)) + pyro.sample("x", dist.Normal(loc, 1.0)) assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("scale", [ - 1, - 0.1, - torch.tensor(0.5), - torch.tensor([0.1, 0.9]), -]) -@pytest.mark.parametrize("Elbo", [ - Trace_ELBO, - TraceMeanField_ELBO, - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "scale", + [ + 1, + 0.1, + torch.tensor(0.5), + torch.tensor([0.1, 0.9]), + ], +) +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceMeanField_ELBO, + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_reparam_scale_plate_ok(Elbo, scale): data = torch.randn(2, 3).exp() data /= data.sum(-1, keepdim=True) def model(): - c = pyro.sample("c", dist.LogNormal(0., 1.).expand([3]).to_event(1)) + c = pyro.sample("c", dist.LogNormal(0.0, 1.0).expand([3]).to_event(1)) with pyro.plate("data", len(data)), poutine.scale(scale=scale): pyro.sample("obs", dist.Dirichlet(c), obs=data) def guide(): loc = pyro.param("loc", torch.zeros(3)) - scale = pyro.param("scale", torch.ones(3), - constraint=constraints.positive) + scale = pyro.param("scale", torch.ones(3), constraint=constraints.positive) pyro.sample("c", dist.LogNormal(loc, scale).to_event(1)) assert_ok(model, guide, Elbo()) -@pytest.mark.parametrize("Elbo", [ - EnergyDistance_prior, - EnergyDistance_noprior, -]) +@pytest.mark.parametrize( + "Elbo", + [ + EnergyDistance_prior, + EnergyDistance_noprior, + ], +) def test_no_log_prob_ok(Elbo): - def model(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): - pyro.sample("obs", dist.Stable(1.5, 0.5, scale, loc), - obs=data) + pyro.sample("obs", dist.Stable(1.5, 0.5, scale, loc), obs=data) def guide(data): - map_loc = pyro.param("map_loc", torch.tensor(0.)) - map_scale = pyro.param("map_scale", torch.tensor(1.), - constraint=constraints.positive) + map_loc = pyro.param("map_loc", torch.tensor(0.0)) + map_scale = pyro.param( + "map_scale", torch.tensor(1.0), constraint=constraints.positive + ) pyro.sample("loc", dist.Delta(map_loc)) pyro.sample("scale", dist.Delta(map_scale)) @@ -2369,19 +2558,18 @@ def guide(data): def test_reparam_stable(): - @poutine.reparam(config={"z": LatentStableReparam()}) def model(): - stability = pyro.sample("stability", dist.Uniform(0., 2.)) - skew = pyro.sample("skew", dist.Uniform(-1., 1.)) + stability = pyro.sample("stability", dist.Uniform(0.0, 2.0)) + skew = pyro.sample("skew", dist.Uniform(-1.0, 1.0)) y = pyro.sample("z", dist.Stable(stability, skew)) - pyro.sample("x", dist.Poisson(y.abs()), obs=torch.tensor(1.)) + pyro.sample("x", dist.Poisson(y.abs()), obs=torch.tensor(1.0)) def guide(): pyro.sample("stability", dist.Delta(torch.tensor(1.5))) - pyro.sample("skew", dist.Delta(torch.tensor(0.))) + pyro.sample("skew", dist.Delta(torch.tensor(0.0))) pyro.sample("z_uniform", dist.Delta(torch.tensor(0.1))) - pyro.sample("z_exponential", dist.Delta(torch.tensor(1.))) + pyro.sample("z_exponential", dist.Delta(torch.tensor(1.0))) assert_ok(model, guide, Trace_ELBO()) @@ -2390,18 +2578,17 @@ def guide(): @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_normal_normal(num_particles): pytest.importorskip("funsor") - data = torch.tensor(0.) + data = torch.tensor(0.0) def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with poutine.collapse(): - y = pyro.sample("y", dist.Normal(x, 1.)) - pyro.sample("z", dist.Normal(y, 1.), obs=data) + y = pyro.sample("y", dist.Normal(x, 1.0)) + pyro.sample("z", dist.Normal(y, 1.0), obs=data) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - scale = pyro.param("scale", torch.tensor(1.), - constraint=constraints.positive) + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("x", dist.Normal(loc, scale)) elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True) @@ -2415,20 +2602,20 @@ def test_collapse_normal_normal_plate(num_particles): data = torch.randn(5) def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with poutine.collapse(): - y = pyro.sample("y", dist.Normal(x, 1.)) + y = pyro.sample("y", dist.Normal(x, 1.0)) with pyro.plate("data", len(data), dim=-1): - pyro.sample("z", dist.Normal(y, 1.), obs=data) + pyro.sample("z", dist.Normal(y, 1.0), obs=data) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - scale = pyro.param("scale", torch.tensor(1.), - constraint=constraints.positive) + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("x", dist.Normal(loc, scale)) - elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, - max_plate_nesting=1) + elbo = Trace_ELBO( + num_particles=num_particles, vectorize_particles=True, max_plate_nesting=1 + ) assert_ok(model, guide, elbo) @@ -2439,20 +2626,20 @@ def test_collapse_normal_plate_normal(num_particles): data = torch.randn(5) def model(): - x = pyro.sample("x", dist.Normal(0., 1.)) + x = pyro.sample("x", dist.Normal(0.0, 1.0)) with poutine.collapse(): with pyro.plate("data", len(data), dim=-1): - y = pyro.sample("y", dist.Normal(x, 1.)) - pyro.sample("z", dist.Normal(y, 1.), obs=data) + y = pyro.sample("y", dist.Normal(x, 1.0)) + pyro.sample("z", dist.Normal(y, 1.0), obs=data) def guide(): - loc = pyro.param("loc", torch.tensor(0.)) - scale = pyro.param("scale", torch.tensor(1.), - constraint=constraints.positive) + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("x", dist.Normal(loc, scale)) - elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, - max_plate_nesting=1) + elbo = Trace_ELBO( + num_particles=num_particles, vectorize_particles=True, max_plate_nesting=1 + ) assert_ok(model, guide, elbo) @@ -2461,7 +2648,7 @@ def guide(): @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_beta_bernoulli(num_particles): pytest.importorskip("funsor") - data = torch.tensor(0.) + data = torch.tensor(0.0) def model(): c = pyro.sample("c", dist.Gamma(1, 1)) @@ -2470,8 +2657,8 @@ def model(): pyro.sample("obs", dist.Bernoulli(probs), obs=data) def guide(): - a = pyro.param("a", torch.tensor(1.), constraint=constraints.positive) - b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive) + a = pyro.param("a", torch.tensor(1.0), constraint=constraints.positive) + b = pyro.param("b", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("c", dist.Gamma(a, b)) elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True) @@ -2483,7 +2670,7 @@ def guide(): @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_beta_binomial(num_particles): pytest.importorskip("funsor") - data = torch.tensor(5.) + data = torch.tensor(5.0) def model(): c = pyro.sample("c", dist.Gamma(1, 1)) @@ -2492,8 +2679,8 @@ def model(): pyro.sample("obs", dist.Binomial(10, probs), obs=data) def guide(): - a = pyro.param("a", torch.tensor(1.), constraint=constraints.positive) - b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive) + a = pyro.param("a", torch.tensor(1.0), constraint=constraints.positive) + b = pyro.param("b", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("c", dist.Gamma(a, b)) elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True) @@ -2505,23 +2692,23 @@ def guide(): @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_beta_binomial_plate(num_particles): pytest.importorskip("funsor") - data = torch.tensor([0., 1., 5., 5.]) + data = torch.tensor([0.0, 1.0, 5.0, 5.0]) def model(): c = pyro.sample("c", dist.Gamma(1, 1)) with poutine.collapse(): probs = pyro.sample("probs", dist.Beta(c, 2)) with pyro.plate("plate", len(data)): - pyro.sample("obs", dist.Binomial(10, probs), - obs=data) + pyro.sample("obs", dist.Binomial(10, probs), obs=data) def guide(): - a = pyro.param("a", torch.tensor(1.), constraint=constraints.positive) - b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive) + a = pyro.param("a", torch.tensor(1.0), constraint=constraints.positive) + b = pyro.param("b", torch.tensor(1.0), constraint=constraints.positive) pyro.sample("c", dist.Gamma(a, b)) - elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, - max_plate_nesting=1) + elbo = Trace_ELBO( + num_particles=num_particles, vectorize_particles=True, max_plate_nesting=1 + ) assert_ok(model, guide, elbo) @@ -2529,7 +2716,7 @@ def guide(): @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_barrier(num_particles): pytest.importorskip("funsor") - data = torch.tensor([0., 1., 5., 5.]) + data = torch.tensor([0.0, 1.0, 5.0, 5.0]) def model(): with poutine.collapse(): @@ -2551,11 +2738,15 @@ def guide(): def test_ordered_logistic_plate(): N = 5 # num data points/batch size K = 4 # num categories - data = (K*torch.rand(N)).long().float() + data = (K * torch.rand(N)).long().float() def model(): - predictor = pyro.sample("predictor", dist.Normal(0., 1.).expand([N]).to_event(1)) - cutpoints = pyro.sample("cutpoints", dist.Normal(0., 1.).expand([K-1]).to_event(1)) + predictor = pyro.sample( + "predictor", dist.Normal(0.0, 1.0).expand([N]).to_event(1) + ) + cutpoints = pyro.sample( + "cutpoints", dist.Normal(0.0, 1.0).expand([K - 1]).to_event(1) + ) # would have identifiability issues, but this isn't a real model... cutpoints = torch.sort(cutpoints, dim=-1).values with pyro.plate("obs_plate", N): @@ -2565,8 +2756,8 @@ def guide(): # parameters pred_mu = pyro.param("pred_mu", torch.zeros(N)) pred_std = pyro.param("pred_std", torch.ones(N)) - cp_mu = pyro.param("cp_mu", torch.zeros(K-1)) - cp_std = pyro.param("cp_std", torch.ones(K-1)) + cp_mu = pyro.param("cp_mu", torch.zeros(K - 1)) + cp_std = pyro.param("cp_std", torch.ones(K - 1)) # sample pyro.sample("predictor", dist.Normal(pred_mu, pred_std).to_event(1)) pyro.sample("cutpoints", dist.Normal(cp_mu, cp_std).to_event(1)) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index a14bb86089..d57659b25d 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -40,23 +40,35 @@ def setup_chain(self, N): self.lambdas = list(map(lambda x: torch.tensor([x]), lambdas)) self.lambda_tilde_posts = [self.lambdas[0]] for k in range(1, self.N): - lambda_tilde_k = (self.lambdas[k] * self.lambda_tilde_posts[k - 1]) /\ - (self.lambdas[k] + self.lambda_tilde_posts[k - 1]) + lambda_tilde_k = (self.lambdas[k] * self.lambda_tilde_posts[k - 1]) / ( + self.lambdas[k] + self.lambda_tilde_posts[k - 1] + ) self.lambda_tilde_posts.append(lambda_tilde_k) - self.lambda_posts = [None] # this is never used (just a way of shifting the indexing by 1) + self.lambda_posts = [ + None + ] # this is never used (just a way of shifting the indexing by 1) for k in range(1, self.N): lambda_k = self.lambdas[k] + self.lambda_tilde_posts[k - 1] self.lambda_posts.append(lambda_k) - lambda_N_post = (self.n_data * torch.tensor(1.0).expand_as(self.lambdas[N]) * self.lambdas[N]) +\ - self.lambda_tilde_posts[N - 1] + lambda_N_post = ( + self.n_data * torch.tensor(1.0).expand_as(self.lambdas[N]) * self.lambdas[N] + ) + self.lambda_tilde_posts[N - 1] self.lambda_posts.append(lambda_N_post) self.target_kappas = [None] - self.target_kappas.extend([self.lambdas[k] / self.lambda_posts[k] for k in range(1, self.N)]) + self.target_kappas.extend( + [self.lambdas[k] / self.lambda_posts[k] for k in range(1, self.N)] + ) self.target_mus = [None] - self.target_mus.extend([self.loc0 * self.lambda_tilde_posts[k - 1] / self.lambda_posts[k] - for k in range(1, self.N)]) - target_loc_N = self.sum_data * self.lambdas[N] / lambda_N_post +\ - self.loc0 * self.lambda_tilde_posts[N - 1] / lambda_N_post + self.target_mus.extend( + [ + self.loc0 * self.lambda_tilde_posts[k - 1] / self.lambda_posts[k] + for k in range(1, self.N) + ] + ) + target_loc_N = ( + self.sum_data * self.lambdas[N] / lambda_N_post + + self.loc0 * self.lambda_tilde_posts[N - 1] / lambda_N_post + ) self.target_mus.append(target_loc_N) self.which_nodes_reparam = self.setup_reparam_mask(N) @@ -76,26 +88,46 @@ def model(self, reparameterized, difficulty=0.0): loc_N = next_mean with pyro.plate("data", self.data.size(0)): - pyro.sample("obs", dist.Normal(loc_N, - torch.pow(self.lambdas[self.N], -0.5)), obs=self.data) + pyro.sample( + "obs", + dist.Normal(loc_N, torch.pow(self.lambdas[self.N], -0.5)), + obs=self.data, + ) return loc_N def guide(self, reparameterized, difficulty=0.0): previous_sample = None for k in reversed(range(1, self.N + 1)): - loc_q = pyro.param("loc_q_%d" % k, self.target_mus[k].detach() + difficulty * (0.1 * torch.randn(1) - 0.53)) - log_sig_q = pyro.param("log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k]).data + - difficulty * (0.1 * torch.randn(1) - 0.53)) + loc_q = pyro.param( + "loc_q_%d" % k, + self.target_mus[k].detach() + + difficulty * (0.1 * torch.randn(1) - 0.53), + ) + log_sig_q = pyro.param( + "log_sig_q_%d" % k, + -0.5 * torch.log(self.lambda_posts[k]).data + + difficulty * (0.1 * torch.randn(1) - 0.53), + ) sig_q = torch.exp(log_sig_q) kappa_q = None if k != self.N: - kappa_q = pyro.param("kappa_q_%d" % k, self.target_kappas[k].data + - difficulty * (0.1 * torch.randn(1) - 0.53)) + kappa_q = pyro.param( + "kappa_q_%d" % k, + self.target_kappas[k].data + + difficulty * (0.1 * torch.randn(1) - 0.53), + ) mean_function = loc_q if k == self.N else kappa_q * previous_sample + loc_q node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False - Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal - loc_latent = pyro.sample("loc_latent_%d" % k, Normal(mean_function, sig_q), - infer=dict(baseline=dict(use_decaying_avg_baseline=True))) + Normal = ( + dist.Normal + if reparameterized or node_flagged + else fakes.NonreparameterizedNormal + ) + loc_latent = pyro.sample( + "loc_latent_%d" % k, + Normal(mean_function, sig_q), + infer=dict(baseline=dict(use_decaying_avg_baseline=True)), + ) previous_sample = loc_latent return previous_sample @@ -103,7 +135,6 @@ def guide(self, reparameterized, difficulty=0.0): @pytest.mark.stage("integration", "integration_batch_1") @pytest.mark.init(rng_seed=0) class GaussianChainTests(GaussianChain): - def test_elbo_reparameterized_N_is_3(self): self.setup_chain(3) self.do_elbo_test(True, 1100, 0.0058, 0.03, difficulty=1.0) @@ -112,8 +143,10 @@ def test_elbo_reparameterized_N_is_8(self): self.setup_chain(8) self.do_elbo_test(True, 1100, 0.0059, 0.03, difficulty=1.0) - @pytest.mark.skipif("CI" in os.environ and os.environ["CI"] == "true", - reason="Skip slow test in travis.") + @pytest.mark.skipif( + "CI" in os.environ and os.environ["CI"] == "true", + reason="Skip slow test in travis.", + ) def test_elbo_reparameterized_N_is_17(self): self.setup_chain(17) self.do_elbo_test(True, 2700, 0.0044, 0.03, difficulty=1.0) @@ -126,17 +159,24 @@ def test_elbo_nonreparameterized_N_is_5(self): self.setup_chain(5) self.do_elbo_test(False, 1000, 0.0061, 0.06, difficulty=0.6) - @pytest.mark.skipif("CI" in os.environ and os.environ["CI"] == "true", - reason="Skip slow test in travis.") + @pytest.mark.skipif( + "CI" in os.environ and os.environ["CI"] == "true", + reason="Skip slow test in travis.", + ) def test_elbo_nonreparameterized_N_is_7(self): self.setup_chain(7) self.do_elbo_test(False, 1800, 0.0035, 0.05, difficulty=0.6) def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0): - n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized else self.N - logger.info(" - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST [reparameterized = %s; %d/%d] - - - - - " % - (self.N, reparameterized, n_repa_nodes, self.N)) + n_repa_nodes = ( + torch.sum(self.which_nodes_reparam) if not reparameterized else self.N + ) + logger.info( + " - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST [reparameterized = %s; %d/%d] - - - - - " + % (self.N, reparameterized, n_repa_nodes, self.N) + ) if self.N < 0: + def array_to_string(y): return str(map(lambda x: "%.3f" % x.detach().cpu().numpy()[0], y)) @@ -144,14 +184,18 @@ def array_to_string(y): logger.debug("target_mus: " + array_to_string(self.target_mus[1:])) logger.debug("target_kappas: " + array_to_string(self.target_kappas[1:])) logger.debug("lambda_posts: " + array_to_string(self.lambda_posts[1:])) - logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts)) + logger.debug( + "lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts) + ) pyro.clear_param_store() adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)}) elbo = TraceGraph_ELBO() loss_and_grads = elbo.loss_and_grads # loss_and_grads = elbo.jit_loss_and_grads # This fails. - svi = SVI(self.model, self.guide, adam, loss=elbo.loss, loss_and_grads=loss_and_grads) + svi = SVI( + self.model, self.guide, adam, loss=elbo.loss, loss_and_grads=loss_and_grads + ) for step in range(n_steps): t0 = time.time() @@ -165,16 +209,42 @@ def array_to_string(y): kappa_errors.append(kappa_error) loc_errors.append(param_mse("loc_q_%d" % k, self.target_mus[k])) - log_sig_error = param_mse("log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k])) + log_sig_error = param_mse( + "log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k]) + ) log_sig_errors.append(log_sig_error) - max_errors = (np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors)) - min_errors = (np.min(loc_errors), np.min(log_sig_errors), np.min(kappa_errors)) - mean_errors = (np.mean(loc_errors), np.mean(log_sig_errors), np.mean(kappa_errors)) - logger.debug("[max errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % max_errors) - logger.debug("[min errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % min_errors) - logger.debug("[mean errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % mean_errors) - logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) + max_errors = ( + np.max(loc_errors), + np.max(log_sig_errors), + np.max(kappa_errors), + ) + min_errors = ( + np.min(loc_errors), + np.min(log_sig_errors), + np.min(kappa_errors), + ) + mean_errors = ( + np.mean(loc_errors), + np.mean(log_sig_errors), + np.mean(kappa_errors), + ) + logger.debug( + "[max errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" + % max_errors + ) + logger.debug( + "[min errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" + % min_errors + ) + logger.debug( + "[mean errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" + % mean_errors + ) + logger.debug( + "[step time = %.3f; N = %d; step = %d]\n" + % (time.time() - t0, self.N, step) + ) assert_equal(0.0, max_errors[0], prec=prec) assert_equal(0.0, max_errors[1], prec=prec) @@ -184,13 +254,12 @@ def array_to_string(y): @pytest.mark.stage("integration", "integration_batch_2") @pytest.mark.init(rng_seed=0) class GaussianPyramidTests(TestCase): - def setUp(self): self.loc0 = torch.tensor([0.52]) def setup_pyramid(self, N): # pyramid of normals with known covariances and latent means - assert(N > 1) + assert N > 1 self.N = N # number of layers in the pyramid lambdas = [1.1 * (k + 1) / N for k in range(N + 2)] self.lambdas = list(map(lambda x: torch.tensor([x]), lambdas)) @@ -201,8 +270,10 @@ def setup_pyramid(self, N): for i in range(bottom_layer_size): data_i = [] for k in range(self.N_data): - data_i.append(torch.tensor([0.25]) + - (0.1 + 0.4 * (i + 1) / bottom_layer_size) * torch.randn(1)) + data_i.append( + torch.tensor([0.25]) + + (0.1 + 0.4 * (i + 1) / bottom_layer_size) * torch.randn(1) + ) self.data.append(data_i) self.data_sums = [sum(self.data[i]) for i in range(bottom_layer_size)] self.N_data = torch.tensor([float(self.N_data)]) @@ -238,33 +309,45 @@ def set_model_permutations(self): def test_elbo_reparameterized_three_layers(self): self.setup_pyramid(3) - self.do_elbo_test(True, 1700, 0.01, 0.04, 0.92, - difficulty=0.8, model_permutation=False) + self.do_elbo_test( + True, 1700, 0.01, 0.04, 0.92, difficulty=0.8, model_permutation=False + ) @pytest.mark.skipif("CI" in os.environ, reason="slow test") def test_elbo_reparameterized_four_layers(self): self.setup_pyramid(4) - self.do_elbo_test(True, 20000, 0.0015, 0.04, 0.92, - difficulty=0.8, model_permutation=False) + self.do_elbo_test( + True, 20000, 0.0015, 0.04, 0.92, difficulty=0.8, model_permutation=False + ) @pytest.mark.stage("integration", "integration_batch_1") def test_elbo_nonreparameterized_two_layers(self): self.setup_pyramid(2) - self.do_elbo_test(False, 500, 0.012, 0.04, 0.95, difficulty=0.5, model_permutation=False) + self.do_elbo_test( + False, 500, 0.012, 0.04, 0.95, difficulty=0.5, model_permutation=False + ) def test_elbo_nonreparameterized_three_layers(self): self.setup_pyramid(3) - self.do_elbo_test(False, 9100, 0.00506, 0.04, 0.95, difficulty=0.5, model_permutation=False) + self.do_elbo_test( + False, 9100, 0.00506, 0.04, 0.95, difficulty=0.5, model_permutation=False + ) def test_elbo_nonreparameterized_two_layers_model_permuted(self): self.setup_pyramid(2) - self.do_elbo_test(False, 700, 0.018, 0.05, 0.96, difficulty=0.5, model_permutation=True) - - @pytest.mark.skipif("CI" in os.environ and os.environ["CI"] == "true", - reason="Skip slow test in travis.") + self.do_elbo_test( + False, 700, 0.018, 0.05, 0.96, difficulty=0.5, model_permutation=True + ) + + @pytest.mark.skipif( + "CI" in os.environ and os.environ["CI"] == "true", + reason="Skip slow test in travis.", + ) def test_elbo_nonreparameterized_three_layers_model_permuted(self): self.setup_pyramid(3) - self.do_elbo_test(False, 6500, 0.0071, 0.05, 0.96, difficulty=0.4, model_permutation=True) + self.do_elbo_test( + False, 6500, 0.0071, 0.05, 0.96, difficulty=0.4, model_permutation=True + ) def calculate_variational_targets(self): # calculate (some of the) variational parameters corresponding to exact posterior @@ -283,7 +366,7 @@ def calc_lambda_C(lA, lB, lC): for n in range(2, self.N + 1): new_names = [] for prev_name in previous_names: - for LR in ['L', 'R']: + for LR in ["L", "R"]: new_names.append(prev_name + LR) self.target_lambdas[new_names[-1]] = self.lambdas[n - 1] previous_names = new_names @@ -295,7 +378,7 @@ def calc_lambda_C(lA, lB, lC): new_names = [] for prev_name in previous_names: BC_names = [] - for LR in ['L', 'R']: + for LR in ["L", "R"]: new_names.append(prev_name + LR) BC_names.append(new_names[-1]) lambda_A0 = self.target_lambdas[prev_name] @@ -317,15 +400,24 @@ def calc_lambda_C(lA, lB, lC): leftmost_node_suffix = self.q_topo_sort[0][11:] leftmost_lambda = self.target_lambdas[leftmost_node_suffix] - self.target_leftmost_constant = self.data_sums[0] * self.lambdas[-1] / leftmost_lambda - self.target_leftmost_constant += self.loc0 * (leftmost_lambda - self.N_data * self.lambdas[-1]) /\ - leftmost_lambda - - almost_leftmost_node_suffix = leftmost_node_suffix[:-1] + 'R' + self.target_leftmost_constant = ( + self.data_sums[0] * self.lambdas[-1] / leftmost_lambda + ) + self.target_leftmost_constant += ( + self.loc0 + * (leftmost_lambda - self.N_data * self.lambdas[-1]) + / leftmost_lambda + ) + + almost_leftmost_node_suffix = leftmost_node_suffix[:-1] + "R" almost_leftmost_lambda = self.target_lambdas[almost_leftmost_node_suffix] result = self.lambdas[-1] * self.data_sums[1] - result += (almost_leftmost_lambda - self.N_data * self.lambdas[-1]) \ - * self.loc0 * old_left_pivot_lambda / (old_left_pivot_lambda + self.lambdas[-2]) + result += ( + (almost_leftmost_lambda - self.N_data * self.lambdas[-1]) + * self.loc0 + * old_left_pivot_lambda + / (old_left_pivot_lambda + self.lambdas[-2]) + ) self.target_almost_leftmost_constant = result / almost_leftmost_lambda # construct dependency structure for the guide @@ -337,14 +429,14 @@ def add_edge(s): if s == "1": deps.extend(["1L", "1R"]) else: - if s[-1] == 'R': - deps.append(s[0:-1] + 'L') + if s[-1] == "R": + deps.append(s[0:-1] + "L") if len(s) < self.N: - deps.extend([s + 'L', s + 'R']) + deps.extend([s + "L", s + "R"]) for k in range(len(s) - 2): - base = s[1:-1 - k] - if base[-1] == 'R': - deps.append('1' + base[:-1] + 'L') + base = s[1 : -1 - k] + if base[-1] == "R": + deps.append("1" + base[:-1] + "L") for dep in deps: g.add_edge("loc_latent_" + dep, "loc_latent_" + s) @@ -353,7 +445,7 @@ def add_edge(s): for n in range(2, self.N + 1): new_names = [] for prev_name in previous_names: - for LR in ['L', 'R']: + for LR in ["L", "R"]: new_name = prev_name + LR new_names.append(new_name) add_edge(new_name) @@ -381,9 +473,11 @@ def unpermute(x, n): for n in range(2, self.N + 1): new_latents_and_names = [] for prev_latent, prev_name in permute(previous_latents_and_names, n - 1): - latent_dist = dist.Normal(prev_latent, torch.pow(self.lambdas[n - 1], -0.5)) + latent_dist = dist.Normal( + prev_latent, torch.pow(self.lambdas[n - 1], -0.5) + ) couple = [] - for LR in ['L', 'R']: + for LR in ["L", "R"]: new_name = prev_name + LR loc_latent_LR = pyro.sample(new_name, latent_dist) couple.append([loc_latent_LR, new_name]) @@ -396,9 +490,14 @@ def unpermute(x, n): for i, data_i in enumerate(self.data): for k, x in enumerate(data_i): - pyro.sample("obs_%s_%d" % (previous_latents_and_names[i][1], k), - dist.Normal(previous_latents_and_names[i][0], torch.pow(self.lambdas[-1], -0.5)), - obs=x) + pyro.sample( + "obs_%s_%d" % (previous_latents_and_names[i][1], k), + dist.Normal( + previous_latents_and_names[i][0], + torch.pow(self.lambdas[-1], -0.5), + ), + obs=x, + ) return top_latent def guide(self, reparameterized, model_permutation, difficulty=0.0): @@ -408,46 +507,101 @@ def guide(self, reparameterized, model_permutation, difficulty=0.0): for i, node in enumerate(self.q_topo_sort): deps = self.q_dag.predecessors(node) node_suffix = node[11:] - log_sig_node = pyro.param("log_sig_" + node_suffix, - -0.5 * torch.log(self.target_lambdas[node_suffix]) + - difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2))) - mean_function_node = pyro.param("constant_term_" + node, - self.loc0 + torch.Tensor([difficulty * i / n_nodes])) + log_sig_node = pyro.param( + "log_sig_" + node_suffix, + -0.5 * torch.log(self.target_lambdas[node_suffix]) + + difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2)), + ) + mean_function_node = pyro.param( + "constant_term_" + node, + self.loc0 + torch.Tensor([difficulty * i / n_nodes]), + ) for dep in deps: - kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[11:], - torch.tensor([0.5 + difficulty * i / n_nodes])) + kappa_dep = pyro.param( + "kappa_" + node_suffix + "_" + dep[11:], + torch.tensor([0.5 + difficulty * i / n_nodes]), + ) mean_function_node = mean_function_node + kappa_dep * latents_dict[dep] node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False - Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal - latent_node = pyro.sample(node, Normal(mean_function_node, torch.exp(log_sig_node)), - infer=dict(baseline=dict(use_decaying_avg_baseline=True, - baseline_beta=0.96))) + Normal = ( + dist.Normal + if reparameterized or node_flagged + else fakes.NonreparameterizedNormal + ) + latent_node = pyro.sample( + node, + Normal(mean_function_node, torch.exp(log_sig_node)), + infer=dict( + baseline=dict(use_decaying_avg_baseline=True, baseline_beta=0.96) + ), + ) latents_dict[node] = latent_node - return latents_dict['loc_latent_1'] - - def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1, - difficulty=1.0, model_permutation=False): - n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \ + return latents_dict["loc_latent_1"] + + def do_elbo_test( + self, + reparameterized, + n_steps, + lr, + prec, + beta1, + difficulty=1.0, + model_permutation=False, + ): + n_repa_nodes = ( + torch.sum(self.which_nodes_reparam) + if not reparameterized else len(self.q_topo_sort) - logger.info((" - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + - "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -") % - (self.N, (2 ** self.N) - 1, reparameterized, n_repa_nodes, - len(self.q_topo_sort), model_permutation)) + ) + logger.info( + ( + " - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + + "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -" + ) + % ( + self.N, + (2 ** self.N) - 1, + reparameterized, + n_repa_nodes, + len(self.q_topo_sort), + model_permutation, + ) + ) pyro.clear_param_store() # check graph structure is as expected but only for N=2 if self.N == 2: - guide_trace = pyro.poutine.trace(self.guide, - graph_type="dense").get_trace(reparameterized=reparameterized, - model_permutation=model_permutation, - difficulty=difficulty) - expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_loc_latent_1R', '_RETURN', - 'loc_latent_1R', 'loc_latent_1', 'constant_term_loc_latent_1', 'loc_latent_1L', - 'constant_term_loc_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', - 'log_sig_1']) - expected_edges = set([('loc_latent_1R', 'loc_latent_1'), ('loc_latent_1L', 'loc_latent_1R'), - ('loc_latent_1L', 'loc_latent_1')]) + guide_trace = pyro.poutine.trace(self.guide, graph_type="dense").get_trace( + reparameterized=reparameterized, + model_permutation=model_permutation, + difficulty=difficulty, + ) + expected_nodes = set( + [ + "log_sig_1R", + "kappa_1_1L", + "_INPUT", + "constant_term_loc_latent_1R", + "_RETURN", + "loc_latent_1R", + "loc_latent_1", + "constant_term_loc_latent_1", + "loc_latent_1L", + "constant_term_loc_latent_1L", + "log_sig_1L", + "kappa_1_1R", + "kappa_1R_1L", + "log_sig_1", + ] + ) + expected_edges = set( + [ + ("loc_latent_1R", "loc_latent_1"), + ("loc_latent_1L", "loc_latent_1R"), + ("loc_latent_1L", "loc_latent_1"), + ] + ) assert expected_nodes == set(guide_trace.nodes) assert expected_edges == set(guide_trace.edges) @@ -456,28 +610,42 @@ def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1, for step in range(n_steps): t0 = time.time() - svi.step(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty) + svi.step( + reparameterized=reparameterized, + model_permutation=model_permutation, + difficulty=difficulty, + ) if step % 5000 == 0 or step == n_steps - 1: log_sig_errors = [] for node in self.target_lambdas: target_log_sig = -0.5 * torch.log(self.target_lambdas[node]) - log_sig_error = param_mse('log_sig_' + node, target_log_sig) + log_sig_error = param_mse("log_sig_" + node, target_log_sig) log_sig_errors.append(log_sig_error) max_log_sig_error = np.max(log_sig_errors) min_log_sig_error = np.min(log_sig_errors) mean_log_sig_error = np.mean(log_sig_errors) leftmost_node = self.q_topo_sort[0] - leftmost_constant_error = param_mse('constant_term_' + leftmost_node, - self.target_leftmost_constant) - almost_leftmost_constant_error = param_mse('constant_term_' + leftmost_node[:-1] + 'R', - self.target_almost_leftmost_constant) - - logger.debug("[mean function constant errors (partial)] %.4f %.4f" % - (leftmost_constant_error, almost_leftmost_constant_error)) - logger.debug("[min/mean/max log(scale) errors] %.4f %.4f %.4f" % - (min_log_sig_error, mean_log_sig_error, max_log_sig_error)) - logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) + leftmost_constant_error = param_mse( + "constant_term_" + leftmost_node, self.target_leftmost_constant + ) + almost_leftmost_constant_error = param_mse( + "constant_term_" + leftmost_node[:-1] + "R", + self.target_almost_leftmost_constant, + ) + + logger.debug( + "[mean function constant errors (partial)] %.4f %.4f" + % (leftmost_constant_error, almost_leftmost_constant_error) + ) + logger.debug( + "[min/mean/max log(scale) errors] %.4f %.4f %.4f" + % (min_log_sig_error, mean_log_sig_error, max_log_sig_error) + ) + logger.debug( + "[step time = %.3f; N = %d; step = %d]\n" + % (time.time() - t0, self.N, step) + ) assert_equal(0.0, max_log_sig_error, prec=prec) assert_equal(0.0, leftmost_constant_error, prec=prec) diff --git a/tests/integration_tests/test_tracegraph_elbo.py b/tests/integration_tests/test_tracegraph_elbo.py index ee6c4ae614..04b37651e4 100644 --- a/tests/integration_tests/test_tracegraph_elbo.py +++ b/tests/integration_tests/test_tracegraph_elbo.py @@ -29,11 +29,10 @@ def param_abs_error(name, target): class NormalNormalTests(TestCase): - def setUp(self): # normal-normal; known covariance - self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior - self.loc0 = torch.tensor([0.0, 0.5]) # prior mean + self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior + self.loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise self.lam = torch.tensor([6.0, 4.0]) self.data = [] @@ -42,13 +41,12 @@ def setUp(self): self.data.append(torch.tensor([0.20, 0.5])) self.data.append(torch.tensor([0.10, 0.7])) self.n_data = torch.tensor(float(len(self.data))) - self.sum_data = self.data[0] + \ - self.data[1] + self.data[2] + self.data[3] - self.analytic_lam_n = self.lam0 + \ - self.n_data.expand_as(self.lam) * self.lam + self.sum_data = self.data[0] + self.data[1] + self.data[2] + self.data[3] + self.analytic_lam_n = self.lam0 + self.n_data.expand_as(self.lam) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) - self.analytic_loc_n = self.sum_data * (self.lam / self.analytic_lam_n) +\ - self.loc0 * (self.lam0 / self.analytic_lam_n) + self.analytic_loc_n = self.sum_data * ( + self.lam / self.analytic_lam_n + ) + self.loc0 * (self.lam0 / self.analytic_lam_n) def test_elbo_reparameterized(self): self.do_elbo_test(True, 1500, 0.02) @@ -58,29 +56,37 @@ def test_elbo_nonreparameterized(self): self.do_elbo_test(False, 5000, 0.05) def do_elbo_test(self, reparameterized, n_steps, prec): - logger.info(" - - - - - DO NORMALNORMAL ELBO TEST [reparameterized = %s] - - - - - " % reparameterized) + logger.info( + " - - - - - DO NORMALNORMAL ELBO TEST [reparameterized = %s] - - - - - " + % reparameterized + ) pyro.clear_param_store() Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(): with pyro.plate("plate", 2): - loc_latent = pyro.sample("loc_latent", Normal(self.loc0, torch.pow(self.lam0, -0.5))) + loc_latent = pyro.sample( + "loc_latent", Normal(self.loc0, torch.pow(self.lam0, -0.5)) + ) for i, x in enumerate(self.data): - pyro.sample("obs_%d" % i, - dist.Normal(loc_latent, torch.pow(self.lam, -0.5)), - obs=x) + pyro.sample( + "obs_%d" % i, + dist.Normal(loc_latent, torch.pow(self.lam, -0.5)), + obs=x, + ) return loc_latent def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) - log_sig_q = pyro.param("log_sig_q", - self.analytic_log_sig_n.expand(2) - 0.29) + log_sig_q = pyro.param( + "log_sig_q", self.analytic_log_sig_n.expand(2) - 0.29 + ) sig_q = torch.exp(log_sig_q) with pyro.plate("plate", 2): loc_latent = pyro.sample("loc_latent", Normal(loc_q, sig_q)) return loc_latent - adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): @@ -89,48 +95,80 @@ def guide(): loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) if k % 250 == 0: - logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) + logger.debug( + "loc error, log(scale) error: %.4f, %.4f" + % (loc_error, log_sig_error) + ) assert_equal(0.0, loc_error, prec=prec) assert_equal(0.0, log_sig_error, prec=prec) class NormalNormalNormalTests(TestCase): - def setUp(self): # normal-normal-normal; known covariance self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior - self.loc0 = torch.tensor([0.0, 0.5]) # prior mean + self.loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise self.lam = torch.tensor([6.0, 4.0]) - self.data = torch.tensor([[-0.1, 0.3], - [0.00, 0.4], - [0.20, 0.5], - [0.10, 0.7]]) + self.data = torch.tensor([[-0.1, 0.3], [0.00, 0.4], [0.20, 0.5], [0.10, 0.7]]) self.analytic_lam_n = self.lam0 + float(len(self.data)) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) - self.analytic_loc_n = self.data.sum(0) * (self.lam / self.analytic_lam_n) +\ - self.loc0 * (self.lam0 / self.analytic_lam_n) + self.analytic_loc_n = self.data.sum(0) * ( + self.lam / self.analytic_lam_n + ) + self.loc0 * (self.lam0 / self.analytic_lam_n) def test_elbo_reparameterized(self): self.do_elbo_test(True, True, 3000, 0.02, 0.002, False, False) def test_elbo_nonreparameterized_both_baselines(self): - self.do_elbo_test(False, False, 3000, 0.04, 0.001, use_nn_baseline=True, - use_decaying_avg_baseline=True) + self.do_elbo_test( + False, + False, + 3000, + 0.04, + 0.001, + use_nn_baseline=True, + use_decaying_avg_baseline=True, + ) def test_elbo_nonreparameterized_decaying_baseline(self): - self.do_elbo_test(True, False, 4000, 0.04, 0.0015, use_nn_baseline=False, - use_decaying_avg_baseline=True) + self.do_elbo_test( + True, + False, + 4000, + 0.04, + 0.0015, + use_nn_baseline=False, + use_decaying_avg_baseline=True, + ) def test_elbo_nonreparameterized_nn_baseline(self): - self.do_elbo_test(False, True, 4000, 0.04, 0.0015, use_nn_baseline=True, - use_decaying_avg_baseline=False) - - def do_elbo_test(self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline): + self.do_elbo_test( + False, + True, + 4000, + 0.04, + 0.0015, + use_nn_baseline=True, + use_decaying_avg_baseline=False, + ) + + def do_elbo_test( + self, + repa1, + repa2, + n_steps, + prec, + lr, + use_nn_baseline, + use_decaying_avg_baseline, + ): logger.info(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -") - logger.info("[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" % - (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline)) + logger.info( + "[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" + % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline) + ) pyro.clear_param_store() Normal1 = dist.Normal if repa1 else fakes.NonreparameterizedNormal Normal2 = dist.Normal if repa2 else fakes.NonreparameterizedNormal @@ -148,46 +186,72 @@ def forward(self, x): h = self.sigmoid(self.lin1(x)) return self.lin2(h) - loc_prime_baseline = pyro.module("loc_prime_baseline", VanillaBaselineNN(2, 5)) + loc_prime_baseline = pyro.module( + "loc_prime_baseline", VanillaBaselineNN(2, 5) + ) else: loc_prime_baseline = None def model(): with pyro.plate("plate", 2): - loc_latent_prime = pyro.sample("loc_latent_prime", Normal1(self.loc0, torch.pow(self.lam0, -0.5))) - loc_latent = pyro.sample("loc_latent", Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5))) + loc_latent_prime = pyro.sample( + "loc_latent_prime", Normal1(self.loc0, torch.pow(self.lam0, -0.5)) + ) + loc_latent = pyro.sample( + "loc_latent", Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5)) + ) with pyro.plate("data", len(self.data)): - pyro.sample("obs", - dist.Normal(loc_latent, torch.pow(self.lam, -0.5)) - .expand_by(self.data.shape[:1]), - obs=self.data) + pyro.sample( + "obs", + dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).expand_by( + self.data.shape[:1] + ), + obs=self.data, + ) return loc_latent # note that the exact posterior is not mean field! def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) - log_sig_q = pyro.param("log_sig_q", - self.analytic_log_sig_n.expand(2) - 0.29) - loc_q_prime = pyro.param("loc_q_prime", - torch.tensor([-0.34, 0.52])) + log_sig_q = pyro.param( + "log_sig_q", self.analytic_log_sig_n.expand(2) - 0.29 + ) + loc_q_prime = pyro.param("loc_q_prime", torch.tensor([-0.34, 0.52])) kappa_q = pyro.param("kappa_q", torch.tensor([0.74])) - log_sig_q_prime = pyro.param("log_sig_q_prime", - -0.5 * torch.log(1.2 * self.lam0)) + log_sig_q_prime = pyro.param( + "log_sig_q_prime", -0.5 * torch.log(1.2 * self.lam0) + ) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime) with pyro.plate("plate", 2): - loc_latent = pyro.sample("loc_latent", Normal2(loc_q, sig_q), - infer=dict(baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline))) - pyro.sample("loc_latent_prime", - Normal1(kappa_q.expand_as(loc_latent) * loc_latent + loc_q_prime, sig_q_prime), - infer=dict(baseline=dict(nn_baseline=loc_prime_baseline, - nn_baseline_input=loc_latent, - use_decaying_avg_baseline=use_decaying_avg_baseline))) + loc_latent = pyro.sample( + "loc_latent", + Normal2(loc_q, sig_q), + infer=dict( + baseline=dict( + use_decaying_avg_baseline=use_decaying_avg_baseline + ) + ), + ) + pyro.sample( + "loc_latent_prime", + Normal1( + kappa_q.expand_as(loc_latent) * loc_latent + loc_q_prime, + sig_q_prime, + ), + infer=dict( + baseline=dict( + nn_baseline=loc_prime_baseline, + nn_baseline_input=loc_latent, + use_decaying_avg_baseline=use_decaying_avg_baseline, + ) + ), + ) with pyro.plate("data", len(self.data)): pass return loc_latent - adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): @@ -197,7 +261,9 @@ def guide(): log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) loc_prime_error = param_mse("loc_q_prime", 0.5 * self.loc0) kappa_error = param_mse("kappa_q", 0.5 * torch.ones(1)) - log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0)) + log_sig_prime_error = param_mse( + "log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0) + ) if k % 500 == 0: logger.debug("errors: %.4f, %.4f" % (loc_error, log_sig_error)) @@ -221,7 +287,9 @@ def setUp(self): self.n_data = float(len(self.data)) data_sum = self.data.sum() self.alpha_n = self.alpha0 + data_sum # posterior alpha - self.beta_n = self.beta0 - data_sum + torch.tensor(self.n_data) # posterior beta + self.beta_n = ( + self.beta0 - data_sum + torch.tensor(self.n_data) + ) # posterior beta self.log_alpha_n = torch.log(self.alpha_n) self.log_beta_n = torch.log(self.beta_n) @@ -232,7 +300,10 @@ def test_elbo_nonreparameterized(self): self.do_elbo_test(False, 3000, 0.95, 0.0007) def do_elbo_test(self, reparameterized, n_steps, beta1, lr): - logger.info(" - - - - - DO BETA-BERNOULLI ELBO TEST [repa = %s] - - - - - " % reparameterized) + logger.info( + " - - - - - DO BETA-BERNOULLI ELBO TEST [repa = %s] - - - - - " + % reparameterized + ) pyro.clear_param_store() Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta @@ -243,13 +314,14 @@ def model(): return p_latent def guide(): - alpha_q_log = pyro.param("alpha_q_log", - self.log_alpha_n + 0.17) - beta_q_log = pyro.param("beta_q_log", - self.log_beta_n - 0.143) + alpha_q_log = pyro.param("alpha_q_log", self.log_alpha_n + 0.17) + beta_q_log = pyro.param("beta_q_log", self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) - p_latent = pyro.sample("p_latent", Beta(alpha_q, beta_q), - infer=dict(baseline=dict(use_decaying_avg_baseline=True))) + p_latent = pyro.sample( + "p_latent", + Beta(alpha_q, beta_q), + infer=dict(baseline=dict(use_decaying_avg_baseline=True)), + ) with pyro.plate("data", len(self.data)): pass return p_latent @@ -262,7 +334,9 @@ def guide(): alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) if k % 500 == 0: - logger.debug("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error)) + logger.debug( + "alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error) + ) assert_equal(0.0, alpha_error, prec=0.03) assert_equal(0.0, beta_error, prec=0.04) @@ -289,7 +363,10 @@ def test_elbo_nonreparameterized(self): self.do_elbo_test(False, 8000, 0.95, 0.0007) def do_elbo_test(self, reparameterized, n_steps, beta1, lr): - logger.info(" - - - - - DO EXPONENTIAL-GAMMA ELBO TEST [repa = %s] - - - - - " % reparameterized) + logger.info( + " - - - - - DO EXPONENTIAL-GAMMA ELBO TEST [repa = %s] - - - - - " + % reparameterized + ) pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma @@ -300,15 +377,14 @@ def model(): return lambda_latent def guide(): - alpha_q_log = pyro.param( - "alpha_q_log", - self.log_alpha_n + 0.17) - beta_q_log = pyro.param( - "beta_q_log", - self.log_beta_n - 0.143) + alpha_q_log = pyro.param("alpha_q_log", self.log_alpha_n + 0.17) + beta_q_log = pyro.param("beta_q_log", self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) - pyro.sample("lambda_latent", Gamma(alpha_q, beta_q), - infer=dict(baseline=dict(use_decaying_avg_baseline=True))) + pyro.sample( + "lambda_latent", + Gamma(alpha_q, beta_q), + infer=dict(baseline=dict(use_decaying_avg_baseline=True)), + ) with pyro.plate("data", len(self.data)): pass @@ -320,7 +396,9 @@ def guide(): alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) if k % 500 == 0: - logger.debug("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error)) + logger.debug( + "alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error) + ) assert_equal(0.0, alpha_error, prec=0.04) assert_equal(0.0, beta_error, prec=0.04) @@ -331,8 +409,8 @@ def guide(): class RaoBlackwellizationTests(TestCase): def setUp(self): # normal-normal; known covariance - self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior - self.loc0 = torch.tensor([0.0, 0.5]) # prior mean + self.lam0 = torch.tensor([0.1, 0.1]) # precision of prior + self.loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise self.lam = torch.tensor([6.0, 4.0]) self.n_outer = 3 @@ -343,43 +421,57 @@ def setUp(self): for _out in range(self.n_outer): data_in = [] for _in in range(self.n_inner): - data_in.append(torch.tensor([-0.1, 0.3]) + torch.empty(torch.Size((2,))).normal_() / self.lam.sqrt()) + data_in.append( + torch.tensor([-0.1, 0.3]) + + torch.empty(torch.Size((2,))).normal_() / self.lam.sqrt() + ) self.sum_data += data_in[-1] self.data.append(data_in) self.analytic_lam_n = self.lam0 + self.n_data.expand_as(self.lam) * self.lam self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n) - self.analytic_loc_n = self.sum_data * (self.lam / self.analytic_lam_n) +\ - self.loc0 * (self.lam0 / self.analytic_lam_n) + self.analytic_loc_n = self.sum_data * ( + self.lam / self.analytic_lam_n + ) + self.loc0 * (self.lam0 / self.analytic_lam_n) # this tests rao-blackwellization in elbo for nested sequential plates def test_nested_iplate_in_elbo(self, n_steps=4000): pyro.clear_param_store() def model(): - loc_latent = pyro.sample("loc_latent", - fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)) - .to_event(1)) + loc_latent = pyro.sample( + "loc_latent", + fakes.NonreparameterizedNormal( + self.loc0, torch.pow(self.lam0, -0.5) + ).to_event(1), + ) for i in pyro.plate("outer", self.n_outer): for j in pyro.plate("inner_%d" % i, self.n_inner): - pyro.sample("obs_%d_%d" % (i, j), - dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), - obs=self.data[i][j]) + pyro.sample( + "obs_%d_%d" % (i, j), + dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), + obs=self.data[i][j], + ) def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.234) - log_sig_q = pyro.param("log_sig_q", - self.analytic_log_sig_n.expand(2) - 0.27) + log_sig_q = pyro.param( + "log_sig_q", self.analytic_log_sig_n.expand(2) - 0.27 + ) sig_q = torch.exp(log_sig_q) - pyro.sample("loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), - infer=dict(baseline=dict(use_decaying_avg_baseline=True))) + pyro.sample( + "loc_latent", + fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), + infer=dict(baseline=dict(use_decaying_avg_baseline=True)), + ) for i in pyro.plate("outer", self.n_outer): for j in pyro.plate("inner_%d" % i, self.n_inner): pass guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace() - model_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace), - graph_type="dense").get_trace() + model_trace = pyro.poutine.trace( + pyro.poutine.replay(model, trace=guide_trace), graph_type="dense" + ).get_trace() assert len(list(model_trace.edges)) == 27 assert len(model_trace.nodes) == 16 assert len(list(guide_trace.edges)) == 0 @@ -393,7 +485,10 @@ def guide(): loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) if k % 500 == 0: - logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) + logger.debug( + "loc error, log(scale) error: %.4f, %.4f" + % (loc_error, log_sig_error) + ) assert_equal(0.0, loc_error, prec=0.04) assert_equal(0.0, log_sig_error, prec=0.04) @@ -402,69 +497,97 @@ def guide(): # inside of a sequential plate with superfluous random torch.tensors to complexify the # graph structure and introduce additional baselines def test_plate_in_elbo_with_superfluous_rvs(self): - self._test_plate_in_elbo(n_superfluous_top=1, n_superfluous_bottom=1, n_steps=2000, lr=0.0113) + self._test_plate_in_elbo( + n_superfluous_top=1, n_superfluous_bottom=1, n_steps=2000, lr=0.0113 + ) - def _test_plate_in_elbo(self, n_superfluous_top, n_superfluous_bottom, n_steps, lr=0.0012): + def _test_plate_in_elbo( + self, n_superfluous_top, n_superfluous_bottom, n_steps, lr=0.0012 + ): pyro.clear_param_store() self.data_tensor = torch.zeros(9, 2) for _out in range(self.n_outer): for _in in range(self.n_inner): self.data_tensor[3 * _out + _in, :] = self.data[_out][_in] - self.data_as_list = [self.data_tensor[0:4, :], self.data_tensor[4:7, :], - self.data_tensor[7:9, :]] + self.data_as_list = [ + self.data_tensor[0:4, :], + self.data_tensor[4:7, :], + self.data_tensor[7:9, :], + ] def model(): - loc_latent = pyro.sample("loc_latent", - fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)) - .to_event(1)) + loc_latent = pyro.sample( + "loc_latent", + fakes.NonreparameterizedNormal( + self.loc0, torch.pow(self.lam0, -0.5) + ).to_event(1), + ) for i in pyro.plate("outer", 3): x_i = self.data_as_list[i] with pyro.plate("inner_%d" % i, x_i.size(0)): for k in range(n_superfluous_top): - z_i_k = pyro.sample("z_%d_%d" % (i, k), - fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i])) + z_i_k = pyro.sample( + "z_%d_%d" % (i, k), + fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i]), + ) assert z_i_k.shape == (4 - i,) - obs_i = pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)) - .to_event(1), obs=x_i) + obs_i = pyro.sample( + "obs_%d" % i, + dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), + obs=x_i, + ) assert obs_i.shape == (4 - i, 2) - for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): - z_i_k = pyro.sample("z_%d_%d" % (i, k), - fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i])) + for k in range( + n_superfluous_top, n_superfluous_top + n_superfluous_bottom + ): + z_i_k = pyro.sample( + "z_%d_%d" % (i, k), + fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i]), + ) assert z_i_k.shape == (4 - i,) pt_loc_baseline = torch.nn.Linear(1, 1) pt_superfluous_baselines = [] for k in range(n_superfluous_top + n_superfluous_bottom): - pt_superfluous_baselines.extend([torch.nn.Linear(2, 4), torch.nn.Linear(2, 3), - torch.nn.Linear(2, 2)]) + pt_superfluous_baselines.extend( + [torch.nn.Linear(2, 4), torch.nn.Linear(2, 3), torch.nn.Linear(2, 2)] + ) def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) - log_sig_q = pyro.param("log_sig_q", - self.analytic_log_sig_n.expand(2) - 0.07) + log_sig_q = pyro.param( + "log_sig_q", self.analytic_log_sig_n.expand(2) - 0.07 + ) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() - loc_latent = pyro.sample("loc_latent", - fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), - infer=dict(baseline=dict(baseline_value=baseline_value))) + loc_latent = pyro.sample( + "loc_latent", + fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), + infer=dict(baseline=dict(baseline_value=baseline_value)), + ) for i in pyro.plate("outer", 3): with pyro.plate("inner_%d" % i, 4 - i): for k in range(n_superfluous_top + n_superfluous_bottom): - z_baseline = pyro.module("z_baseline_%d_%d" % (i, k), - pt_superfluous_baselines[3 * k + i]) + z_baseline = pyro.module( + "z_baseline_%d_%d" % (i, k), + pt_superfluous_baselines[3 * k + i], + ) baseline_value = z_baseline(loc_latent.detach()) - mean_i = pyro.param("mean_%d_%d" % (i, k), - 0.5 * torch.ones(4 - i)) - z_i_k = pyro.sample("z_%d_%d" % (i, k), - fakes.NonreparameterizedNormal(mean_i, 1), - infer=dict(baseline=dict(baseline_value=baseline_value))) + mean_i = pyro.param( + "mean_%d_%d" % (i, k), 0.5 * torch.ones(4 - i) + ) + z_i_k = pyro.sample( + "z_%d_%d" % (i, k), + fakes.NonreparameterizedNormal(mean_i, 1), + infer=dict(baseline=dict(baseline_value=baseline_value)), + ) assert z_i_k.shape == (4 - i,) def per_param_callable(param_name): - if 'baseline' in param_name: + if "baseline" in param_name: return {"lr": 0.010, "betas": (0.95, 0.999)} else: return {"lr": lr, "betas": (0.95, 0.999)} @@ -481,14 +604,25 @@ def per_param_callable(param_name): if n_superfluous_top > 0 or n_superfluous_bottom > 0: superfluous_errors = [] for k in range(n_superfluous_top + n_superfluous_bottom): - mean_0_error = torch.sum(torch.pow(pyro.param("mean_0_%d" % k), 2.0)) - mean_1_error = torch.sum(torch.pow(pyro.param("mean_1_%d" % k), 2.0)) - mean_2_error = torch.sum(torch.pow(pyro.param("mean_2_%d" % k), 2.0)) - superfluous_error = torch.max(torch.max(mean_0_error, mean_1_error), mean_2_error) + mean_0_error = torch.sum( + torch.pow(pyro.param("mean_0_%d" % k), 2.0) + ) + mean_1_error = torch.sum( + torch.pow(pyro.param("mean_1_%d" % k), 2.0) + ) + mean_2_error = torch.sum( + torch.pow(pyro.param("mean_2_%d" % k), 2.0) + ) + superfluous_error = torch.max( + torch.max(mean_0_error, mean_1_error), mean_2_error + ) superfluous_errors.append(superfluous_error.detach().cpu().numpy()) if step % 500 == 0: - logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) + logger.debug( + "loc error, log(scale) error: %.4f, %.4f" + % (loc_error, log_sig_error) + ) if n_superfluous_top > 0 or n_superfluous_bottom > 0: logger.debug("superfluous error: %.4f" % np.max(superfluous_errors)) diff --git a/tests/nn/test_autoregressive.py b/tests/nn/test_autoregressive.py index a0136ad7d4..573a015939 100644 --- a/tests/nn/test_autoregressive.py +++ b/tests/nn/test_autoregressive.py @@ -19,7 +19,9 @@ def setUp(self): def _test_jacobian(self, input_dim, observed_dim, hidden_dim, param_dim): jacobian = torch.zeros(input_dim, input_dim) if observed_dim > 0: - arn = ConditionalAutoRegressiveNN(input_dim, observed_dim, [hidden_dim], param_dims=[param_dim]) + arn = ConditionalAutoRegressiveNN( + input_dim, observed_dim, [hidden_dim], param_dims=[param_dim] + ) else: arn = AutoRegressiveNN(input_dim, [hidden_dim], param_dims=[param_dim]) @@ -35,9 +37,15 @@ def nonzero(x): epsilon_vector = torch.zeros(1, input_dim) epsilon_vector[0, j] = self.epsilon if observed_dim > 0: - delta = (arn(x + 0.5 * epsilon_vector, y) - arn(x - 0.5 * epsilon_vector, y)) / self.epsilon + delta = ( + arn(x + 0.5 * epsilon_vector, y) + - arn(x - 0.5 * epsilon_vector, y) + ) / self.epsilon else: - delta = (arn(x + 0.5 * epsilon_vector) - arn(x - 0.5 * epsilon_vector)) / self.epsilon + delta = ( + arn(x + 0.5 * epsilon_vector) + - arn(x - 0.5 * epsilon_vector) + ) / self.epsilon jacobian[j, k] = float(delta[0, output_index, k]) permutation = arn.get_permutation() @@ -50,8 +58,12 @@ def nonzero(x): assert lower_sum == float(0.0) - def _test_masks(self, input_dim, observed_dim, hidden_dims, permutation, output_dim_multiplier): - masks, mask_skip = create_mask(input_dim, observed_dim, hidden_dims, permutation, output_dim_multiplier) + def _test_masks( + self, input_dim, observed_dim, hidden_dims, permutation, output_dim_multiplier + ): + masks, mask_skip = create_mask( + input_dim, observed_dim, hidden_dims, permutation, output_dim_multiplier + ) # First test that hidden layer masks are adequately connected # Tracing backwards, works out what inputs each output is connected to @@ -61,8 +73,16 @@ def _test_masks(self, input_dim, observed_dim, hidden_dims, permutation, output_ # Loop over variables for idx in range(input_dim): # Calculate correct answer - correct = torch.cat((torch.arange(observed_dim, dtype=torch.long), torch.tensor( - sorted(permutation[0:permutation.index(idx)]), dtype=torch.long) + observed_dim)) + correct = torch.cat( + ( + torch.arange(observed_dim, dtype=torch.long), + torch.tensor( + sorted(permutation[0 : permutation.index(idx)]), + dtype=torch.long, + ) + + observed_dim, + ) + ) # Loop over parameters for each variable for jdx in range(output_dim_multiplier): @@ -81,14 +101,20 @@ def _test_masks(self, input_dim, observed_dim, hidden_dims, permutation, output_ this_connections.add(ldx) prev_connections = this_connections - assert (torch.tensor(list(sorted(prev_connections)), dtype=torch.long) == correct).all() + assert ( + torch.tensor(list(sorted(prev_connections)), dtype=torch.long) + == correct + ).all() # Test the skip-connections mask skip_connections = set() for kdx in range(mask_skip.size(1)): if mask_skip[idx + jdx * input_dim, kdx]: skip_connections.add(kdx) - assert (torch.tensor(list(sorted(skip_connections)), dtype=torch.long) == correct).all() + assert ( + torch.tensor(list(sorted(skip_connections)), dtype=torch.long) + == correct + ).all() def test_jacobians(self): for observed_dim in [0, 5]: @@ -103,10 +129,11 @@ def test_masks(self): # NOTE: the hidden dimension must be greater than the input_dim for the # masks to be well-defined! hidden_dim = input_dim * 5 - permutation = torch.randperm(input_dim, device='cpu') + permutation = torch.randperm(input_dim, device="cpu") self._test_masks( input_dim, observed_dim, - [hidden_dim]*num_layers, + [hidden_dim] * num_layers, permutation, - output_dim_multiplier) + output_dim_multiplier, + ) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 98e651d9cc..41198b9f2a 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -19,26 +19,28 @@ def test_svi_smoke(): - class Model(PyroModule): def __init__(self): super().__init__() self.loc = nn.Parameter(torch.zeros(2)) self.scale = PyroParam(torch.ones(2), constraint=constraints.positive) - self.z = PyroSample(lambda self: dist.Normal(self.loc, self.scale).to_event(1)) + self.z = PyroSample( + lambda self: dist.Normal(self.loc, self.scale).to_event(1) + ) def forward(self, data): loc, log_scale = self.z.unbind(-1) with pyro.plate("data"): - pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), - obs=data) + pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), obs=data) class Guide(PyroModule): def __init__(self): super().__init__() self.loc = nn.Parameter(torch.zeros(2)) self.scale = PyroParam(torch.ones(2), constraint=constraints.positive) - self.z = PyroSample(lambda self: dist.Normal(self.loc, self.scale).to_event(1)) + self.z = PyroSample( + lambda self: dist.Normal(self.loc, self.scale).to_event(1) + ) def forward(self, *args, **kwargs): return self.z @@ -65,17 +67,16 @@ def forward(self, *args, **kwargs): def test_names(): - class Model(PyroModule): def __init__(self): super().__init__() - self.x = nn.Parameter(torch.tensor(0.)) - self.y = PyroParam(torch.tensor(1.), constraint=constraints.positive) + self.x = nn.Parameter(torch.tensor(0.0)) + self.y = PyroParam(torch.tensor(1.0), constraint=constraints.positive) self.m = nn.Module() self.m.u = nn.Parameter(torch.tensor(2.0)) self.p = PyroModule() - self.p.v = nn.Parameter(torch.tensor(3.)) - self.p.w = PyroParam(torch.tensor(4.), constraint=constraints.positive) + self.p.v = nn.Parameter(torch.tensor(3.0)) + self.p.w = PyroParam(torch.tensor(4.0), constraint=constraints.positive) def forward(self): # trigger .__getattr__() @@ -102,8 +103,11 @@ def forward(self): expected = {"x", "y", "m$$$u", "p.v", "p.w"} with poutine.trace(param_only=True) as param_capture: model() - actual = {name for name, site in param_capture.trace.nodes.items() - if site["type"] == "param"} + actual = { + name + for name, site in param_capture.trace.nodes.items() + if site["type"] == "param" + } assert actual == expected # Check pyro_parameters method @@ -114,7 +118,7 @@ def forward(self): def test_delete(): m = PyroModule() - m.a = PyroParam(torch.tensor(1.)) + m.a = PyroParam(torch.tensor(1.0)) del m.a m.a = PyroParam(torch.tensor(0.1)) assert_equal(m.a.detach(), torch.tensor(0.1)) @@ -129,12 +133,12 @@ def __init__(self, a): class Family(PyroModule): def __init__(self): super().__init__() - self.child1 = Child(torch.tensor(1.)) - self.child2 = Child(torch.tensor(2.)) + self.child1 = Child(torch.tensor(1.0)) + self.child2 = Child(torch.tensor(2.0)) f = Family() - assert_equal(f.child1.a.detach(), torch.tensor(1.)) - assert_equal(f.child2.a.detach(), torch.tensor(2.)) + assert_equal(f.child1.a.detach(), torch.tensor(1.0)) + assert_equal(f.child2.a.detach(), torch.tensor(2.0)) def test_module_cache(): @@ -155,10 +159,10 @@ def forward(self): return self.c.a f = Family() - assert_equal(f().detach(), torch.tensor(1.)) - f.c = Child(3.) - assert_equal(f().detach(), torch.tensor(3.)) - assert_equal(f.c().detach(), torch.tensor(3.)) + assert_equal(f().detach(), torch.tensor(1.0)) + f.c = Child(3.0) + assert_equal(f().detach(), torch.tensor(3.0)) + assert_equal(f.c().detach(), torch.tensor(3.0)) def test_submodule_contains_torch_module(): @@ -194,10 +198,16 @@ def forward(self): ((4,), constraints.positive), ((3, 2), constraints.positive), ((5,), constraints.simplex), - ((2, 5,), constraints.simplex), + ( + ( + 2, + 5, + ), + constraints.simplex, + ), ((5, 5), constraints.lower_cholesky), ((2, 5, 5), constraints.lower_cholesky), - ((10, ), constraints.greater_than(-torch.randn(10).exp())), + ((10,), constraints.greater_than(-torch.randn(10).exp())), ((4, 10), constraints.greater_than(-torch.randn(10).exp())), ((4, 10), constraints.greater_than(-torch.randn(4, 10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(10).exp())), @@ -205,20 +215,23 @@ def forward(self): ((3, 2, 10), constraints.greater_than(-torch.randn(3, 1, 10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(3, 2, 10).exp())), ((5,), constraints.real_vector), - ((2, 5,), constraints.real_vector), + ( + ( + 2, + 5, + ), + constraints.real_vector, + ), ((), constraints.unit_interval), - ((4, ), constraints.unit_interval), + ((4,), constraints.unit_interval), ((3, 2), constraints.unit_interval), - ((10,), constraints.interval(-torch.randn(10).exp(), - torch.randn(10).exp())), - ((4, 10), constraints.interval(-torch.randn(10).exp(), - torch.randn(10).exp())), - ((3, 2, 10), constraints.interval(-torch.randn(10).exp(), - torch.randn(10).exp())), + ((10,), constraints.interval(-torch.randn(10).exp(), torch.randn(10).exp())), + ((4, 10), constraints.interval(-torch.randn(10).exp(), torch.randn(10).exp())), + ((3, 2, 10), constraints.interval(-torch.randn(10).exp(), torch.randn(10).exp())), ] -@pytest.mark.parametrize('shape,constraint_', SHAPE_CONSTRAINT) +@pytest.mark.parametrize("shape,constraint_", SHAPE_CONSTRAINT) def test_constraints(shape, constraint_): module = PyroModule() module.x = PyroParam(torch.full(shape, 1e-4), constraint_) @@ -240,22 +253,21 @@ def test_constraints(shape, constraint_): assert constraint_.check(module.x).all() del module.x - assert 'x' not in module._pyro_params - assert not hasattr(module, 'x') - assert not hasattr(module, 'x_unconstrained') + assert "x" not in module._pyro_params + assert not hasattr(module, "x") + assert not hasattr(module, "x_unconstrained") def test_clear(): - class Model(PyroModule): def __init__(self): super().__init__() - self.x = nn.Parameter(torch.tensor(0.)) + self.x = nn.Parameter(torch.tensor(0.0)) self.m = torch.nn.Linear(2, 3) - self.m.weight.data.fill_(1.) - self.m.bias.data.fill_(2.) + self.m.weight.data.fill_(1.0) + self.m.bias.data.fill_(2.0) self.p = PyroModule() - self.p.x = nn.Parameter(torch.tensor(3.)) + self.p.x = nn.Parameter(torch.tensor(3.0)) def forward(self): return [x.clone() for x in [self.x, self.m.weight, self.m.bias, self.p.x]] @@ -285,25 +297,25 @@ def forward(self): def test_sample(): - class Model(nn.Linear, PyroModule): def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.weight = PyroSample( lambda self: dist.Normal(0, 1) - .expand([self.out_features, - self.in_features]) - .to_event(2)) + .expand([self.out_features, self.in_features]) + .to_event(2) + ) class Guide(nn.Linear, PyroModule): def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.loc = PyroParam(torch.zeros_like(self.weight)) - self.scale = PyroParam(torch.ones_like(self.weight), - constraint=constraints.positive) + self.scale = PyroParam( + torch.ones_like(self.weight), constraint=constraints.positive + ) self.weight = PyroSample( - lambda self: dist.Normal(self.loc, self.scale) - .to_event(2)) + lambda self: dist.Normal(self.loc, self.scale).to_event(2) + ) data = torch.randn(8) model = Model(8, 2) @@ -331,12 +343,12 @@ def gather(self): } module = MyModule() - module.a = nn.Parameter(torch.tensor(0.)) - module.b = PyroParam(torch.tensor(1.), constraint=constraints.positive) + module.a = nn.Parameter(torch.tensor(0.0)) + module.b = PyroParam(torch.tensor(1.0), constraint=constraints.positive) module.c = PyroSample(dist.Normal(0, 1)) module.p = PyroModule() - module.p.d = nn.Parameter(torch.tensor(3.)) - module.p.e = PyroParam(torch.tensor(4.), constraint=constraints.positive) + module.p.d = nn.Parameter(torch.tensor(3.0)) + module.p.e = PyroParam(torch.tensor(4.0), constraint=constraints.positive) module.p.f = PyroSample(dist.Normal(0, 1)) assert module._pyro_context is module.p._pyro_context @@ -358,9 +370,9 @@ def __init__(self, size): super().__init__() self.x = PyroParam(torch.zeros(size)) self.y = PyroParam(lambda: torch.randn(size)) - self.z = PyroParam(torch.ones(size), - constraint=constraints.positive, - event_dim=1) + self.z = PyroParam( + torch.ones(size), constraint=constraints.positive, event_dim=1 + ) self.s = PyroSample(dist.Normal(0, 1)) self.t = PyroSample(lambda self: dist.Normal(self.s, self.z)) @@ -505,9 +517,16 @@ def assert_identical(a, e): def randomize(model): for m in model.modules(): for name, value in list(m.named_parameters(recurse=False)): - setattr(m, name, PyroSample(prior=dist.Normal(0, 1) - .expand(value.shape) - .to_event(value.dim()))) + setattr( + m, + name, + PyroSample( + prior=dist.Normal(0, 1) + .expand(value.shape) + .to_event(value.dim()) + ), + ) + randomize(actual) randomize(expected) assert_identical(actual, expected) diff --git a/tests/ops/einsum/test_adjoint.py b/tests/ops/einsum/test_adjoint.py index 95a71b2fd6..99672a832e 100644 --- a/tests/ops/einsum/test_adjoint.py +++ b/tests/ops/einsum/test_adjoint.py @@ -11,51 +11,50 @@ from tests.common import assert_equal EQUATIONS = [ - '->', - 'w->', - ',w->', - 'w,w->', - 'w,x->', - 'w,wx,x->', - 'w,wx,xy,yz->', - 'wx,xy,yz,zw->', - 'i->i', - 'wi->i', - 'i,wi->i', - 'wi,wi->i', - 'wi,xi->i', - 'wi,wxi,xi->i', - 'wi,wxi,xyi,yzi->i', - 'wxi,xyi,yzi,zwi->i', - 'ij->ij', - 'iwj->ij', - 'ij,iwj->ij', - 'iwj,iwj->ij', - 'iwj,ixj->ij', - 'iwj,iwxj,ixj->ij', - 'iwj,iwxj,ixyj,iyzj->ij', - 'iwxj,ixyj,iyzj,izwj->ij', - 'ij->ji', - 'iwj->ji', - 'ji,iwj->ji', - 'iwj,iwj->ji', - 'iwj,ixj->ji', - 'iwj,iwxj,ixj->ji', - 'iwj,iwxj,ixyj,iyzj->ji', - 'iwxj,ixyj,iyzj,izwj->ji', + "->", + "w->", + ",w->", + "w,w->", + "w,x->", + "w,wx,x->", + "w,wx,xy,yz->", + "wx,xy,yz,zw->", + "i->i", + "wi->i", + "i,wi->i", + "wi,wi->i", + "wi,xi->i", + "wi,wxi,xi->i", + "wi,wxi,xyi,yzi->i", + "wxi,xyi,yzi,zwi->i", + "ij->ij", + "iwj->ij", + "ij,iwj->ij", + "iwj,iwj->ij", + "iwj,ixj->ij", + "iwj,iwxj,ixj->ij", + "iwj,iwxj,ixyj,iyzj->ij", + "iwxj,ixyj,iyzj,izwj->ij", + "ij->ji", + "iwj->ji", + "ji,iwj->ji", + "iwj,iwj->ji", + "iwj,ixj->ji", + "iwj,iwxj,ixj->ji", + "iwj,iwxj,ixyj,iyzj->ji", + "iwxj,ixyj,iyzj,izwj->ji", ] -@pytest.mark.parametrize('equation', EQUATIONS) -@pytest.mark.parametrize('backend', ['map', 'sample', 'marginal']) +@pytest.mark.parametrize("equation", EQUATIONS) +@pytest.mark.parametrize("backend", ["map", "sample", "marginal"]) def test_shape(backend, equation): - backend = 'pyro.ops.einsum.torch_{}'.format(backend) - inputs, output = equation.split('->') - inputs = inputs.split(',') - symbols = sorted(set(equation) - set(',->')) + backend = "pyro.ops.einsum.torch_{}".format(backend) + inputs, output = equation.split("->") + inputs = inputs.split(",") + symbols = sorted(set(equation) - set(",->")) sizes = dict(zip(symbols, itertools.count(2))) - input_shapes = [torch.Size(sizes[dim] for dim in dims) - for dims in inputs] + input_shapes = [torch.Size(sizes[dim] for dim in dims) for dims in inputs] operands = [torch.randn(shape) for shape in input_shapes] for input_, x in zip(inputs, operands): x._pyro_dims = input_ @@ -63,9 +62,9 @@ def test_shape(backend, equation): # check forward pass for x in operands: require_backward(x) - expected = contract(equation, *operands, backend='pyro.ops.einsum.torch_log') + expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") actual = contract(equation, *operands, backend=backend) - if backend.endswith('map'): + if backend.endswith("map"): assert actual.dtype == expected.dtype assert actual.shape == expected.shape else: @@ -75,42 +74,43 @@ def test_shape(backend, equation): actual._pyro_backward() for input_, x in zip(inputs, operands): backward_result = x._pyro_backward_result - if backend.endswith('marginal'): + if backend.endswith("marginal"): assert backward_result.shape == x.shape else: contract_dims = set(input_) - set(output) if contract_dims: assert backward_result.size(0) == len(contract_dims) assert set(backward_result._pyro_dims[1:]) == set(output) - for sample, dim in zip(backward_result, backward_result._pyro_sample_dims): + for sample, dim in zip( + backward_result, backward_result._pyro_sample_dims + ): assert sample.min() >= 0 assert sample.max() < sizes[dim] else: assert backward_result is None -@pytest.mark.parametrize('equation', EQUATIONS) +@pytest.mark.parametrize("equation", EQUATIONS) def test_marginal(equation): - inputs, output = equation.split('->') - inputs = inputs.split(',') - operands = [torch.randn(torch.Size((2,) * len(input_))) - for input_ in inputs] + inputs, output = equation.split("->") + inputs = inputs.split(",") + operands = [torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) - actual = contract(equation, *operands, backend='pyro.ops.einsum.torch_marginal') - expected = contract(equation, *operands, - backend='pyro.ops.einsum.torch_log') + actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_marginal") + expected = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): - marginal_equation = ','.join(inputs) + '->' + input_ - expected = contract(marginal_equation, *operands, - backend='pyro.ops.einsum.torch_log') + marginal_equation = ",".join(inputs) + "->" + input_ + expected = contract( + marginal_equation, *operands, backend="pyro.ops.einsum.torch_log" + ) actual = operand._pyro_backward_result assert_equal(expected, actual) diff --git a/tests/ops/einsum/test_torch_log.py b/tests/ops/einsum/test_torch_log.py index 4d8f53d3ab..ed23d6eafb 100644 --- a/tests/ops/einsum/test_torch_log.py +++ b/tests/ops/einsum/test_torch_log.py @@ -11,41 +11,47 @@ from tests.common import assert_equal -@pytest.mark.parametrize('min_size', [1, 2]) -@pytest.mark.parametrize('equation', [ - ',ab->ab', - 'ab,,bc->a', - 'ab,,bc->b', - 'ab,,bc->c', - 'ab,,bc->ac', - 'ab,,b,bc->ac', - 'a,ab->ab', - 'ab,b,bc->a', - 'ab,b,bc->b', - 'ab,b,bc->c', - 'ab,b,bc->ac', - 'ab,bc->ac', - 'ab,bc,cd->', - 'ab,bc,cd->a', - 'ab,bc,cd->b', - 'ab,bc,cd->c', - 'ab,bc,cd->d', - 'ab,bc,cd->ac', - 'ab,bc,cd->ad', - 'ab,bc,cd->bc', - 'a,a,ab,b,b,b,b->a', -]) -@pytest.mark.parametrize('infinite', [False, True], ids=['finite', 'infinite']) +@pytest.mark.parametrize("min_size", [1, 2]) +@pytest.mark.parametrize( + "equation", + [ + ",ab->ab", + "ab,,bc->a", + "ab,,bc->b", + "ab,,bc->c", + "ab,,bc->ac", + "ab,,b,bc->ac", + "a,ab->ab", + "ab,b,bc->a", + "ab,b,bc->b", + "ab,b,bc->c", + "ab,b,bc->ac", + "ab,bc->ac", + "ab,bc,cd->", + "ab,bc,cd->a", + "ab,bc,cd->b", + "ab,bc,cd->c", + "ab,bc,cd->d", + "ab,bc,cd->ac", + "ab,bc,cd->ad", + "ab,bc,cd->bc", + "a,a,ab,b,b,b,b->a", + ], +) +@pytest.mark.parametrize("infinite", [False, True], ids=["finite", "infinite"]) def test_einsum(equation, min_size, infinite): - inputs, output = equation.split('->') - inputs = inputs.split(',') - symbols = sorted(set(equation) - set(',->')) + inputs, output = equation.split("->") + inputs = inputs.split(",") + symbols = sorted(set(equation) - set(",->")) sizes = dict(zip(symbols, itertools.count(min_size))) - shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) - for dims in inputs] - operands = [torch.full(shape, -float('inf')) if infinite else torch.randn(shape) - for shape in shapes] + shapes = [torch.Size(tuple(sizes[dim] for dim in dims)) for dims in inputs] + operands = [ + torch.full(shape, -float("inf")) if infinite else torch.randn(shape) + for shape in shapes + ] - expected = contract(equation, *(torch_exp(x) for x in operands), backend='torch').log() - actual = contract(equation, *operands, backend='pyro.ops.einsum.torch_log') + expected = contract( + equation, *(torch_exp(x) for x in operands), backend="torch" + ).log() + actual = contract(equation, *operands, backend="pyro.ops.einsum.torch_log") assert_equal(actual, expected) diff --git a/tests/ops/test_arrowhead.py b/tests/ops/test_arrowhead.py index 5a3e611e98..f3927157f3 100644 --- a/tests/ops/test_arrowhead.py +++ b/tests/ops/test_arrowhead.py @@ -14,15 +14,15 @@ from tests.common import assert_close -@pytest.mark.parametrize('head_size', [0, 2, 5]) +@pytest.mark.parametrize("head_size", [0, 2, 5]) def test_utilities(head_size): size = 5 cov = torch.randn(size, size) cov = torch.mm(cov, cov.t()) mask = torch.ones(size, size) - mask[head_size:, head_size:] = 0. - mask.view(-1)[::size + 1][head_size:] = 1. + mask[head_size:, head_size:] = 0.0 + mask.view(-1)[:: size + 1][head_size:] = 1.0 arrowhead_full = mask * cov expected = torch.flip( torch.linalg.cholesky(torch.flip(arrowhead_full, (-2, -1))), (-2, -1) @@ -46,10 +46,13 @@ def test_utilities(head_size): # test triu_matvecmul v = torch.randn(size) assert_close(triu_matvecmul(actual, v), expected.matmul(v)) - assert_close(triu_matvecmul(actual, v, transpose=True), - expected.t().matmul(v)) + assert_close(triu_matvecmul(actual, v, transpose=True), expected.t().matmul(v)) # test triu_gram actual = triu_gram(actual) - expected = arrowhead_full.inverse() if head_size > 0 else arrowhead_full.diag().reciprocal() + expected = ( + arrowhead_full.inverse() + if head_size > 0 + else arrowhead_full.diag().reciprocal() + ) assert_close(actual, expected) diff --git a/tests/ops/test_contract.py b/tests/ops/test_contract.py index 86f304d2f3..01db05824c 100644 --- a/tests/ops/test_contract.py +++ b/tests/ops/test_contract.py @@ -74,8 +74,11 @@ def checked_fn(*args): result = fn(*args) for pos, (arg, copy) in enumerate(zip(args, copies)): if not deep_equal(arg, copy): - raise AssertionError('{} mutated arg {} of type {}.\nOld:\n{}\nNew:\n{}' - .format(fn.__name__, pos, type(arg).__name__, copy, arg)) + raise AssertionError( + "{} mutated arg {} of type {}.\nOld:\n{}\nNew:\n{}".format( + fn.__name__, pos, type(arg).__name__, copy, arg + ) + ) return result return checked_fn @@ -89,30 +92,33 @@ def _normalize(tensor, dims, plates): return tensor - total -@pytest.mark.parametrize('inputs,dims,expected_num_components', [ - ([''], set(), 1), - (['a'], set(), 1), - (['a'], set('a'), 1), - (['a', 'a'], set(), 2), - (['a', 'a'], set('a'), 1), - (['a', 'a', 'b', 'b'], set(), 4), - (['a', 'a', 'b', 'b'], set('a'), 3), - (['a', 'a', 'b', 'b'], set('b'), 3), - (['a', 'a', 'b', 'b'], set('ab'), 2), - (['a', 'ab', 'b'], set(), 3), - (['a', 'ab', 'b'], set('a'), 2), - (['a', 'ab', 'b'], set('b'), 2), - (['a', 'ab', 'b'], set('ab'), 1), - (['a', 'ab', 'bc', 'c'], set(), 4), - (['a', 'ab', 'bc', 'c'], set('c'), 3), - (['a', 'ab', 'bc', 'c'], set('b'), 3), - (['a', 'ab', 'bc', 'c'], set('a'), 3), - (['a', 'ab', 'bc', 'c'], set('ac'), 2), - (['a', 'ab', 'bc', 'c'], set('abc'), 1), -]) +@pytest.mark.parametrize( + "inputs,dims,expected_num_components", + [ + ([""], set(), 1), + (["a"], set(), 1), + (["a"], set("a"), 1), + (["a", "a"], set(), 2), + (["a", "a"], set("a"), 1), + (["a", "a", "b", "b"], set(), 4), + (["a", "a", "b", "b"], set("a"), 3), + (["a", "a", "b", "b"], set("b"), 3), + (["a", "a", "b", "b"], set("ab"), 2), + (["a", "ab", "b"], set(), 3), + (["a", "ab", "b"], set("a"), 2), + (["a", "ab", "b"], set("b"), 2), + (["a", "ab", "b"], set("ab"), 1), + (["a", "ab", "bc", "c"], set(), 4), + (["a", "ab", "bc", "c"], set("c"), 3), + (["a", "ab", "bc", "c"], set("b"), 3), + (["a", "ab", "bc", "c"], set("a"), 3), + (["a", "ab", "bc", "c"], set("ac"), 2), + (["a", "ab", "bc", "c"], set("abc"), 1), + ], +) def test_partition_terms(inputs, dims, expected_num_components): ring = LogRing() - symbol_to_size = dict(zip('abc', [2, 3, 4])) + symbol_to_size = dict(zip("abc", [2, 3, 4])) shapes = [tuple(symbol_to_size[s] for s in input_) for input_ in inputs] tensors = [torch.randn(shape) for shape in shapes] for input_, tensor in zip(inputs, tensors): @@ -138,7 +144,9 @@ def test_partition_terms(inputs, dims, expected_num_components): def frame(dim, size): - return CondIndepStackFrame(name="plate_{}".format(size), dim=dim, size=size, counter=0) + return CondIndepStackFrame( + name="plate_{}".format(size), dim=dim, size=size, counter=0 + ) EXAMPLES = [ @@ -147,44 +155,44 @@ def frame(dim, size): # | 4 x, y are enumerated in dims: # x a, b { - 'shape_tree': { - frozenset(): ['a'], - frozenset('i'): ['abi'], + "shape_tree": { + frozenset(): ["a"], + frozenset("i"): ["abi"], }, - 'sum_dims': set('ab'), - 'target_dims': set(), - 'target_ordinal': frozenset(), - 'expected_dims': (), + "sum_dims": set("ab"), + "target_dims": set(), + "target_ordinal": frozenset(), + "expected_dims": (), }, { - 'shape_tree': { - frozenset(): ['a'], - frozenset('i'): ['abi'], + "shape_tree": { + frozenset(): ["a"], + frozenset("i"): ["abi"], }, - 'sum_dims': set('ab'), - 'target_dims': set('a'), - 'target_ordinal': frozenset(), - 'expected_dims': 'a', + "sum_dims": set("ab"), + "target_dims": set("a"), + "target_ordinal": frozenset(), + "expected_dims": "a", }, { - 'shape_tree': { - frozenset(): ['a'], - frozenset('i'): ['abi'], + "shape_tree": { + frozenset(): ["a"], + frozenset("i"): ["abi"], }, - 'sum_dims': set('ab'), - 'target_dims': set('b'), - 'target_ordinal': frozenset('i'), - 'expected_dims': 'bi', + "sum_dims": set("ab"), + "target_dims": set("b"), + "target_ordinal": frozenset("i"), + "expected_dims": "bi", }, { - 'shape_tree': { - frozenset(): ['a'], - frozenset('i'): ['abi'], + "shape_tree": { + frozenset(): ["a"], + frozenset("i"): ["abi"], }, - 'sum_dims': set('ab'), - 'target_dims': set('ab'), - 'target_ordinal': frozenset('i'), - 'expected_dims': 'abi', + "sum_dims": set("ab"), + "target_dims": set("ab"), + "target_ordinal": frozenset("i"), + "expected_dims": "abi", }, # ------------------------------------------------------ # z @@ -193,88 +201,90 @@ def frame(dim, size): # 2 \ / 3 a, b, c, d # w { - 'shape_tree': { - frozenset(): ['a'], # w - frozenset('i'): ['abi'], # x - frozenset('j'): ['acj'], # y - frozenset('ij'): ['cdij'], # z + "shape_tree": { + frozenset(): ["a"], # w + frozenset("i"): ["abi"], # x + frozenset("j"): ["acj"], # y + frozenset("ij"): ["cdij"], # z }, # query for w - 'sum_dims': set('abcd'), - 'target_dims': set('a'), - 'target_ordinal': frozenset(), - 'expected_dims': 'a', + "sum_dims": set("abcd"), + "target_dims": set("a"), + "target_ordinal": frozenset(), + "expected_dims": "a", }, { - 'shape_tree': { - frozenset(): ['a'], # w - frozenset('i'): ['abi'], # x - frozenset('j'): ['acj'], # y - frozenset('ij'): ['cdij'], # z + "shape_tree": { + frozenset(): ["a"], # w + frozenset("i"): ["abi"], # x + frozenset("j"): ["acj"], # y + frozenset("ij"): ["cdij"], # z }, # query for x - 'sum_dims': set('abcd'), - 'target_dims': set('b'), - 'target_ordinal': frozenset('i'), - 'expected_dims': 'bi', + "sum_dims": set("abcd"), + "target_dims": set("b"), + "target_ordinal": frozenset("i"), + "expected_dims": "bi", }, { - 'shape_tree': { - frozenset(): ['a'], # w - frozenset('i'): ['abi'], # x - frozenset('j'): ['acj'], # y - frozenset('ij'): ['cdij'], # z + "shape_tree": { + frozenset(): ["a"], # w + frozenset("i"): ["abi"], # x + frozenset("j"): ["acj"], # y + frozenset("ij"): ["cdij"], # z }, # query for y - 'sum_dims': set('abcd'), - 'target_dims': set('c'), - 'target_ordinal': frozenset('j'), - 'expected_dims': 'cj', + "sum_dims": set("abcd"), + "target_dims": set("c"), + "target_ordinal": frozenset("j"), + "expected_dims": "cj", }, { - 'shape_tree': { - frozenset(): ['a'], # w - frozenset('i'): ['abi'], # x - frozenset('j'): ['acj'], # y - frozenset('ij'): ['cdij'], # z + "shape_tree": { + frozenset(): ["a"], # w + frozenset("i"): ["abi"], # x + frozenset("j"): ["acj"], # y + frozenset("ij"): ["cdij"], # z }, # query for z - 'sum_dims': set('abcd'), - 'target_dims': set('d'), - 'target_ordinal': frozenset('ij'), - 'expected_dims': 'dij', + "sum_dims": set("abcd"), + "target_dims": set("d"), + "target_ordinal": frozenset("ij"), + "expected_dims": "dij", }, ] -@pytest.mark.parametrize('example', EXAMPLES) +@pytest.mark.parametrize("example", EXAMPLES) def test_contract_to_tensor(example): - symbol_to_size = dict(zip('abcdij', [4, 5, 6, 7, 2, 3])) + symbol_to_size = dict(zip("abcdij", [4, 5, 6, 7, 2, 3])) tensor_tree = OrderedDict() - for t, shapes in example['shape_tree'].items(): + for t, shapes in example["shape_tree"].items(): for dims in shapes: tensor = torch.randn(tuple(symbol_to_size[s] for s in dims)) tensor._pyro_dims = dims tensor_tree.setdefault(t, []).append(tensor) - sum_dims = example['sum_dims'] - target_dims = example['target_dims'] - target_ordinal = example['target_ordinal'] - expected_dims = example['expected_dims'] - - actual = assert_immutable(contract_to_tensor)(tensor_tree, sum_dims, target_ordinal, target_dims) + sum_dims = example["sum_dims"] + target_dims = example["target_dims"] + target_ordinal = example["target_ordinal"] + expected_dims = example["expected_dims"] + + actual = assert_immutable(contract_to_tensor)( + tensor_tree, sum_dims, target_ordinal, target_dims + ) assert set(actual._pyro_dims) == set(expected_dims) -@pytest.mark.parametrize('example', EXAMPLES) +@pytest.mark.parametrize("example", EXAMPLES) def test_contract_tensor_tree(example): - symbol_to_size = dict(zip('abcdij', [4, 5, 6, 7, 2, 3])) + symbol_to_size = dict(zip("abcdij", [4, 5, 6, 7, 2, 3])) tensor_tree = OrderedDict() - for t, shapes in example['shape_tree'].items(): + for t, shapes in example["shape_tree"].items(): for dims in shapes: tensor = torch.randn(tuple(symbol_to_size[s] for s in dims)) tensor._pyro_dims = dims tensor_tree.setdefault(t, []).append(tensor) - sum_dims = example['sum_dims'] + sum_dims = example["sum_dims"] tensor_tree = assert_immutable(contract_tensor_tree)(tensor_tree, sum_dims) assert tensor_tree @@ -286,53 +296,53 @@ def test_contract_tensor_tree(example): # Let abcde be enum dims and ijk be plates. UBERSUM_EXAMPLES = [ - ('->', ''), - ('a->,a', ''), - ('ab->,a,b,ab,ba', ''), - ('ab,bc->,a,b,c,ab,bc,ac,abc', ''), - ('ab,bc,cd->,a,b,c,d,ab,ac,ad,bc,bd,cd,abc,acd,bcd,abcd', ''), - ('i->,i', 'i'), - (',i->,i', 'i'), - (',i,i->,i', 'i'), - (',i,ia->,i,ia', 'i'), - (',i,i,ia,ia->,i,ia', 'i'), - ('bi,ia->,i,ia,ib,iab', 'i'), - ('abi,b->,b,ai,abi', 'i'), - ('ia,ja,ija->,a,i,ia,j,ja,ija', 'ij'), - ('i,jb,ijab->,i,j,jb,ij,ija,ijb,ijab', 'ij'), - ('ia,jb,ijab->,i,ia,j,jb,ij,ija,ijb,ijab', 'ij'), - (',i,j,a,ij,ia,ja,ija->,a,i,j,ia,ja,ij,ija', 'ij'), - ('a,b,c,di,ei,fj->,a,b,c,di,ei,fj', 'ij'), + ("->", ""), + ("a->,a", ""), + ("ab->,a,b,ab,ba", ""), + ("ab,bc->,a,b,c,ab,bc,ac,abc", ""), + ("ab,bc,cd->,a,b,c,d,ab,ac,ad,bc,bd,cd,abc,acd,bcd,abcd", ""), + ("i->,i", "i"), + (",i->,i", "i"), + (",i,i->,i", "i"), + (",i,ia->,i,ia", "i"), + (",i,i,ia,ia->,i,ia", "i"), + ("bi,ia->,i,ia,ib,iab", "i"), + ("abi,b->,b,ai,abi", "i"), + ("ia,ja,ija->,a,i,ia,j,ja,ija", "ij"), + ("i,jb,ijab->,i,j,jb,ij,ija,ijb,ijab", "ij"), + ("ia,jb,ijab->,i,ia,j,jb,ij,ija,ijb,ijab", "ij"), + (",i,j,a,ij,ia,ja,ija->,a,i,j,ia,ja,ij,ija", "ij"), + ("a,b,c,di,ei,fj->,a,b,c,di,ei,fj", "ij"), # {ij} {ik} # a\ /a # {i} - ('ija,ika->,i,j,k,ij,ik,ijk,ia,ija,ika,ijka', 'ijk'), + ("ija,ika->,i,j,k,ij,ik,ijk,ia,ija,ika,ijka", "ijk"), # {ij} {ik} # a\ /a # {i} {} - (',ia,ija,ika->,i,j,k,ij,ik,ijk,ia,ija,ika,ijka', 'ijk'), + (",ia,ija,ika->,i,j,k,ij,ik,ijk,ia,ija,ika,ijka", "ijk"), # {i} c # |b # {} a - ('ab,bci->,a,b,ab,i,ai,bi,ci,abi,bci,abci', 'i'), + ("ab,bci->,a,b,ab,i,ai,bi,ci,abi,bci,abci", "i"), # {i} cd # |b # {} a - ('ab,bci,bdi->,a,b,ab,i,ai,bi,ci,abi,bci,bdi,cdi,abci,abdi,abcdi', 'i'), + ("ab,bci,bdi->,a,b,ab,i,ai,bi,ci,abi,bci,bdi,cdi,abci,abdi,abcdi", "i"), # {ij} c # |b # {} a - ('ab,bcij->,a,b,ab,i,j,ij,ai,aj,aij,bi,bj,aij,bij,cij,abij,acij,bcij,abcij', 'ij'), + ("ab,bcij->,a,b,ab,i,j,ij,ai,aj,aij,bi,bj,aij,bij,cij,abij,acij,bcij,abcij", "ij"), # {ij} c # |b # {i} a - ('abi,bcij->,i,ai,bi,abi,j,ij,aij,bij,cij,abij,bcij,abcij', 'ij'), + ("abi,bcij->,i,ai,bi,abi,j,ij,aij,bij,cij,abij,bcij,abcij", "ij"), # {ij} e # |d # {i} c # |b # {} a - ('ab,bcdi,deij->,a,b,ci,di,eij', 'ij'), + ("ab,bcdi,deij->,a,b,ci,di,eij", "ij"), # {ijk} g # |f # {ij} e @@ -340,28 +350,28 @@ def test_contract_tensor_tree(example): # {i} c # |b # {} a - ('ab,bcdi,defij,fgijk->,a,b,ci,di,eij,fij,gijk', 'ijk'), + ("ab,bcdi,defij,fgijk->,a,b,ci,di,eij,fij,gijk", "ijk"), # {ik} {ij} {ij} # a\ /b /e # {i} {j} # c\ /d # {} - ('aik,bij,abci,cd,dej,eij->,ai,bi,ej,aik,bij,eij', 'ijk'), + ("aik,bij,abci,cd,dej,eij->,ai,bi,ej,aik,bij,eij", "ijk"), # {ij} {ij} # a| |d # {i} {j} # b\ /c # {} - ('aij,abi,bc,cdj,dij->,bi,cj,aij,dij,adij', 'ij'), + ("aij,abi,bc,cdj,dij->,bi,cj,aij,dij,adij", "ij"), ] def make_example(equation, fill=None, sizes=(2, 3)): - symbols = sorted(set(equation) - set(',->')) + symbols = sorted(set(equation) - set(",->")) sizes = {dim: size for dim, size in zip(symbols, itertools.cycle(sizes))} - inputs, outputs = equation.split('->') - inputs = inputs.split(',') - outputs = outputs.split(',') + inputs, outputs = equation.split("->") + inputs = inputs.split(",") + outputs = outputs.split(",") operands = [] for dims in inputs: shape = tuple(sizes[dim] for dim in dims) @@ -369,7 +379,7 @@ def make_example(equation, fill=None, sizes=(2, 3)): return inputs, outputs, operands, sizes -@pytest.mark.parametrize('equation,plates', UBERSUM_EXAMPLES) +@pytest.mark.parametrize("equation,plates", UBERSUM_EXAMPLES) def test_naive_ubersum(equation, plates): inputs, outputs, operands, sizes = make_example(equation) @@ -381,15 +391,20 @@ def test_naive_ubersum(equation, plates): expected_shape = tuple(sizes[dim] for dim in output) assert actual_part.shape == expected_shape if not plates: - equation_part = ','.join(inputs) + '->' + output - expected_part = opt_einsum.contract(equation_part, *operands, - backend='pyro.ops.einsum.torch_log') - assert_equal(expected_part, actual_part, - msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( - output, expected_part.detach().cpu(), actual_part.detach().cpu())) - - -@pytest.mark.parametrize('equation,plates', UBERSUM_EXAMPLES) + equation_part = ",".join(inputs) + "->" + output + expected_part = opt_einsum.contract( + equation_part, *operands, backend="pyro.ops.einsum.torch_log" + ) + assert_equal( + expected_part, + actual_part, + msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( + output, expected_part.detach().cpu(), actual_part.detach().cpu() + ), + ) + + +@pytest.mark.parametrize("equation,plates", UBERSUM_EXAMPLES) def test_ubersum(equation, plates): inputs, outputs, operands, sizes = make_example(equation) @@ -404,18 +419,24 @@ def test_ubersum(equation, plates): for output, expected_part, actual_part in zip(outputs, expected, actual): actual_part = _normalize(actual_part, output, plates) expected_part = _normalize(expected_part, output, plates) - assert_equal(expected_part, actual_part, - msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( - output, expected_part.detach().cpu(), actual_part.detach().cpu())) + assert_equal( + expected_part, + actual_part, + msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( + output, expected_part.detach().cpu(), actual_part.detach().cpu() + ), + ) -@pytest.mark.parametrize('equation,plates', UBERSUM_EXAMPLES) +@pytest.mark.parametrize("equation,plates", UBERSUM_EXAMPLES) def test_einsum_linear(equation, plates): inputs, outputs, log_operands, sizes = make_example(equation) operands = [x.exp() for x in log_operands] try: - log_expected = ubersum(equation, *log_operands, plates=plates, modulo_total=True) + log_expected = ubersum( + equation, *log_operands, plates=plates, modulo_total=True + ) expected = [x.exp() for x in log_expected] except NotImplementedError: pytest.skip() @@ -425,12 +446,16 @@ def test_einsum_linear(equation, plates): assert isinstance(actual, tuple) assert len(actual) == len(outputs) for output, expected_part, actual_part in zip(outputs, expected, actual): - assert_equal(expected_part.log(), actual_part.log(), - msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( - output, expected_part.detach().cpu(), actual_part.detach().cpu())) + assert_equal( + expected_part.log(), + actual_part.log(), + msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( + output, expected_part.detach().cpu(), actual_part.detach().cpu() + ), + ) -@pytest.mark.parametrize('equation,plates', UBERSUM_EXAMPLES) +@pytest.mark.parametrize("equation,plates", UBERSUM_EXAMPLES) def test_ubersum_jit(equation, plates): inputs, outputs, operands, sizes = make_example(equation) @@ -452,46 +477,53 @@ def jit_ubersum(*operands): assert_equal(e, a) -@pytest.mark.parametrize('equation,plates', [ - ('i->', 'i'), - ('i->i', 'i'), - (',i->', 'i'), - (',i->i', 'i'), - ('ai->', 'i'), - ('ai->i', 'i'), - ('ai->ai', 'i'), - (',ai,abij->aij', 'ij'), - ('a,ai,bij->bij', 'ij'), - ('a,ai,abij->bij', 'ij'), - ('a,abi,bcij->a', 'ij'), - ('a,abi,bcij->bi', 'ij'), - ('a,abi,bcij->bij', 'ij'), - ('a,abi,bcij->cij', 'ij'), - ('ab,bcdi,deij->eij', 'ij'), -]) +@pytest.mark.parametrize( + "equation,plates", + [ + ("i->", "i"), + ("i->i", "i"), + (",i->", "i"), + (",i->i", "i"), + ("ai->", "i"), + ("ai->i", "i"), + ("ai->ai", "i"), + (",ai,abij->aij", "ij"), + ("a,ai,bij->bij", "ij"), + ("a,ai,abij->bij", "ij"), + ("a,abi,bcij->a", "ij"), + ("a,abi,bcij->bi", "ij"), + ("a,abi,bcij->bij", "ij"), + ("a,abi,bcij->cij", "ij"), + ("ab,bcdi,deij->eij", "ij"), + ], +) def test_ubersum_total(equation, plates): - inputs, outputs, operands, sizes = make_example(equation, fill=1., sizes=(2,)) + inputs, outputs, operands, sizes = make_example(equation, fill=1.0, sizes=(2,)) output = outputs[0] expected = naive_ubersum(equation, *operands, plates=plates)[0] actual = ubersum(equation, *operands, plates=plates, modulo_total=True)[0] expected = _normalize(expected, output, plates) actual = _normalize(actual, output, plates) - assert_equal(expected, actual, - msg=u"Expected:\n{}\nActual:\n{}".format( - expected.detach().cpu(), actual.detach().cpu())) - - -@pytest.mark.parametrize('a', [2, 1]) -@pytest.mark.parametrize('b', [3, 1]) -@pytest.mark.parametrize('c', [3, 1]) -@pytest.mark.parametrize('d', [4, 1]) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) + assert_equal( + expected, + actual, + msg=u"Expected:\n{}\nActual:\n{}".format( + expected.detach().cpu(), actual.detach().cpu() + ), + ) + + +@pytest.mark.parametrize("a", [2, 1]) +@pytest.mark.parametrize("b", [3, 1]) +@pytest.mark.parametrize("c", [3, 1]) +@pytest.mark.parametrize("d", [4, 1]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_sizes(impl, a, b, c, d): X = torch.randn(a, b) Y = torch.randn(b, c) Z = torch.randn(c, d) - actual = impl('ab,bc,cd->a,b,c,d', X, Y, Z, plates='ad', modulo_total=True) + actual = impl("ab,bc,cd->a,b,c,d", X, Y, Z, plates="ad", modulo_total=True) actual_a, actual_b, actual_c, actual_d = actual assert actual_a.shape == (a,) assert actual_b.shape == (b,) @@ -499,7 +531,7 @@ def test_ubersum_sizes(impl, a, b, c, d): assert actual_d.shape == (d,) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_1(impl): # y {a} z {b} # \ / @@ -508,12 +540,12 @@ def test_ubersum_1(impl): x = torch.randn(c) y = torch.randn(c, d, a) z = torch.randn(e, c, b) - actual, = impl('c,cda,ecb->', x, y, z, plates='ab', modulo_total=True) + (actual,) = impl("c,cda,ecb->", x, y, z, plates="ab", modulo_total=True) expected = logsumexp(x + logsumexp(y, -2).sum(-1) + logsumexp(z, -3).sum(-1), -1) assert_equal(actual, expected) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_2(impl): # y {a} z {b} <--- target # \ / @@ -522,13 +554,13 @@ def test_ubersum_2(impl): x = torch.randn(c) y = torch.randn(c, d, a) z = torch.randn(e, c, b) - actual, = impl('c,cda,ecb->b', x, y, z, plates='ab', modulo_total=True) + (actual,) = impl("c,cda,ecb->b", x, y, z, plates="ab", modulo_total=True) xyz = logsumexp(x + logsumexp(y, -2).sum(-1) + logsumexp(z, -3).sum(-1), -1) expected = xyz.expand(b) assert_equal(actual, expected) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_3(impl): # z {b,c} # | @@ -540,7 +572,7 @@ def test_ubersum_3(impl): x = torch.randn(d) y = torch.randn(b, d) z = torch.randn(b, c, d, e) - actual, = impl('ae,d,bd,bcde->be', w, x, y, z, plates='abc', modulo_total=True) + (actual,) = impl("ae,d,bd,bcde->be", w, x, y, z, plates="abc", modulo_total=True) yz = y.reshape(b, d, 1) + z.sum(-3) # eliminate c assert yz.shape == (b, d, e) yz = yz.sum(0) # eliminate b @@ -553,7 +585,7 @@ def test_ubersum_3(impl): assert_equal(actual, expected) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_4(impl): # x,y {b} <--- target # | @@ -561,7 +593,7 @@ def test_ubersum_4(impl): a, b, c, d = 2, 3, 4, 5 x = torch.randn(a, b) y = torch.randn(d, b, c) - actual, = impl('ab,dbc->dc', x, y, plates='d', modulo_total=True) + (actual,) = impl("ab,dbc->dc", x, y, plates="d", modulo_total=True) x_b1 = logsumexp(x, 0).unsqueeze(-1) assert x_b1.shape == (b, 1) y_db1 = logsumexp(y, 2, keepdim=True) @@ -574,7 +606,7 @@ def test_ubersum_4(impl): assert_equal(actual, expected) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_5(impl): # z {ij} <--- target # | @@ -585,7 +617,7 @@ def test_ubersum_5(impl): x = torch.randn(a) y = torch.randn(a, b, i) z = torch.randn(b, c, i, j) - actual, = impl('a,abi,bcij->cij', x, y, z, plates='ij', modulo_total=True) + (actual,) = impl("a,abi,bcij->cij", x, y, z, plates="ij", modulo_total=True) # contract plate j s1 = logsumexp(z, 1) @@ -605,12 +637,13 @@ def test_ubersum_5(impl): q2 = x2 - s2.unsqueeze(-2) assert q2.shape == (a, b, i) - expected = opt_einsum.contract('a,a,abi,bcij->cij', x, p2, q2, q1, - backend='pyro.ops.einsum.torch_log') + expected = opt_einsum.contract( + "a,a,abi,bcij->cij", x, p2, q2, q1, backend="pyro.ops.einsum.torch_log" + ) assert_equal(actual, expected) -@pytest.mark.parametrize('impl,implemented', [(naive_ubersum, True), (ubersum, False)]) +@pytest.mark.parametrize("impl,implemented", [(naive_ubersum, True), (ubersum, False)]) def test_ubersum_collide_implemented(impl, implemented): # Non-tree plates cause exponential blowup, # so ubersum() refuses to evaluate them. @@ -624,12 +657,14 @@ def test_ubersum_collide_implemented(impl, implemented): x = torch.randn(a, c) y = torch.randn(b, d) z = torch.randn(a, b, c, d) - raises = pytest.raises(NotImplementedError, match='Expected tree-structured plate nesting') + raises = pytest.raises( + NotImplementedError, match="Expected tree-structured plate nesting" + ) with optional(raises, not implemented): - impl('ac,bd,abcd->', x, y, z, plates='ab', modulo_total=True) + impl("ac,bd,abcd->", x, y, z, plates="ab", modulo_total=True) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_collide_ok_1(impl): # The following is ok because it splits into connected components # {x,z1} and {y,z2}, thereby avoiding exponential blowup. @@ -644,10 +679,10 @@ def test_ubersum_collide_ok_1(impl): y = torch.randn(b, d) z1 = torch.randn(a, b, c) z2 = torch.randn(a, b, d) - impl('ac,bd,abc,abd->', x, y, z1, z2, plates='ab', modulo_total=True) + impl("ac,bd,abc,abd->", x, y, z1, z2, plates="ab", modulo_total=True) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_collide_ok_2(impl): # The following is ok because z1 can be contracted to x and # z2 can be contracted to y. @@ -663,10 +698,10 @@ def test_ubersum_collide_ok_2(impl): y = torch.randn(b, d) z1 = torch.randn(a, b, c) z2 = torch.randn(a, b, d) - impl('cd,ac,bd,abc,abd->', w, x, y, z1, z2, plates='ab', modulo_total=True) + impl("cd,ac,bd,abc,abd->", w, x, y, z1, z2, plates="ab", modulo_total=True) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_collide_ok_3(impl): # The following is ok because x, y, and z can be independently contracted to w. # @@ -680,72 +715,73 @@ def test_ubersum_collide_ok_3(impl): x = torch.randn(a, c) y = torch.randn(b, c) z = torch.randn(a, b, c) - impl('c,ac,bc,abc->', w, x, y, z, plates='ab', modulo_total=True) + impl("c,ac,bc,abc->", w, x, y, z, plates="ab", modulo_total=True) UBERSUM_SHAPE_ERRORS = [ - ('ab,bc->', [(2, 3), (4, 5)], ''), - ('ab,bc->', [(2, 3), (4, 5)], 'b'), + ("ab,bc->", [(2, 3), (4, 5)], ""), + ("ab,bc->", [(2, 3), (4, 5)], "b"), ] -@pytest.mark.parametrize('equation,shapes,plates', UBERSUM_SHAPE_ERRORS) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("equation,shapes,plates", UBERSUM_SHAPE_ERRORS) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_size_error(impl, equation, shapes, plates): operands = [torch.randn(shape) for shape in shapes] - with pytest.raises(ValueError, match='Dimension size mismatch|Size of label'): + with pytest.raises(ValueError, match="Dimension size mismatch|Size of label"): impl(equation, *operands, plates=plates, modulo_total=True) UBERSUM_BATCH_ERRORS = [ - ('ai->a', 'i'), - (',ai->a', 'i'), - ('bi,abi->b', 'i'), - (',bi,abi->b', 'i'), - ('aij->ai', 'ij'), - ('aij->aj', 'ij'), + ("ai->a", "i"), + (",ai->a", "i"), + ("bi,abi->b", "i"), + (",bi,abi->b", "i"), + ("aij->ai", "ij"), + ("aij->aj", "ij"), ] -@pytest.mark.parametrize('equation,plates', UBERSUM_BATCH_ERRORS) -@pytest.mark.parametrize('impl', [naive_ubersum, ubersum]) +@pytest.mark.parametrize("equation,plates", UBERSUM_BATCH_ERRORS) +@pytest.mark.parametrize("impl", [naive_ubersum, ubersum]) def test_ubersum_plate_error(impl, equation, plates): - inputs, outputs = equation.split('->') - operands = [torch.randn(torch.Size((2,) * len(input_))) - for input_ in inputs.split(',')] - with pytest.raises(ValueError, match='It is nonsensical to preserve a plated dim'): + inputs, outputs = equation.split("->") + operands = [ + torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs.split(",") + ] + with pytest.raises(ValueError, match="It is nonsensical to preserve a plated dim"): impl(equation, *operands, plates=plates, modulo_total=True) ADJOINT_EXAMPLES = [ - ('a->', ''), - ('a,a->', ''), - ('ab,bc->', ''), - ('a,abi->', 'i'), - ('a,abi,bcij->', 'ij'), - ('a,abi,bcij,bdik->', 'ijk'), - ('ai,ai->i', 'i'), - ('ai,abij->i', 'ij'), - ('ai,abij,acik->i', 'ijk'), + ("a->", ""), + ("a,a->", ""), + ("ab,bc->", ""), + ("a,abi->", "i"), + ("a,abi,bcij->", "ij"), + ("a,abi,bcij,bdik->", "ijk"), + ("ai,ai->i", "i"), + ("ai,abij->i", "ij"), + ("ai,abij,acik->i", "ijk"), ] -@pytest.mark.parametrize('equation,plates', ADJOINT_EXAMPLES) -@pytest.mark.parametrize('backend', ['map', 'sample', 'marginal']) +@pytest.mark.parametrize("equation,plates", ADJOINT_EXAMPLES) +@pytest.mark.parametrize("backend", ["map", "sample", "marginal"]) def test_adjoint_shape(backend, equation, plates): - backend = 'pyro.ops.einsum.torch_{}'.format(backend) - inputs, output = equation.split('->') - inputs = inputs.split(',') - operands = [torch.randn(torch.Size((2,) * len(input_))) - for input_ in inputs] + backend = "pyro.ops.einsum.torch_{}".format(backend) + inputs, output = equation.split("->") + inputs = inputs.split(",") + operands = [torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # run forward-backward algorithm for x in operands: require_backward(x) - result, = ubersum(equation, *operands, plates=plates, - modulo_total=True, backend=backend) + (result,) = ubersum( + equation, *operands, plates=plates, modulo_total=True, backend=backend + ) result._pyro_backward() for input_, x in zip(inputs, operands): @@ -757,29 +793,43 @@ def test_adjoint_shape(backend, equation, plates): assert backward_result is None -@pytest.mark.parametrize('equation,plates', ADJOINT_EXAMPLES) +@pytest.mark.parametrize("equation,plates", ADJOINT_EXAMPLES) def test_adjoint_marginal(equation, plates): - inputs, output = equation.split('->') - inputs = inputs.split(',') - operands = [torch.randn(torch.Size((2,) * len(input_))) - for input_ in inputs] + inputs, output = equation.split("->") + inputs = inputs.split(",") + operands = [torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) - actual, = ubersum(equation, *operands, plates=plates, modulo_total=True, - backend='pyro.ops.einsum.torch_marginal') - expected, = ubersum(equation, *operands, plates=plates, modulo_total=True, - backend='pyro.ops.einsum.torch_log') + (actual,) = ubersum( + equation, + *operands, + plates=plates, + modulo_total=True, + backend="pyro.ops.einsum.torch_marginal" + ) + (expected,) = ubersum( + equation, + *operands, + plates=plates, + modulo_total=True, + backend="pyro.ops.einsum.torch_log" + ) assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): - marginal_equation = ','.join(inputs) + '->' + input_ - expected, = ubersum(marginal_equation, *operands, plates=plates, modulo_total=True, - backend='pyro.ops.einsum.torch_log') + marginal_equation = ",".join(inputs) + "->" + input_ + (expected,) = ubersum( + marginal_equation, + *operands, + plates=plates, + modulo_total=True, + backend="pyro.ops.einsum.torch_log" + ) actual = operand._pyro_backward_result assert_equal(expected, actual) diff --git a/tests/ops/test_gamma_gaussian.py b/tests/ops/test_gamma_gaussian.py index ae913f62f0..7bf420121c 100644 --- a/tests/ops/test_gamma_gaussian.py +++ b/tests/ops/test_gamma_gaussian.py @@ -25,17 +25,29 @@ @pytest.mark.parametrize("extra_shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("log_normalizer_shape,info_vec_shape,precision_shape,alpha_shape,beta_shape", [ - ((), (), (), (), ()), - ((5,), (), (), (), ()), - ((), (5,), (), (), ()), - ((), (), (5,), (), ()), - ((), (), (), (5,), ()), - ((), (), (), (), (5,)), - ((3, 1, 1), (1, 4, 1), (1, 1, 5), (3, 4, 1), (1, 4, 5)), -], ids=str) +@pytest.mark.parametrize( + "log_normalizer_shape,info_vec_shape,precision_shape,alpha_shape,beta_shape", + [ + ((), (), (), (), ()), + ((5,), (), (), (), ()), + ((), (5,), (), (), ()), + ((), (), (5,), (), ()), + ((), (), (), (5,), ()), + ((), (), (), (), (5,)), + ((3, 1, 1), (1, 4, 1), (1, 1, 5), (3, 4, 1), (1, 4, 5)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) -def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape, beta_shape, dim): +def test_expand( + extra_shape, + log_normalizer_shape, + info_vec_shape, + precision_shape, + alpha_shape, + beta_shape, + dim, +): rank = dim + dim log_normalizer = torch.randn(log_normalizer_shape) info_vec = torch.randn(info_vec_shape + (dim,)) @@ -46,15 +58,20 @@ def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_sha gamma_gaussian = GammaGaussian(log_normalizer, info_vec, precision, alpha, beta) expected_shape = extra_shape + broadcast_shape( - log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape, beta_shape) + log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape, beta_shape + ) actual = gamma_gaussian.expand(expected_shape) assert actual.batch_shape == expected_shape -@pytest.mark.parametrize("old_shape,new_shape", [ - ((6,), (3, 2)), - ((5, 6), (5, 3, 2)), -], ids=str) +@pytest.mark.parametrize( + "old_shape,new_shape", + [ + ((6,), (3, 2)), + ((5, 6), (5, 3, 2)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) def test_reshape(old_shape, new_shape, dim): gamma_gaussian = random_gamma_gaussian(old_shape, dim) @@ -68,11 +85,15 @@ def test_reshape(old_shape, new_shape, dim): assert_close_gamma_gaussian(g, gamma_gaussian) -@pytest.mark.parametrize("shape,cat_dim,split", [ - ((4, 7, 6), -1, (2, 1, 3)), - ((4, 7, 6), -2, (1, 1, 2, 3)), - ((4, 7, 6), 1, (1, 1, 2, 3)), -], ids=str) +@pytest.mark.parametrize( + "shape,cat_dim,split", + [ + ((4, 7, 6), -1, (2, 1, 3)), + ((4, 7, 6), -2, (1, 1, 2, 3)), + ((4, 7, 6), 1, (1, 1, 2, 3)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] @@ -82,11 +103,11 @@ def test_cat(shape, cat_dim, split, dim): for size in split: beg, end = end, end + size if cat_dim == -1: - part = gamma_gaussian[..., beg: end] + part = gamma_gaussian[..., beg:end] elif cat_dim == -2: - part = gamma_gaussian[..., beg: end, :] + part = gamma_gaussian[..., beg:end, :] elif cat_dim == 1: - part = gamma_gaussian[:, beg: end] + part = gamma_gaussian[:, beg:end] else: raise ValueError parts.append(part) @@ -116,7 +137,9 @@ def test_add(shape, dim): y = random_gamma_gaussian(shape, dim) value = torch.randn(dim) s = torch.randn(()).exp() - assert_close((x + y).log_density(value, s), x.log_density(value, s) + y.log_density(value, s)) + assert_close( + (x + y).log_density(value, s), x.log_density(value, s) + y.log_density(value, s) + ) @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) @@ -136,10 +159,14 @@ def test_marginalize(batch_shape, left, right): dim = left + right g = random_gamma_gaussian(batch_shape, dim) s = torch.randn(batch_shape).exp() - assert_close(g.marginalize(left=left).event_logsumexp().log_density(s), - g.event_logsumexp().log_density(s)) - assert_close(g.marginalize(right=right).event_logsumexp().log_density(s), - g.event_logsumexp().log_density(s)) + assert_close( + g.marginalize(left=left).event_logsumexp().log_density(s), + g.event_logsumexp().log_density(s), + ) + assert_close( + g.marginalize(right=right).event_logsumexp().log_density(s), + g.event_logsumexp().log_density(s), + ) @pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) @@ -151,8 +178,10 @@ def test_marginalize_condition(sample_shape, batch_shape, left, right): g = random_gamma_gaussian(batch_shape, dim) x = torch.randn(sample_shape + (1,) * len(batch_shape) + (right,)) s = torch.randn(batch_shape).exp() - assert_close(g.marginalize(left=left).log_density(x, s), - g.condition(x).event_logsumexp().log_density(s)) + assert_close( + g.marginalize(left=left).log_density(x, s), + g.condition(x).event_logsumexp().log_density(s), + ) @pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) @@ -186,8 +215,13 @@ def test_logsumexp(batch_shape, dim): num_samples = 200000 scale = 10 - samples = torch.rand((num_samples,) + (1,) * len(batch_shape) + (dim,)) * scale - scale / 2 - expected = g.log_density(samples, s).logsumexp(0) + math.log(scale ** dim / num_samples) + samples = ( + torch.rand((num_samples,) + (1,) * len(batch_shape) + (dim,)) * scale + - scale / 2 + ) + expected = g.log_density(samples, s).logsumexp(0) + math.log( + scale ** dim / num_samples + ) actual = g.event_logsumexp().log_density(s) assert_close(actual, expected, atol=0.05, rtol=0.05) @@ -205,7 +239,9 @@ def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim): s_log_prob = gamma.log_prob(s) scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1) - mvn_log_prob = dist.MultivariateNormal(mvn.loc, precision_matrix=scaled_prec).log_prob(value) + mvn_log_prob = dist.MultivariateNormal( + mvn.loc, precision_matrix=scaled_prec + ).log_prob(value) expected_log_prob = s_log_prob + mvn_log_prob assert_close(actual_log_prob, expected_log_prob) @@ -226,33 +262,47 @@ def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_dim, y_di y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) loc = y_pred + y_mvn.loc scaled_prec = y_mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1) - expected_log_prob = dist.MultivariateNormal(loc, precision_matrix=scaled_prec).log_prob(y) + expected_log_prob = dist.MultivariateNormal( + loc, precision_matrix=scaled_prec + ).log_prob(y) assert_close(actual_log_prob, expected_log_prob) -@pytest.mark.parametrize("x_batch_shape,y_batch_shape", [ - ((), ()), - ((3,), ()), - ((), (3,)), - ((2, 1), (3,)), - ((2, 3), (2, 3,)), -], ids=str) -@pytest.mark.parametrize("x_dim,y_dim,dot_dims", [ - (0, 0, 0), - (0, 2, 0), - (1, 0, 0), - (2, 1, 0), - (3, 3, 3), - (3, 2, 1), - (3, 2, 2), - (5, 4, 2), -], ids=str) -@pytest.mark.parametrize("x_rank,y_rank", [ - (1, 1), (4, 1), (1, 4), (4, 4) -], ids=str) -def test_gamma_gaussian_tensordot(dot_dims, - x_batch_shape, x_dim, x_rank, - y_batch_shape, y_dim, y_rank): +@pytest.mark.parametrize( + "x_batch_shape,y_batch_shape", + [ + ((), ()), + ((3,), ()), + ((), (3,)), + ((2, 1), (3,)), + ( + (2, 3), + ( + 2, + 3, + ), + ), + ], + ids=str, +) +@pytest.mark.parametrize( + "x_dim,y_dim,dot_dims", + [ + (0, 0, 0), + (0, 2, 0), + (1, 0, 0), + (2, 1, 0), + (3, 3, 3), + (3, 2, 1), + (3, 2, 2), + (5, 4, 2), + ], + ids=str, +) +@pytest.mark.parametrize("x_rank,y_rank", [(1, 1), (4, 1), (1, 4), (4, 4)], ids=str) +def test_gamma_gaussian_tensordot( + dot_dims, x_batch_shape, x_dim, x_rank, y_batch_shape, y_dim, y_rank +): x_rank = min(x_rank, x_dim) y_rank = min(y_rank, y_dim) x = random_gamma_gaussian(x_batch_shape, x_dim, x_rank) @@ -276,9 +326,15 @@ def test_gamma_gaussian_tensordot(dot_dims, precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision, (na, 0, na, 0)) info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0)) covariance = torch.inverse(precision) - loc = covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) if info_vec.size(-1) > 0 else info_vec + loc = ( + covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) + if info_vec.size(-1) > 0 + else info_vec + ) z_covariance = torch.inverse(z.precision) - z_loc = z_covariance.matmul(z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0),))).sum(-1) + z_loc = z_covariance.matmul( + z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0),)) + ).sum(-1) assert_close(loc[..., :na], z_loc[..., :na]) assert_close(loc[..., x_dim:], z_loc[..., na:]) assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na]) @@ -294,7 +350,9 @@ def test_gamma_gaussian_tensordot(dot_dims, value_b = torch.rand((num_samples,) + z.batch_shape + (nb,)) * scale - scale / 2 value_x = pad(value_b, (na, 0)) value_y = pad(value_b, (0, nc)) - expect = torch.logsumexp(x.log_density(value_x, s) + y.log_density(value_y, s), dim=0) + expect = torch.logsumexp( + x.log_density(value_x, s) + y.log_density(value_y, s), dim=0 + ) expect += math.log(scale ** nb / num_samples) actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(),)), s) - assert_close(actual.clamp(max=10.), expect.clamp(max=10.), atol=0.1, rtol=0.1) + assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index 5224ba3a10..b602a0e07b 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -21,15 +21,21 @@ @pytest.mark.parametrize("extra_shape", [(), (4,), (3, 2)], ids=str) -@pytest.mark.parametrize("log_normalizer_shape,info_vec_shape,precision_shape", [ - ((), (), ()), - ((5,), (), ()), - ((), (5,), ()), - ((), (), (5,)), - ((3, 1, 1), (1, 4, 1), (1, 1, 5)), -], ids=str) +@pytest.mark.parametrize( + "log_normalizer_shape,info_vec_shape,precision_shape", + [ + ((), (), ()), + ((5,), (), ()), + ((), (5,), ()), + ((), (), (5,)), + ((3, 1, 1), (1, 4, 1), (1, 1, 5)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) -def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_shape, dim): +def test_expand( + extra_shape, log_normalizer_shape, info_vec_shape, precision_shape, dim +): rank = dim + dim log_normalizer = torch.randn(log_normalizer_shape) info_vec = torch.randn(info_vec_shape + (dim,)) @@ -38,15 +44,20 @@ def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_sha gaussian = Gaussian(log_normalizer, info_vec, precision) expected_shape = extra_shape + broadcast_shape( - log_normalizer_shape, info_vec_shape, precision_shape) + log_normalizer_shape, info_vec_shape, precision_shape + ) actual = gaussian.expand(expected_shape) assert actual.batch_shape == expected_shape -@pytest.mark.parametrize("old_shape,new_shape", [ - ((6,), (3, 2)), - ((5, 6), (5, 3, 2)), -], ids=str) +@pytest.mark.parametrize( + "old_shape,new_shape", + [ + ((6,), (3, 2)), + ((5, 6), (5, 3, 2)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) def test_reshape(old_shape, new_shape, dim): gaussian = random_gaussian(old_shape, dim) @@ -60,11 +71,15 @@ def test_reshape(old_shape, new_shape, dim): assert_close_gaussian(g, gaussian) -@pytest.mark.parametrize("shape,cat_dim,split", [ - ((4, 7, 6), -1, (2, 1, 3)), - ((4, 7, 6), -2, (1, 1, 2, 3)), - ((4, 7, 6), 1, (1, 1, 2, 3)), -], ids=str) +@pytest.mark.parametrize( + "shape,cat_dim,split", + [ + ((4, 7, 6), -1, (2, 1, 3)), + ((4, 7, 6), -2, (1, 1, 2, 3)), + ((4, 7, 6), 1, (1, 1, 2, 3)), + ], + ids=str, +) @pytest.mark.parametrize("dim", [1, 2, 3]) def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] @@ -74,11 +89,11 @@ def test_cat(shape, cat_dim, split, dim): for size in split: beg, end = end, end + size if cat_dim == -1: - part = gaussian[..., beg: end] + part = gaussian[..., beg:end] elif cat_dim == -2: - part = gaussian[..., beg: end, :] + part = gaussian[..., beg:end, :] elif cat_dim == 1: - part = gaussian[:, beg: end] + part = gaussian[:, beg:end] else: raise ValueError parts.append(part) @@ -107,7 +122,9 @@ def test_add(shape, dim): x = random_gaussian(shape, dim) y = random_gaussian(shape, dim) value = torch.randn(dim) - assert_close((x + y).log_density(value), x.log_density(value) + y.log_density(value)) + assert_close( + (x + y).log_density(value), x.log_density(value) + y.log_density(value) + ) @pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) @@ -162,10 +179,8 @@ def test_marginalize_shape(batch_shape, left, right): def test_marginalize(batch_shape, left, right): dim = left + right g = random_gaussian(batch_shape, dim) - assert_close(g.marginalize(left=left).event_logsumexp(), - g.event_logsumexp()) - assert_close(g.marginalize(right=right).event_logsumexp(), - g.event_logsumexp()) + assert_close(g.marginalize(left=left).event_logsumexp(), g.event_logsumexp()) + assert_close(g.marginalize(right=right).event_logsumexp(), g.event_logsumexp()) @pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) @@ -176,8 +191,9 @@ def test_marginalize_condition(sample_shape, batch_shape, left, right): dim = left + right g = random_gaussian(batch_shape, dim) x = torch.randn(sample_shape + (1,) * len(batch_shape) + (right,)) - assert_close(g.marginalize(left=left).log_density(x), - g.condition(x).event_logsumexp()) + assert_close( + g.marginalize(left=left).log_density(x), g.condition(x).event_logsumexp() + ) @pytest.mark.parametrize("sample_shape", [(), (4,), (3, 2)], ids=str) @@ -217,8 +233,13 @@ def test_logsumexp(batch_shape, dim): num_samples = 200000 scale = 10 - samples = torch.rand((num_samples,) + (1,) * len(batch_shape) + (dim,)) * scale - scale / 2 - expected = gaussian.log_density(samples).logsumexp(0) + math.log(scale ** dim / num_samples) + samples = ( + torch.rand((num_samples,) + (1,) * len(batch_shape) + (dim,)) * scale + - scale / 2 + ) + expected = gaussian.log_density(samples).logsumexp(0) + math.log( + scale ** dim / num_samples + ) actual = gaussian.event_logsumexp() assert_close(actual, expected, atol=0.05, rtol=0.05) @@ -303,34 +324,47 @@ def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) - actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), - matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) + actual = gaussian_tensordot( + mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim + ) assert_close_gaussian(expected, actual) -@pytest.mark.parametrize("x_batch_shape,y_batch_shape", [ - ((), ()), - ((3,), ()), - ((), (3,)), - ((2, 1), (3,)), - ((2, 3), (2, 3,)), -], ids=str) -@pytest.mark.parametrize("x_dim,y_dim,dot_dims", [ - (0, 0, 0), - (0, 2, 0), - (1, 0, 0), - (2, 1, 0), - (3, 3, 3), - (3, 2, 1), - (3, 2, 2), - (5, 4, 2), -], ids=str) -@pytest.mark.parametrize("x_rank,y_rank", [ - (1, 1), (4, 1), (1, 4), (4, 4) -], ids=str) -def test_gaussian_tensordot(dot_dims, - x_batch_shape, x_dim, x_rank, - y_batch_shape, y_dim, y_rank): +@pytest.mark.parametrize( + "x_batch_shape,y_batch_shape", + [ + ((), ()), + ((3,), ()), + ((), (3,)), + ((2, 1), (3,)), + ( + (2, 3), + ( + 2, + 3, + ), + ), + ], + ids=str, +) +@pytest.mark.parametrize( + "x_dim,y_dim,dot_dims", + [ + (0, 0, 0), + (0, 2, 0), + (1, 0, 0), + (2, 1, 0), + (3, 3, 3), + (3, 2, 1), + (3, 2, 2), + (5, 4, 2), + ], + ids=str, +) +@pytest.mark.parametrize("x_rank,y_rank", [(1, 1), (4, 1), (1, 4), (4, 4)], ids=str) +def test_gaussian_tensordot( + dot_dims, x_batch_shape, x_dim, x_rank, y_batch_shape, y_dim, y_rank +): x_rank = min(x_rank, x_dim) y_rank = min(y_rank, y_dim) x = random_gaussian(x_batch_shape, x_dim, x_rank) @@ -354,9 +388,15 @@ def test_gaussian_tensordot(dot_dims, precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision, (na, 0, na, 0)) info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0)) covariance = torch.inverse(precision) - loc = covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) if info_vec.size(-1) > 0 else info_vec + loc = ( + covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1) + if info_vec.size(-1) > 0 + else info_vec + ) z_covariance = torch.inverse(z.precision) - z_loc = z_covariance.matmul(z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0),))).sum(-1) + z_loc = z_covariance.matmul( + z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0),)) + ).sum(-1) assert_close(loc[..., :na], z_loc[..., :na]) assert_close(loc[..., x_dim:], z_loc[..., na:]) assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na]) @@ -377,4 +417,4 @@ def test_gaussian_tensordot(dot_dims, actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(),))) # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value # log densities. - assert_close(actual.clamp(max=10.), expect.clamp(max=10.), atol=0.1, rtol=0.1) + assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1) diff --git a/tests/ops/test_indexing.py b/tests/ops/test_indexing.py index 4bf76a791a..133158f449 100644 --- a/tests/ops/test_indexing.py +++ b/tests/ops/test_indexing.py @@ -25,92 +25,92 @@ def z(*args): SHAPE_EXAMPLES = [ - ('Vindex(z(()))[...]', ()), - ('Vindex(z(2))[...]', (2,)), - ('Vindex(z(2))[...,0]', ()), - ('Vindex(z(2))[...,:]', (2,)), - ('Vindex(z(2))[...,z(3)]', (3,)), - ('Vindex(z(2))[0]', ()), - ('Vindex(z(2))[:]', (2,)), - ('Vindex(z(2))[z(3)]', (3,)), - ('Vindex(z(2,3))[...]', (2, 3)), - ('Vindex(z(2,3))[...,0]', (2,)), - ('Vindex(z(2,3))[...,:]', (2, 3)), - ('Vindex(z(2,3))[...,z(2)]', (2,)), - ('Vindex(z(2,3))[...,z(4,1)]', (4, 2)), - ('Vindex(z(2,3))[...,0,0]', ()), - ('Vindex(z(2,3))[...,0,:]', (3,)), - ('Vindex(z(2,3))[...,0,z(4)]', (4,)), - ('Vindex(z(2,3))[...,:,0]', (2,)), - ('Vindex(z(2,3))[...,:,:]', (2, 3)), - ('Vindex(z(2,3))[...,:,z(4)]', (4, 2)), - ('Vindex(z(2,3))[...,z(4),0]', (4,)), - ('Vindex(z(2,3))[...,z(4),:]', (4, 3)), - ('Vindex(z(2,3))[...,z(4),z(4)]', (4,)), - ('Vindex(z(2,3))[...,z(5,1),z(4)]', (5, 4)), - ('Vindex(z(2,3))[...,z(4),z(5,1)]', (5, 4)), - ('Vindex(z(2,3))[0,0]', ()), - ('Vindex(z(2,3))[0,:]', (3,)), - ('Vindex(z(2,3))[0,z(4)]', (4,)), - ('Vindex(z(2,3))[:,0]', (2,)), - ('Vindex(z(2,3))[:,:]', (2, 3)), - ('Vindex(z(2,3))[:,z(4)]', (4, 2)), - ('Vindex(z(2,3))[z(4),0]', (4,)), - ('Vindex(z(2,3))[z(4),:]', (4, 3)), - ('Vindex(z(2,3))[z(4)]', (4, 3)), - ('Vindex(z(2,3))[z(4),z(4)]', (4,)), - ('Vindex(z(2,3))[z(5,1),z(4)]', (5, 4)), - ('Vindex(z(2,3))[z(4),z(5,1)]', (5, 4)), - ('Vindex(z(2,3,4))[...]', (2, 3, 4)), - ('Vindex(z(2,3,4))[...,z(3)]', (2, 3)), - ('Vindex(z(2,3,4))[...,z(2,1)]', (2, 3)), - ('Vindex(z(2,3,4))[...,z(2,3)]', (2, 3)), - ('Vindex(z(2,3,4))[...,z(5,1,1)]', (5, 2, 3)), - ('Vindex(z(2,3,4))[...,z(2),0]', (2,)), - ('Vindex(z(2,3,4))[...,z(5,1),0]', (5, 2)), - ('Vindex(z(2,3,4))[...,z(2),:]', (2, 4)), - ('Vindex(z(2,3,4))[...,z(5,1),:]', (5, 2, 4)), - ('Vindex(z(2,3,4))[...,z(5),0,0]', (5,)), - ('Vindex(z(2,3,4))[...,z(5),0,:]', (5, 4)), - ('Vindex(z(2,3,4))[...,z(5),:,0]', (5, 3)), - ('Vindex(z(2,3,4))[...,z(5),:,:]', (5, 3, 4)), - ('Vindex(z(2,3,4))[0,0,z(5)]', (5,)), - ('Vindex(z(2,3,4))[0,:,z(5)]', (5, 3)), - ('Vindex(z(2,3,4))[0,z(5),0]', (5,)), - ('Vindex(z(2,3,4))[0,z(5),:]', (5, 4)), - ('Vindex(z(2,3,4))[0,z(5),z(5)]', (5,)), - ('Vindex(z(2,3,4))[0,z(5,1),z(6)]', (5, 6)), - ('Vindex(z(2,3,4))[0,z(6),z(5,1)]', (5, 6)), - ('Vindex(z(2,3,4))[:,0,z(5)]', (5, 2)), - ('Vindex(z(2,3,4))[:,:,z(5)]', (5, 2, 3)), - ('Vindex(z(2,3,4))[:,z(5),0]', (5, 2)), - ('Vindex(z(2,3,4))[:,z(5),:]', (5, 2, 4)), - ('Vindex(z(2,3,4))[:,z(5),z(5)]', (5, 2)), - ('Vindex(z(2,3,4))[:,z(5,1),z(6)]', (5, 6, 2)), - ('Vindex(z(2,3,4))[:,z(6),z(5,1)]', (5, 6, 2)), - ('Vindex(z(2,3,4))[z(5),0,0]', (5,)), - ('Vindex(z(2,3,4))[z(5),0,:]', (5, 4)), - ('Vindex(z(2,3,4))[z(5),:,0]', (5, 3)), - ('Vindex(z(2,3,4))[z(5),:,:]', (5, 3, 4)), - ('Vindex(z(2,3,4))[z(5),0,z(5)]', (5,)), - ('Vindex(z(2,3,4))[z(5,1),0,z(6)]', (5, 6)), - ('Vindex(z(2,3,4))[z(6),0,z(5,1)]', (5, 6)), - ('Vindex(z(2,3,4))[z(5),:,z(5)]', (5, 3)), - ('Vindex(z(2,3,4))[z(5,1),:,z(6)]', (5, 6, 3)), - ('Vindex(z(2,3,4))[z(6),:,z(5,1)]', (5, 6, 3)), + ("Vindex(z(()))[...]", ()), + ("Vindex(z(2))[...]", (2,)), + ("Vindex(z(2))[...,0]", ()), + ("Vindex(z(2))[...,:]", (2,)), + ("Vindex(z(2))[...,z(3)]", (3,)), + ("Vindex(z(2))[0]", ()), + ("Vindex(z(2))[:]", (2,)), + ("Vindex(z(2))[z(3)]", (3,)), + ("Vindex(z(2,3))[...]", (2, 3)), + ("Vindex(z(2,3))[...,0]", (2,)), + ("Vindex(z(2,3))[...,:]", (2, 3)), + ("Vindex(z(2,3))[...,z(2)]", (2,)), + ("Vindex(z(2,3))[...,z(4,1)]", (4, 2)), + ("Vindex(z(2,3))[...,0,0]", ()), + ("Vindex(z(2,3))[...,0,:]", (3,)), + ("Vindex(z(2,3))[...,0,z(4)]", (4,)), + ("Vindex(z(2,3))[...,:,0]", (2,)), + ("Vindex(z(2,3))[...,:,:]", (2, 3)), + ("Vindex(z(2,3))[...,:,z(4)]", (4, 2)), + ("Vindex(z(2,3))[...,z(4),0]", (4,)), + ("Vindex(z(2,3))[...,z(4),:]", (4, 3)), + ("Vindex(z(2,3))[...,z(4),z(4)]", (4,)), + ("Vindex(z(2,3))[...,z(5,1),z(4)]", (5, 4)), + ("Vindex(z(2,3))[...,z(4),z(5,1)]", (5, 4)), + ("Vindex(z(2,3))[0,0]", ()), + ("Vindex(z(2,3))[0,:]", (3,)), + ("Vindex(z(2,3))[0,z(4)]", (4,)), + ("Vindex(z(2,3))[:,0]", (2,)), + ("Vindex(z(2,3))[:,:]", (2, 3)), + ("Vindex(z(2,3))[:,z(4)]", (4, 2)), + ("Vindex(z(2,3))[z(4),0]", (4,)), + ("Vindex(z(2,3))[z(4),:]", (4, 3)), + ("Vindex(z(2,3))[z(4)]", (4, 3)), + ("Vindex(z(2,3))[z(4),z(4)]", (4,)), + ("Vindex(z(2,3))[z(5,1),z(4)]", (5, 4)), + ("Vindex(z(2,3))[z(4),z(5,1)]", (5, 4)), + ("Vindex(z(2,3,4))[...]", (2, 3, 4)), + ("Vindex(z(2,3,4))[...,z(3)]", (2, 3)), + ("Vindex(z(2,3,4))[...,z(2,1)]", (2, 3)), + ("Vindex(z(2,3,4))[...,z(2,3)]", (2, 3)), + ("Vindex(z(2,3,4))[...,z(5,1,1)]", (5, 2, 3)), + ("Vindex(z(2,3,4))[...,z(2),0]", (2,)), + ("Vindex(z(2,3,4))[...,z(5,1),0]", (5, 2)), + ("Vindex(z(2,3,4))[...,z(2),:]", (2, 4)), + ("Vindex(z(2,3,4))[...,z(5,1),:]", (5, 2, 4)), + ("Vindex(z(2,3,4))[...,z(5),0,0]", (5,)), + ("Vindex(z(2,3,4))[...,z(5),0,:]", (5, 4)), + ("Vindex(z(2,3,4))[...,z(5),:,0]", (5, 3)), + ("Vindex(z(2,3,4))[...,z(5),:,:]", (5, 3, 4)), + ("Vindex(z(2,3,4))[0,0,z(5)]", (5,)), + ("Vindex(z(2,3,4))[0,:,z(5)]", (5, 3)), + ("Vindex(z(2,3,4))[0,z(5),0]", (5,)), + ("Vindex(z(2,3,4))[0,z(5),:]", (5, 4)), + ("Vindex(z(2,3,4))[0,z(5),z(5)]", (5,)), + ("Vindex(z(2,3,4))[0,z(5,1),z(6)]", (5, 6)), + ("Vindex(z(2,3,4))[0,z(6),z(5,1)]", (5, 6)), + ("Vindex(z(2,3,4))[:,0,z(5)]", (5, 2)), + ("Vindex(z(2,3,4))[:,:,z(5)]", (5, 2, 3)), + ("Vindex(z(2,3,4))[:,z(5),0]", (5, 2)), + ("Vindex(z(2,3,4))[:,z(5),:]", (5, 2, 4)), + ("Vindex(z(2,3,4))[:,z(5),z(5)]", (5, 2)), + ("Vindex(z(2,3,4))[:,z(5,1),z(6)]", (5, 6, 2)), + ("Vindex(z(2,3,4))[:,z(6),z(5,1)]", (5, 6, 2)), + ("Vindex(z(2,3,4))[z(5),0,0]", (5,)), + ("Vindex(z(2,3,4))[z(5),0,:]", (5, 4)), + ("Vindex(z(2,3,4))[z(5),:,0]", (5, 3)), + ("Vindex(z(2,3,4))[z(5),:,:]", (5, 3, 4)), + ("Vindex(z(2,3,4))[z(5),0,z(5)]", (5,)), + ("Vindex(z(2,3,4))[z(5,1),0,z(6)]", (5, 6)), + ("Vindex(z(2,3,4))[z(6),0,z(5,1)]", (5, 6)), + ("Vindex(z(2,3,4))[z(5),:,z(5)]", (5, 3)), + ("Vindex(z(2,3,4))[z(5,1),:,z(6)]", (5, 6, 3)), + ("Vindex(z(2,3,4))[z(6),:,z(5,1)]", (5, 6, 3)), ] -@pytest.mark.parametrize('expression,expected_shape', SHAPE_EXAMPLES, ids=str) +@pytest.mark.parametrize("expression,expected_shape", SHAPE_EXAMPLES, ids=str) def test_shape(expression, expected_shape): result = eval(expression) assert result.shape == expected_shape -@pytest.mark.parametrize('event_shape', [(), (7,)], ids=str) -@pytest.mark.parametrize('j_shape', [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) -@pytest.mark.parametrize('i_shape', [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) -@pytest.mark.parametrize('x_shape', [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) +@pytest.mark.parametrize("event_shape", [(), (7,)], ids=str) +@pytest.mark.parametrize("j_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) +@pytest.mark.parametrize("i_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) +@pytest.mark.parametrize("x_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) def test_value(x_shape, i_shape, j_shape, event_shape): x = torch.rand(x_shape + (5, 6) + event_shape) i = dist.Categorical(torch.ones(5)).sample(i_shape) @@ -125,34 +125,45 @@ def test_value(x_shape, i_shape, j_shape, event_shape): i = i.expand(shape) j = j.expand(shape) expected = x.new_empty(shape + event_shape) - for ind in (itertools.product(*map(range, shape)) if shape else [()]): + for ind in itertools.product(*map(range, shape)) if shape else [()]: expected[ind] = x[ind + (i[ind].item(), j[ind].item())] assert_equal(actual, expected) -@pytest.mark.parametrize('prev_enum_dim,curr_enum_dim', [(-3, -4), (-4, -5), (-5, -3)]) +@pytest.mark.parametrize("prev_enum_dim,curr_enum_dim", [(-3, -4), (-4, -5), (-5, -3)]) def test_hmm_example(prev_enum_dim, curr_enum_dim): hidden_dim = 8 probs_x = torch.rand(hidden_dim, hidden_dim, hidden_dim) x_prev = torch.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - prev_enum_dim)) x_curr = torch.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - curr_enum_dim)) - expected = probs_x[x_prev.unsqueeze(-1), x_curr.unsqueeze(-1), torch.arange(hidden_dim)] + expected = probs_x[ + x_prev.unsqueeze(-1), x_curr.unsqueeze(-1), torch.arange(hidden_dim) + ] actual = Vindex(probs_x)[x_prev, x_curr, :] assert_equal(actual, expected) -@pytest.mark.parametrize("args,expected", [ - (0, 0), - (1, 1), - (None, None), - (slice(1, 2, 3), slice(1, 2, 3)), - (Ellipsis, Ellipsis), - ((0, 1, None, slice(1, 2, 3), Ellipsis), (0, 1, None, slice(1, 2, 3), Ellipsis)), - (((0, 1), (None, slice(1, 2, 3)), Ellipsis), (0, 1, None, slice(1, 2, 3), Ellipsis)), - ((Ellipsis, None), (Ellipsis, None)), - ((Ellipsis, (Ellipsis, None)), (Ellipsis, None)), - ((Ellipsis, (Ellipsis, None, None)), (Ellipsis, None, None)), -]) +@pytest.mark.parametrize( + "args,expected", + [ + (0, 0), + (1, 1), + (None, None), + (slice(1, 2, 3), slice(1, 2, 3)), + (Ellipsis, Ellipsis), + ( + (0, 1, None, slice(1, 2, 3), Ellipsis), + (0, 1, None, slice(1, 2, 3), Ellipsis), + ), + ( + ((0, 1), (None, slice(1, 2, 3)), Ellipsis), + (0, 1, None, slice(1, 2, 3), Ellipsis), + ), + ((Ellipsis, None), (Ellipsis, None)), + ((Ellipsis, (Ellipsis, None)), (Ellipsis, None)), + ((Ellipsis, (Ellipsis, None, None)), (Ellipsis, None, None)), + ], +) def test_index(args, expected): assert Index(tensor_mock)[args] == expected diff --git a/tests/ops/test_integrator.py b/tests/ops/test_integrator.py index 0f0f85d6cb..db1053d86c 100644 --- a/tests/ops/test_integrator.py +++ b/tests/ops/test_integrator.py @@ -16,8 +16,10 @@ TEST_EXAMPLES = [] EXAMPLE_IDS = [] -ModelArgs = namedtuple('model_args', ['step_size', 'num_steps', 'q_i', 'p_i', 'q_f', 'p_f', 'prec']) -Example = namedtuple('test_case', ['model', 'args']) +ModelArgs = namedtuple( + "model_args", ["step_size", "num_steps", "q_i", "p_i", "q_f", "p_f", "prec"] +) +Example = namedtuple("test_case", ["model", "args"]) def register_model(init_args): @@ -25,25 +27,29 @@ def register_model(init_args): Register the model along with each of the model arguments as test examples. """ + def register_fn(model): for args in init_args: test_example = Example(model, args) TEST_EXAMPLES.append(test_example) EXAMPLE_IDS.append(model.__name__) + return register_fn -@register_model([ - ModelArgs( - step_size=0.01, - num_steps=100, - q_i={'x': torch.tensor([0.0])}, - p_i={'x': torch.tensor([1.0])}, - q_f={'x': torch.sin(torch.tensor([1.0]))}, - p_f={'x': torch.cos(torch.tensor([1.0]))}, - prec=1e-4 - ) -]) +@register_model( + [ + ModelArgs( + step_size=0.01, + num_steps=100, + q_i={"x": torch.tensor([0.0])}, + p_i={"x": torch.tensor([1.0])}, + q_f={"x": torch.sin(torch.tensor([1.0]))}, + p_f={"x": torch.cos(torch.tensor([1.0]))}, + prec=1e-4, + ) + ] +) class HarmonicOscillator: @staticmethod def kinetic_grad(p): @@ -51,24 +57,26 @@ def kinetic_grad(p): @staticmethod def energy(q, p): - return 0.5 * p['x'] ** 2 + 0.5 * q['x'] ** 2 + return 0.5 * p["x"] ** 2 + 0.5 * q["x"] ** 2 @staticmethod def potential_fn(q): - return 0.5 * q['x'] ** 2 - - -@register_model([ - ModelArgs( - step_size=0.01, - num_steps=628, - q_i={'x': torch.tensor([1.0]), 'y': torch.tensor([0.0])}, - p_i={'x': torch.tensor([0.0]), 'y': torch.tensor([1.0])}, - q_f={'x': torch.tensor([1.0]), 'y': torch.tensor([0.0])}, - p_f={'x': torch.tensor([0.0]), 'y': torch.tensor([1.0])}, - prec=5.0e-3 - ) -]) + return 0.5 * q["x"] ** 2 + + +@register_model( + [ + ModelArgs( + step_size=0.01, + num_steps=628, + q_i={"x": torch.tensor([1.0]), "y": torch.tensor([0.0])}, + p_i={"x": torch.tensor([0.0]), "y": torch.tensor([1.0])}, + q_f={"x": torch.tensor([1.0]), "y": torch.tensor([0.0])}, + p_f={"x": torch.tensor([0.0]), "y": torch.tensor([1.0])}, + prec=5.0e-3, + ) + ] +) class CircularPlanetaryMotion: @staticmethod def kinetic_grad(p): @@ -76,25 +84,30 @@ def kinetic_grad(p): @staticmethod def energy(q, p): - return 0.5 * p['x'] ** 2 + 0.5 * p['y'] ** 2 - \ - 1.0 / torch.pow(q['x'] ** 2 + q['y'] ** 2, 0.5) + return ( + 0.5 * p["x"] ** 2 + + 0.5 * p["y"] ** 2 + - 1.0 / torch.pow(q["x"] ** 2 + q["y"] ** 2, 0.5) + ) @staticmethod def potential_fn(q): - return - 1.0 / torch.pow(q['x'] ** 2 + q['y'] ** 2, 0.5) - - -@register_model([ - ModelArgs( - step_size=0.1, - num_steps=1810, - q_i={'x': torch.tensor([0.02])}, - p_i={'x': torch.tensor([0.0])}, - q_f={'x': torch.tensor([-0.02])}, - p_f={'x': torch.tensor([0.0])}, - prec=1.0e-4 - ) -]) + return -1.0 / torch.pow(q["x"] ** 2 + q["y"] ** 2, 0.5) + + +@register_model( + [ + ModelArgs( + step_size=0.1, + num_steps=1810, + q_i={"x": torch.tensor([0.02])}, + p_i={"x": torch.tensor([0.0])}, + q_f={"x": torch.tensor([-0.02])}, + p_f={"x": torch.tensor([0.0])}, + prec=1.0e-4, + ) + ] +) class QuarticOscillator: @staticmethod def kinetic_grad(p): @@ -102,37 +115,41 @@ def kinetic_grad(p): @staticmethod def energy(q, p): - return 0.5 * p['x'] ** 2 + 0.25 * torch.pow(q['x'], 4.0) + return 0.5 * p["x"] ** 2 + 0.25 * torch.pow(q["x"], 4.0) @staticmethod def potential_fn(q): - return 0.25 * torch.pow(q['x'], 4.0) + return 0.25 * torch.pow(q["x"], 4.0) -@pytest.mark.parametrize('example', TEST_EXAMPLES, ids=EXAMPLE_IDS) +@pytest.mark.parametrize("example", TEST_EXAMPLES, ids=EXAMPLE_IDS) def test_trajectory(example): model, args = example - q_f, p_f, _, _ = velocity_verlet(args.q_i, - args.p_i, - model.potential_fn, - model.kinetic_grad, - args.step_size, - args.num_steps) + q_f, p_f, _, _ = velocity_verlet( + args.q_i, + args.p_i, + model.potential_fn, + model.kinetic_grad, + args.step_size, + args.num_steps, + ) logger.info("initial q: {}".format(args.q_i)) logger.info("final q: {}".format(q_f)) assert_equal(q_f, args.q_f, args.prec) assert_equal(p_f, args.p_f, args.prec) -@pytest.mark.parametrize('example', TEST_EXAMPLES, ids=EXAMPLE_IDS) +@pytest.mark.parametrize("example", TEST_EXAMPLES, ids=EXAMPLE_IDS) def test_energy_conservation(example): model, args = example - q_f, p_f, _, _ = velocity_verlet(args.q_i, - args.p_i, - model.potential_fn, - model.kinetic_grad, - args.step_size, - args.num_steps) + q_f, p_f, _, _ = velocity_verlet( + args.q_i, + args.p_i, + model.potential_fn, + model.kinetic_grad, + args.step_size, + args.num_steps, + ) energy_initial = model.energy(args.q_i, args.p_i) energy_final = model.energy(q_f, p_f) logger.info("initial energy: {}".format(energy_initial.item())) @@ -140,20 +157,24 @@ def test_energy_conservation(example): assert_equal(energy_final, energy_initial) -@pytest.mark.parametrize('example', TEST_EXAMPLES, ids=EXAMPLE_IDS) +@pytest.mark.parametrize("example", TEST_EXAMPLES, ids=EXAMPLE_IDS) def test_time_reversibility(example): model, args = example - q_forward, p_forward, _, _ = velocity_verlet(args.q_i, - args.p_i, - model.potential_fn, - model.kinetic_grad, - args.step_size, - args.num_steps) + q_forward, p_forward, _, _ = velocity_verlet( + args.q_i, + args.p_i, + model.potential_fn, + model.kinetic_grad, + args.step_size, + args.num_steps, + ) p_reverse = {key: -val for key, val in p_forward.items()} - q_f, p_f, _, _ = velocity_verlet(q_forward, - p_reverse, - model.potential_fn, - model.kinetic_grad, - args.step_size, - args.num_steps) + q_f, p_f, _, _ = velocity_verlet( + q_forward, + p_reverse, + model.potential_fn, + model.kinetic_grad, + args.step_size, + args.num_steps, + ) assert_equal(q_f, args.q_i, 1e-5) diff --git a/tests/ops/test_jit.py b/tests/ops/test_jit.py index ce4cb6219f..45cb42f0c4 100644 --- a/tests/ops/test_jit.py +++ b/tests/ops/test_jit.py @@ -8,38 +8,35 @@ def test_varying_len_args(): - def fn(*args): return sum(args) jit_fn = pyro.ops.jit.trace(fn) examples = [ - [torch.tensor(1.)], - [torch.tensor(2.), torch.tensor(3.)], - [torch.tensor(4.), torch.tensor(5.), torch.tensor(6.)], + [torch.tensor(1.0)], + [torch.tensor(2.0), torch.tensor(3.0)], + [torch.tensor(4.0), torch.tensor(5.0), torch.tensor(6.0)], ] for args in examples: assert_equal(jit_fn(*args), fn(*args)) def test_varying_kwargs(): - - def fn(x, scale=1.): + def fn(x, scale=1.0): return x * scale jit_fn = pyro.ops.jit.trace(fn) - x = torch.tensor(1.) - for scale in [-1., 0., 1., 10.]: + x = torch.tensor(1.0) + for scale in [-1.0, 0.0, 1.0, 10.0]: assert_equal(jit_fn(x, scale=scale), fn(x, scale=scale)) def test_varying_unhashable_kwargs(): - def fn(x, config={}): - return x * config.get(scale, 1.) + return x * config.get(scale, 1.0) jit_fn = pyro.ops.jit.trace(fn) - x = torch.tensor(1.) - for scale in [-1., 0., 1., 10.]: - config = {'scale': scale} + x = torch.tensor(1.0) + for scale in [-1.0, 0.0, 1.0, 10.0]: + config = {"scale": scale} assert_equal(jit_fn(x, config=config), fn(x, config=config)) diff --git a/tests/ops/test_linalg.py b/tests/ops/test_linalg.py index 34452f0def..5b4497567d 100644 --- a/tests/ops/test_linalg.py +++ b/tests/ops/test_linalg.py @@ -8,14 +8,25 @@ from tests.common import assert_equal -@pytest.mark.parametrize("A", [ - torch.tensor([[17.]]), - torch.tensor([[1., 2.], [2., -3.]]), - torch.tensor([[1., 2, 0], [2, -2, 4], [0, 4, 5]]), - torch.tensor([[1., 2, 0, 7], [2, -2, 4, -1], [0, 4, 5, 8], [7, -1, 8, 1]]), - torch.tensor([[1., 2, 0, 7, 0], [2, -2, 4, -1, 2], [0, 4, 5, 8, -4], [7, -1, 8, 1, -3], [0, 2, -4, -3, -1]]), - torch.eye(40) - ]) +@pytest.mark.parametrize( + "A", + [ + torch.tensor([[17.0]]), + torch.tensor([[1.0, 2.0], [2.0, -3.0]]), + torch.tensor([[1.0, 2, 0], [2, -2, 4], [0, 4, 5]]), + torch.tensor([[1.0, 2, 0, 7], [2, -2, 4, -1], [0, 4, 5, 8], [7, -1, 8, 1]]), + torch.tensor( + [ + [1.0, 2, 0, 7, 0], + [2, -2, 4, -1, 2], + [0, 4, 5, 8, -4], + [7, -1, 8, 1, -3], + [0, 2, -4, -3, -1], + ] + ), + torch.eye(40), + ], +) @pytest.mark.parametrize("use_sym", [True, False]) def test_sym_rinverse(A, use_sym): d = A.shape[-1] diff --git a/tests/ops/test_newton.py b/tests/ops/test_newton.py index d264b3ae35..237cb9f4fe 100644 --- a/tests/ops/test_newton.py +++ b/tests/ops/test_newton.py @@ -24,15 +24,15 @@ def random_inside_unit_circle(shape, requires_grad=False): return x -@pytest.mark.parametrize('batch_shape', [(), (1,), (2,), (10,), (3, 2), (2, 3)]) -@pytest.mark.parametrize('trust_radius', [None, 2.0, 100.0]) -@pytest.mark.parametrize('dims', [1, 2, 3]) +@pytest.mark.parametrize("batch_shape", [(), (1,), (2,), (10,), (3, 2), (2, 3)]) +@pytest.mark.parametrize("trust_radius", [None, 2.0, 100.0]) +@pytest.mark.parametrize("dims", [1, 2, 3]) def test_newton_step(batch_shape, trust_radius, dims): batch_shape = torch.Size(batch_shape) mode = 0.5 * random_inside_unit_circle(batch_shape + (dims,), requires_grad=True) x = 0.5 * random_inside_unit_circle(batch_shape + (dims,), requires_grad=True) if trust_radius is not None: - assert trust_radius >= 2, '(x, mode) may be farther apart than trust_radius' + assert trust_radius >= 2, "(x, mode) may be farther apart than trust_radius" # create a quadratic loss function flat_x = x.reshape(-1, dims) @@ -51,30 +51,40 @@ def test_newton_step(batch_shape, trust_radius, dims): assert cov.shape == hessian.shape # check values - assert_equal(x_updated, mode, prec=1e-6, - msg='{} vs {}'.format(x_updated, mode)) + assert_equal(x_updated, mode, prec=1e-6, msg="{} vs {}".format(x_updated, mode)) flat_cov = cov.reshape(flat_hessian.shape) - assert_equal(flat_cov, flat_cov.transpose(-1, -2), - msg='covariance is not symmetric: {}'.format(flat_cov)) + assert_equal( + flat_cov, + flat_cov.transpose(-1, -2), + msg="covariance is not symmetric: {}".format(flat_cov), + ) actual_eye = torch.bmm(flat_cov, flat_hessian) expected_eye = torch.eye(dims).expand(actual_eye.shape) - assert_equal(actual_eye, expected_eye, prec=1e-4, - msg='bad covariance {}'.format(actual_eye)) + assert_equal( + actual_eye, expected_eye, prec=1e-4, msg="bad covariance {}".format(actual_eye) + ) # check gradients for i in itertools.product(*map(range, mode.shape)): expected_grad = torch.zeros(mode.shape) expected_grad[i] = 1 actual_grad = grad(x_updated[i], [mode], create_graph=True)[0] - assert_equal(actual_grad, expected_grad, prec=1e-5, msg='\n'.join([ - 'bad gradient at index {}'.format(i), - 'expected {}'.format(expected_grad), - 'actual {}'.format(actual_grad), - ])) - - -@pytest.mark.parametrize('trust_radius', [None, 0.1, 1.0, 10.0]) -@pytest.mark.parametrize('dims', [1, 2, 3]) + assert_equal( + actual_grad, + expected_grad, + prec=1e-5, + msg="\n".join( + [ + "bad gradient at index {}".format(i), + "expected {}".format(expected_grad), + "actual {}".format(actual_grad), + ] + ), + ) + + +@pytest.mark.parametrize("trust_radius", [None, 0.1, 1.0, 10.0]) +@pytest.mark.parametrize("dims", [1, 2, 3]) def test_newton_step_trust(trust_radius, dims): batch_size = 100 batch_shape = torch.Size((batch_size,)) @@ -96,13 +106,15 @@ def test_newton_step_trust(trust_radius, dims): # check values if trust_radius is None: - assert ((x - x_updated).pow(2).sum(-1) > 1.0).any(), 'test is too weak' + assert ((x - x_updated).pow(2).sum(-1) > 1.0).any(), "test is too weak" else: - assert ((x - x_updated).pow(2).sum(-1) <= 1e-8 + trust_radius**2).all(), 'trust region violated' + assert ( + (x - x_updated).pow(2).sum(-1) <= 1e-8 + trust_radius ** 2 + ).all(), "trust region violated" -@pytest.mark.parametrize('trust_radius', [None, 0.1, 1.0, 10.0]) -@pytest.mark.parametrize('dims', [1, 2, 3]) +@pytest.mark.parametrize("trust_radius", [None, 0.1, 1.0, 10.0]) +@pytest.mark.parametrize("dims", [1, 2, 3]) def test_newton_step_converges(trust_radius, dims): batch_size = 100 batch_shape = torch.Size((batch_size,)) @@ -124,6 +136,6 @@ def loss_fn(x): loss = loss_fn(x) x, cov = newton_step(loss, x, trust_radius=trust_radius) if ((x - mode).pow(2).sum(-1) < 1e-4).all(): - logger.debug('Newton iteration converged after {} steps'.format(2 + i)) + logger.debug("Newton iteration converged after {} steps".format(2 + i)) return - pytest.fail('Newton iteration did not converge') + pytest.fail("Newton iteration did not converge") diff --git a/tests/ops/test_packed.py b/tests/ops/test_packed.py index 2eeaa66985..56564d6aed 100644 --- a/tests/ops/test_packed.py +++ b/tests/ops/test_packed.py @@ -12,17 +12,17 @@ from tests.common import assert_equal EXAMPLE_DIMS = [ - ''.join(dims) + "".join(dims) for num_dims in range(5) - for dims in itertools.permutations('abcd'[:num_dims]) + for dims in itertools.permutations("abcd"[:num_dims]) ] -@pytest.mark.parametrize('dims', EXAMPLE_DIMS) +@pytest.mark.parametrize("dims", EXAMPLE_DIMS) def test_unpack_pack(dims): dim_to_symbol = {} symbol_to_dim = {} - for symbol, dim in zip('abcd', range(-1, -5, -1)): + for symbol, dim in zip("abcd", range(-1, -5, -1)): dim_to_symbol[dim] = symbol symbol_to_dim[symbol] = dim shape = tuple(range(2, 2 + len(dims))) @@ -32,7 +32,7 @@ def test_unpack_pack(dims): unpack_pack_x = packed.unpack(pack_x, symbol_to_dim) assert_equal(unpack_pack_x, x) - sort_dims = ''.join(sorted(dims)) + sort_dims = "".join(sorted(dims)) if sort_dims != pack_x._pyro_dims: sort_pack_x = pack_x.permute(*(pack_x._pyro_dims.index(d) for d in sort_dims)) sort_pack_x._pyro_dims = sort_dims @@ -56,13 +56,13 @@ def make_inputs(shapes, num_numbers=0): inputs.append(random.random()) dim_to_symbol = {} symbol_to_dim = {} - for dim, symbol in zip(range(-num_symbols, 0), 'abcdefghijklmnopqrstuvwxyz'): + for dim, symbol in zip(range(-num_symbols, 0), "abcdefghijklmnopqrstuvwxyz"): dim_to_symbol[dim] = symbol symbol_to_dim[symbol] = dim return inputs, dim_to_symbol, symbol_to_dim -@pytest.mark.parametrize('shapes', EXAMPLE_SHAPES) +@pytest.mark.parametrize("shapes", EXAMPLE_SHAPES) def test_broadcast_all(shapes): inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes) packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs] diff --git a/tests/ops/test_special.py b/tests/ops/test_special.py index 2387be8dc9..426297d465 100644 --- a/tests/ops/test_special.py +++ b/tests/ops/test_special.py @@ -20,14 +20,26 @@ def test_safe_log(): assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0]) # Test gradients. - x = torch.tensor(0., requires_grad=True) + x = torch.tensor(0.0, requires_grad=True) assert not torch.isfinite(grad(x.log(), [x])[0]) assert torch.isfinite(grad(safe_log(x), [x])[0]) -@pytest.mark.parametrize("tol", [ - 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., -]) +@pytest.mark.parametrize( + "tol", + [ + 1e-8, + 1e-6, + 1e-4, + 1e-2, + 0.02, + 0.05, + 0.1, + 0.2, + 0.1, + 1.0, + ], +) def test_log_beta_stirling(tol): x = torch.logspace(-5, 5, 200) y = x.unsqueeze(-1) @@ -39,11 +51,23 @@ def test_log_beta_stirling(tol): assert (expected < actual + tol).all() -@pytest.mark.parametrize("tol", [ - 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., -]) +@pytest.mark.parametrize( + "tol", + [ + 1e-8, + 1e-6, + 1e-4, + 1e-2, + 0.02, + 0.05, + 0.1, + 0.2, + 0.1, + 1.0, + ], +) def test_log_binomial_stirling(tol): - k = torch.arange(200.) + k = torch.arange(200.0) n_minus_k = k.unsqueeze(-1) n = k + n_minus_k @@ -54,8 +78,8 @@ def test_log_binomial_stirling(tol): assert (actual - expected).abs().max() < tol -@pytest.mark.parametrize('order', [0, 1, 5, 10, 20]) -@pytest.mark.parametrize('value', [0.01, .1, 1., 10., 100.]) +@pytest.mark.parametrize("order", [0, 1, 5, 10, 20]) +@pytest.mark.parametrize("value", [0.01, 0.1, 1.0, 10.0, 100.0]) def test_log_I1(order, value): value = tensor([value]) expected = torch.tensor([iv(i, value.numpy()) for i in range(order + 1)]).log() @@ -64,8 +88,8 @@ def test_log_I1(order, value): def test_log_I1_shapes(): - assert_equal(log_I1(10, tensor(.6)).shape, torch.Size([11, 1])) - assert_equal(log_I1(10, tensor([.6])).shape, torch.Size([11, 1])) - assert_equal(log_I1(10, tensor([[.6]])).shape, torch.Size([11, 1, 1])) - assert_equal(log_I1(10, tensor([.6, .2])).shape, torch.Size([11, 2])) - assert_equal(log_I1(0, tensor(.6)).shape, torch.Size((1, 1))) + assert_equal(log_I1(10, tensor(0.6)).shape, torch.Size([11, 1])) + assert_equal(log_I1(10, tensor([0.6])).shape, torch.Size([11, 1])) + assert_equal(log_I1(10, tensor([[0.6]])).shape, torch.Size([11, 1, 1])) + assert_equal(log_I1(10, tensor([0.6, 0.2])).shape, torch.Size([11, 2])) + assert_equal(log_I1(0, tensor(0.6)).shape, torch.Size((1, 1))) diff --git a/tests/ops/test_ssm_gp.py b/tests/ops/test_ssm_gp.py index 22f2b214d3..1a792f80d5 100644 --- a/tests/ops/test_ssm_gp.py +++ b/tests/ops/test_ssm_gp.py @@ -8,10 +8,12 @@ from tests.common import assert_equal -@pytest.mark.parametrize('num_gps', [1, 2, 3]) -@pytest.mark.parametrize('nu', [0.5, 1.5, 2.5]) +@pytest.mark.parametrize("num_gps", [1, 2, 3]) +@pytest.mark.parametrize("nu", [0.5, 1.5, 2.5]) def test_matern_kernel(num_gps, nu): - mk = MaternKernel(nu=nu, num_gps=num_gps, length_scale_init=0.1 + torch.rand(num_gps)) + mk = MaternKernel( + nu=nu, num_gps=num_gps, length_scale_init=0.1 + torch.rand(num_gps) + ) dt = torch.rand(1).item() forward = mk.transition_matrix(dt) @@ -19,7 +21,9 @@ def test_matern_kernel(num_gps, nu): forward_backward = torch.matmul(forward, backward) # going forward dt in time and then backward dt in time should bring us back to the identity - eye = torch.eye(mk.state_dim).unsqueeze(0).expand(num_gps, mk.state_dim, mk.state_dim) + eye = ( + torch.eye(mk.state_dim).unsqueeze(0).expand(num_gps, mk.state_dim, mk.state_dim) + ) assert_equal(forward_backward, eye) # let's just check that these are PSD diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index 6d6e3b0c9c..1027f18d25 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -24,7 +24,7 @@ from tests.common import assert_close, assert_equal, xfail_if_not_implemented -@pytest.mark.parametrize('replacement', [True, False]) +@pytest.mark.parametrize("replacement", [True, False]) def test_resample(replacement): x = torch.empty(10000, 2) x[:, 0].normal_(3, 4) @@ -38,21 +38,23 @@ def test_resample(replacement): assert_equal(torch.unique(z.reshape(-1)).numel(), z.numel()) assert_equal(y.shape, torch.Size([num_samples, 2])) assert_equal(z.shape, torch.Size([2, num_samples])) - assert_equal(y.mean(dim=0), torch.tensor([3., 5.]), prec=0.2) - assert_equal(z.mean(dim=1), torch.tensor([3., 5.]), prec=0.2) - assert_equal(y.std(dim=0), torch.tensor([4., 6.]), prec=0.2) - assert_equal(z.std(dim=1), torch.tensor([4., 6.]), prec=0.2) + assert_equal(y.mean(dim=0), torch.tensor([3.0, 5.0]), prec=0.2) + assert_equal(z.mean(dim=1), torch.tensor([3.0, 5.0]), prec=0.2) + assert_equal(y.std(dim=0), torch.tensor([4.0, 6.0]), prec=0.2) + assert_equal(z.std(dim=1), torch.tensor([4.0, 6.0]), prec=0.2) @pytest.mark.init(rng_seed=3) def test_quantile(): - x = torch.tensor([0., 1., 2.]) + x = torch.tensor([0.0, 1.0, 2.0]) y = torch.rand(2000) z = torch.randn(2000) - assert_equal(quantile(x, probs=[0., 0.4, 0.5, 1.]), torch.tensor([0., 0.8, 1., 2.])) + assert_equal( + quantile(x, probs=[0.0, 0.4, 0.5, 1.0]), torch.tensor([0.0, 0.8, 1.0, 2.0]) + ) assert_equal(quantile(y, probs=0.2), torch.tensor(0.2), prec=0.02) - assert_equal(quantile(z, probs=0.8413), torch.tensor(1.), prec=0.02) + assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02) def test_pi(): @@ -66,7 +68,7 @@ def test_hpdi(): assert_equal(hpdi(x, prob=0.8), pi(x, prob=0.8), prec=0.01) x = torch.empty(20000).exponential_(1) - assert_equal(hpdi(x, prob=0.2), torch.tensor([0., 0.22]), prec=0.01) + assert_equal(hpdi(x, prob=0.2), torch.tensor([0.0, 0.22]), prec=0.01) def _quantile(x, dim=0): @@ -81,8 +83,8 @@ def _hpdi(x, dim=0): return hpdi(x, prob=0.8, dim=dim) -@pytest.mark.parametrize('statistics', [_quantile, _pi, _hpdi]) -@pytest.mark.parametrize('sample_shape', [(), (3,), (2, 3)]) +@pytest.mark.parametrize("statistics", [_quantile, _pi, _hpdi]) +@pytest.mark.parametrize("sample_shape", [(), (3,), (2, 3)]) def test_statistics_A_ok_with_sample_shape(statistics, sample_shape): xs = torch.rand((10,) + torch.Size(sample_shape)) y = statistics(xs) @@ -102,21 +104,27 @@ def test_statistics_A_ok_with_sample_shape(statistics, sample_shape): def test_autocorrelation(): - x = torch.arange(10.) + x = torch.arange(10.0) with xfail_if_not_implemented(): actual = autocorrelation(x) - assert_equal(actual, - torch.tensor([1, 0.78, 0.52, 0.21, -0.13, - -0.52, -0.94, -1.4, -1.91, -2.45]), prec=0.01) + assert_equal( + actual, + torch.tensor([1, 0.78, 0.52, 0.21, -0.13, -0.52, -0.94, -1.4, -1.91, -2.45]), + prec=0.01, + ) def test_autocovariance(): - x = torch.arange(10.) + x = torch.arange(10.0) with xfail_if_not_implemented(): actual = autocovariance(x) - assert_equal(actual, - torch.tensor([8.25, 6.42, 4.25, 1.75, -1.08, - -4.25, -7.75, -11.58, -15.75, -20.25]), prec=0.01) + assert_equal( + actual, + torch.tensor( + [8.25, 6.42, 4.25, 1.75, -1.08, -4.25, -7.75, -11.58, -15.75, -20.25] + ), + prec=0.01, + ) def test_cummin(): @@ -124,13 +132,13 @@ def test_cummin(): y = torch.empty(x.shape) y[0] = x[0] for i in range(1, x.size(0)): - y[i] = min(x[i], y[i-1]) + y[i] = min(x[i], y[i - 1]) assert_equal(_cummin(x), y) -@pytest.mark.parametrize('statistics', [autocorrelation, autocovariance, _cummin]) -@pytest.mark.parametrize('sample_shape', [(), (3,), (2, 3)]) +@pytest.mark.parametrize("statistics", [autocorrelation, autocovariance, _cummin]) +@pytest.mark.parametrize("sample_shape", [(), (3,), (2, 3)]) def test_statistics_B_ok_with_sample_shape(statistics, sample_shape): xs = torch.rand((10,) + torch.Size(sample_shape)) with xfail_if_not_implemented(): @@ -154,8 +162,8 @@ def test_statistics_B_ok_with_sample_shape(statistics, sample_shape): def test_gelman_rubin(): # only need to test precision for small data x = torch.empty(2, 10) - x[0, :] = torch.arange(10.) - x[1, :] = torch.arange(10.) + 1 + x[0, :] = torch.arange(10.0) + x[1, :] = torch.arange(10.0) + 1 r_hat = gelman_rubin(x) assert_equal(r_hat.item(), 0.98, prec=0.01) @@ -169,15 +177,17 @@ def test_split_gelman_rubin_agree_with_gelman_rubin(): def test_effective_sample_size(): - x = torch.arange(1000.).reshape(100, 10) + x = torch.arange(1000.0).reshape(100, 10) with xfail_if_not_implemented(): # test against arviz assert_equal(effective_sample_size(x).item(), 52.64, prec=0.01) -@pytest.mark.parametrize('diagnostics', [gelman_rubin, split_gelman_rubin, effective_sample_size]) -@pytest.mark.parametrize('sample_shape', [(), (3,), (2, 3)]) +@pytest.mark.parametrize( + "diagnostics", [gelman_rubin, split_gelman_rubin, effective_sample_size] +) +@pytest.mark.parametrize("sample_shape", [(), (3,), (2, 3)]) def test_diagnostics_ok_with_sample_shape(diagnostics, sample_shape): sample_shape = torch.Size(sample_shape) xs = torch.rand((4, 100) + sample_shape) @@ -204,7 +214,7 @@ def test_diagnostics_ok_with_sample_shape(diagnostics, sample_shape): def test_waic(): - x = - torch.arange(1., 101).log().reshape(25, 4) + x = -torch.arange(1.0, 101).log().reshape(25, 4) w_pw, p_pw = waic(x, pointwise=True) w, p = waic(x) w1, p1 = waic(x.t(), dim=1) @@ -226,7 +236,7 @@ def test_weighted_waic(): c = 1 + torch.rand(10) expanded_x = torch.stack([a, b, c, a, b, a, c, a, c]).log() x = torch.stack([a, b, c]).log() - log_weights = torch.tensor([4., 2, 3]).log() + log_weights = torch.tensor([4.0, 2, 3]).log() # assume weights are unnormalized log_weights = log_weights - torch.randn(1) @@ -249,8 +259,8 @@ def test_weighted_waic(): assert_equal(p1, p3) -@pytest.mark.parametrize('k', [0.2, 0.5]) -@pytest.mark.parametrize('sigma', [0.8, 1.3]) +@pytest.mark.parametrize("k", [0.2, 0.5]) +@pytest.mark.parametrize("sigma", [0.8, 1.3]) def test_fit_generalized_pareto(k, sigma, n_samples=5000): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -262,8 +272,8 @@ def test_fit_generalized_pareto(k, sigma, n_samples=5000): assert_equal(sigma, fit_sigma, prec=0.02) -@pytest.mark.parametrize('event_shape', [(), (4,), (3, 2)]) -@pytest.mark.parametrize('num_samples', [1, 2, 3, 4, 10]) +@pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)]) +@pytest.mark.parametrize("num_samples", [1, 2, 3, 4, 10]) def test_crps_empirical(num_samples, event_shape): truth = torch.randn(event_shape) pred = truth + 0.1 * torch.randn((num_samples,) + event_shape) @@ -271,6 +281,7 @@ def test_crps_empirical(num_samples, event_shape): actual = crps_empirical(pred, truth) assert actual.shape == truth.shape - expected = ((pred - truth).abs().mean(0) - - 0.5 * (pred - pred.unsqueeze(1)).abs().mean([0, 1])) + expected = (pred - truth).abs().mean(0) - 0.5 * ( + pred - pred.unsqueeze(1) + ).abs().mean([0, 1]) assert_close(actual, expected) diff --git a/tests/ops/test_tensor_utils.py b/tests/ops/test_tensor_utils.py index ed8a6468f2..e3531d2cf2 100644 --- a/tests/ops/test_tensor_utils.py +++ b/tests/ops/test_tensor_utils.py @@ -24,11 +24,13 @@ ) from tests.common import assert_close, assert_equal -pytestmark = pytest.mark.stage('unit') +pytestmark = pytest.mark.stage("unit") -@pytest.mark.parametrize('batch_size', [1, 2, 3]) -@pytest.mark.parametrize('block_size', [torch.Size([2, 2]), torch.Size([3, 1]), torch.Size([4, 2])]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +@pytest.mark.parametrize( + "block_size", [torch.Size([2, 2]), torch.Size([3, 1]), torch.Size([4, 2])] +) def test_block_diag_embed(batch_size, block_size): m = torch.randn(block_size).unsqueeze(0).expand((batch_size,) + block_size) b = block_diag_embed(m) @@ -43,9 +45,11 @@ def test_block_diag_embed(batch_size, block_size): assert_equal(b[bottom:top, left:right], m[k]) -@pytest.mark.parametrize('batch_shape', [torch.Size([]), torch.Size([7])]) -@pytest.mark.parametrize('mat_size,block_size', [(torch.Size([2, 2]), 2), (torch.Size([3, 1]), 1), - (torch.Size([6, 3]), 3)]) +@pytest.mark.parametrize("batch_shape", [torch.Size([]), torch.Size([7])]) +@pytest.mark.parametrize( + "mat_size,block_size", + [(torch.Size([2, 2]), 2), (torch.Size([3, 1]), 1), (torch.Size([6, 3]), 3)], +) def test_block_diag(batch_shape, mat_size, block_size): mat = torch.randn(batch_shape + (block_size,) + mat_size) mat_embed = block_diag_embed(mat) @@ -75,7 +79,10 @@ def test_periodic_features(duration): min_period = torch.distributions.Uniform(2, max_period).sample().item() for min_period in [min_period, 2]: actual = periodic_features(duration, max_period, min_period) - assert actual.shape == (duration, 2 * math.ceil(max_period / min_period) - 2) + assert actual.shape == ( + duration, + 2 * math.ceil(max_period / min_period) - 2, + ) assert (-1 <= actual).all() assert (actual <= 1).all() @@ -93,12 +100,14 @@ def test_periodic_cumsum(period, size, left_shape, right_shape): for t in range(period): assert_equal(actual[dots + (t,)], tensor[dots + (t,)]) for t in range(period, size): - assert_close(actual[dots + (t,)], tensor[dots + (t,)] + actual[dots + (t - period,)]) + assert_close( + actual[dots + (t,)], tensor[dots + (t,)] + actual[dots + (t - period,)] + ) -@pytest.mark.parametrize('m', [2, 3, 4, 5, 6, 10]) -@pytest.mark.parametrize('n', [2, 3, 4, 5, 6, 10]) -@pytest.mark.parametrize('mode', ['full', 'valid', 'same']) +@pytest.mark.parametrize("m", [2, 3, 4, 5, 6, 10]) +@pytest.mark.parametrize("n", [2, 3, 4, 5, 6, 10]) +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) def test_convolve_shape(m, n, mode): signal = torch.randn(m) kernel = torch.randn(n) @@ -107,23 +116,27 @@ def test_convolve_shape(m, n, mode): assert actual.shape == expected.shape -@pytest.mark.parametrize('m', [2, 3, 4, 5, 6, 10]) -@pytest.mark.parametrize('n', [2, 3, 4, 5, 6, 10]) -@pytest.mark.parametrize('batch_shape', [(), (4,), (2, 3)], ids=str) -@pytest.mark.parametrize('mode', ['full', 'valid', 'same']) +@pytest.mark.parametrize("m", [2, 3, 4, 5, 6, 10]) +@pytest.mark.parametrize("n", [2, 3, 4, 5, 6, 10]) +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) def test_convolve(batch_shape, m, n, mode): signal = torch.randn(*batch_shape, m) kernel = torch.randn(*batch_shape, n) actual = convolve(signal, kernel, mode) - expected = torch.stack([ - torch.tensor(np.convolve(s, k, mode=mode)) - for s, k in zip(signal.reshape(-1, m), kernel.reshape(-1, n)) - ]).reshape(*batch_shape, -1) + expected = torch.stack( + [ + torch.tensor(np.convolve(s, k, mode=mode)) + for s, k in zip(signal.reshape(-1, m), kernel.reshape(-1, n)) + ] + ).reshape(*batch_shape, -1) assert_close(actual, expected) -@pytest.mark.parametrize('size', [torch.Size([2, 2]), torch.Size([4, 3, 3]), torch.Size([4, 1, 2, 2])]) -@pytest.mark.parametrize('n', [1, 2, 3, 7, 8]) +@pytest.mark.parametrize( + "size", [torch.Size([2, 2]), torch.Size([4, 3, 3]), torch.Size([4, 1, 2, 2])] +) +@pytest.mark.parametrize("n", [1, 2, 3, 7, 8]) def test_repeated_matmul(size, n): M = torch.randn(size) result = repeated_matmul(M, n) @@ -135,19 +148,19 @@ def test_repeated_matmul(size, n): serial_result = torch.matmul(serial_result, M) -@pytest.mark.parametrize('shape', [(3, 4), (5,), (2, 1, 6)]) +@pytest.mark.parametrize("shape", [(3, 4), (5,), (2, 1, 6)]) def test_dct(shape): x = torch.randn(shape) actual = dct(x) - expected = torch.from_numpy(fftpack.dct(x.numpy(), norm='ortho')) + expected = torch.from_numpy(fftpack.dct(x.numpy(), norm="ortho")) assert_close(actual, expected) -@pytest.mark.parametrize('shape', [(3, 4), (5,), (2, 1, 6)]) +@pytest.mark.parametrize("shape", [(3, 4), (5,), (2, 1, 6)]) def test_idct(shape): x = torch.randn(shape) actual = idct(x) - expected = torch.from_numpy(fftpack.idct(x.numpy(), norm='ortho')) + expected = torch.from_numpy(fftpack.idct(x.numpy(), norm="ortho")) assert_close(actual, expected) @@ -168,10 +181,13 @@ def test_next_fast_len(): assert next_fast_len(size) == fftpack.next_fast_len(size) -@pytest.mark.parametrize('batch_shape,event_shape', [ - ((), (5,)), - ((3,), (4,)), -]) +@pytest.mark.parametrize( + "batch_shape,event_shape", + [ + ((), (5,)), + ((3,), (4,)), + ], +) def test_precision_to_scale_tril(batch_shape, event_shape): x = torch.randn(batch_shape + event_shape + event_shape) precision = x.matmul(x.transpose(-2, -1)) diff --git a/tests/ops/test_welford.py b/tests/ops/test_welford.py index 24137ae1ad..d345c7a75f 100644 --- a/tests/ops/test_welford.py +++ b/tests/ops/test_welford.py @@ -10,9 +10,7 @@ from tests.common import assert_equal -@pytest.mark.parametrize('n_samples,dim_size', [(1000, 1), - (1000, 7), - (1, 1)]) +@pytest.mark.parametrize("n_samples,dim_size", [(1000, 1), (1000, 7), (1, 1)]) @pytest.mark.init(rng_seed=7) def test_welford_diagonal(n_samples, dim_size): w = WelfordCovariance(diagonal=True) @@ -32,9 +30,7 @@ def test_welford_diagonal(n_samples, dim_size): assert_equal(estimates, sample_variance) -@pytest.mark.parametrize('n_samples,dim_size', [(1000, 1), - (1000, 7), - (1, 1)]) +@pytest.mark.parametrize("n_samples,dim_size", [(1000, 1), (1000, 7), (1, 1)]) @pytest.mark.init(rng_seed=7) def test_welford_dense(n_samples, dim_size): w = WelfordCovariance(diagonal=False) @@ -52,13 +48,11 @@ def test_welford_dense(n_samples, dim_size): assert_equal(estimates, sample_cov) -@pytest.mark.parametrize('n_samples,dim_size,head_size', [ - (1000, 5, 0), - (1000, 5, 1), - (1000, 5, 4), - (1000, 5, 5) -]) -@pytest.mark.parametrize('regularize', [True, False]) +@pytest.mark.parametrize( + "n_samples,dim_size,head_size", + [(1000, 5, 0), (1000, 5, 1), (1000, 5, 4), (1000, 5, 5)], +) +@pytest.mark.parametrize("regularize", [True, False]) def test_welford_arrowhead(n_samples, dim_size, head_size, regularize): adapt_scheme = WelfordArrowheadCovariance(head_size=head_size) loc = torch.zeros(dim_size) @@ -70,14 +64,18 @@ def test_welford_arrowhead(n_samples, dim_size, head_size, regularize): for sample in samples: adapt_scheme.update(sample) top, bottom_diag = adapt_scheme.get_covariance(regularize=regularize) - actual = torch.cat([top, torch.cat([top[:, head_size:].t(), bottom_diag.diag()], -1)]) + actual = torch.cat( + [top, torch.cat([top[:, head_size:].t(), bottom_diag.diag()], -1)] + ) mask = torch.ones(dim_size, dim_size) - mask[head_size:, head_size:] = 0. - mask.view(-1)[::dim_size + 1][head_size:] = 1. + mask[head_size:, head_size:] = 0.0 + mask.view(-1)[:: dim_size + 1][head_size:] = 1.0 expected = np.cov(samples.cpu().numpy(), bias=False, rowvar=False) expected = torch.from_numpy(expected).type_as(mask) if regularize: - expected = (expected * n_samples + 1e-3 * torch.eye(dim_size) * 5) / (n_samples + 5) + expected = (expected * n_samples + 1e-3 * torch.eye(dim_size) * 5) / ( + n_samples + 5 + ) expected = expected * mask assert_equal(actual, expected) diff --git a/tests/optim/test_multi.py b/tests/optim/test_multi.py index 3bd523dd92..07de5a7f53 100644 --- a/tests/optim/test_multi.py +++ b/tests/optim/test_multi.py @@ -18,24 +18,31 @@ from tests.common import assert_equal FACTORIES = [ - lambda: PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05})), - lambda: TorchMultiOptimizer(torch.optim.Adam, {'lr': 0.05}), - lambda: Newton(trust_radii={'z': 0.2}), - lambda: MixedMultiOptimizer([(['y'], PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05}))), - (['x', 'z'], Newton())]), - lambda: MixedMultiOptimizer([(['y'], pyro.optim.Adam({'lr': 0.05})), - (['x', 'z'], Newton())]), + lambda: PyroMultiOptimizer(pyro.optim.Adam({"lr": 0.05})), + lambda: TorchMultiOptimizer(torch.optim.Adam, {"lr": 0.05}), + lambda: Newton(trust_radii={"z": 0.2}), + lambda: MixedMultiOptimizer( + [ + (["y"], PyroMultiOptimizer(pyro.optim.Adam({"lr": 0.05}))), + (["x", "z"], Newton()), + ] + ), + lambda: MixedMultiOptimizer( + [(["y"], pyro.optim.Adam({"lr": 0.05})), (["x", "z"], Newton())] + ), ] -@pytest.mark.parametrize('factory', FACTORIES) +@pytest.mark.parametrize("factory", FACTORIES) def test_optimizers(factory): optim = factory() def model(loc, cov): x = pyro.param("x", torch.randn(2)) y = pyro.param("y", torch.randn(3, 2)) - z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1)) + z = pyro.param( + "z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1) + ) pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x) with pyro.plate("y_plate", 3): pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y) @@ -47,26 +54,36 @@ def model(loc, cov): for step in range(200): tr = poutine.trace(model).get_trace(loc, cov) loss = -tr.log_prob_sum() - params = {name: site['value'].unconstrained() - for name, site in tr.nodes.items() - if site['type'] == 'param'} + params = { + name: site["value"].unconstrained() + for name, site in tr.nodes.items() + if site["type"] == "param" + } optim.step(loss, params) for name in ["x", "y", "z"]: actual = pyro.param(name) expected = loc.expand(actual.shape) - assert_equal(actual, expected, prec=1e-2, - msg='{} in correct: {} vs {}'.format(name, actual, expected)) + assert_equal( + actual, + expected, + prec=1e-2, + msg="{} in correct: {} vs {}".format(name, actual, expected), + ) def test_multi_optimizer_disjoint_ok(): - parts = [(['w', 'x'], pyro.optim.Adam({'lr': 0.1})), - (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))] + parts = [ + (["w", "x"], pyro.optim.Adam({"lr": 0.1})), + (["y", "z"], pyro.optim.Adam({"lr": 0.01})), + ] MixedMultiOptimizer(parts) def test_multi_optimizer_overlap_error(): - parts = [(['x', 'y'], pyro.optim.Adam({'lr': 0.1})), - (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))] + parts = [ + (["x", "y"], pyro.optim.Adam({"lr": 0.1})), + (["y", "z"], pyro.optim.Adam({"lr": 0.01})), + ] with pytest.raises(ValueError): MixedMultiOptimizer(parts) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index d894777337..8902f2218e 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -18,7 +18,6 @@ class OptimTests(TestCase): - def setUp(self): # normal-normal; known covariance self.lam0 = torch.tensor([0.1]) # precision of prior @@ -43,20 +42,16 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param( - "loc_q", - torch.zeros(1, requires_grad=True)) - log_sig_q = pyro.param( - "log_sig_q", - torch.zeros(1, requires_grad=True)) + loc_q = pyro.param("loc_q", torch.zeros(1, requires_grad=True)) + log_sig_q = pyro.param("log_sig_q", torch.zeros(1, requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", Normal(loc_q, sig_q)) def optim_params(param_name): if param_name == fixed_param: - return {'lr': 0.00} + return {"lr": 0.00} elif param_name == free_param: - return {'lr': 0.01} + return {"lr": 0.01} adam = optim.Adam(optim_params) adam2 = optim.Adam(optim_params) @@ -64,89 +59,127 @@ def optim_params(param_name): svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO()) svi.step() - adam_initial_step_count = list(adam.get_state()['loc_q']['state'].items())[0][1]['step'] - adam.save('adam.unittest.save') + adam_initial_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][ + 1 + ]["step"] + adam.save("adam.unittest.save") svi.step() - adam_final_step_count = list(adam.get_state()['loc_q']['state'].items())[0][1]['step'] - adam2.load('adam.unittest.save') + adam_final_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][1][ + "step" + ] + adam2.load("adam.unittest.save") svi2.step() - adam2_step_count_after_load_and_step = list(adam2.get_state()['loc_q']['state'].items())[0][1]['step'] + adam2_step_count_after_load_and_step = list( + adam2.get_state()["loc_q"]["state"].items() + )[0][1]["step"] assert adam_initial_step_count == 1 assert adam_final_step_count == 2 assert adam2_step_count_after_load_and_step == 2 free_param_unchanged = torch.equal(pyro.param(free_param).data, torch.zeros(1)) - fixed_param_unchanged = torch.equal(pyro.param(fixed_param).data, torch.zeros(1)) + fixed_param_unchanged = torch.equal( + pyro.param(fixed_param).data, torch.zeros(1) + ) assert fixed_param_unchanged and not free_param_unchanged -@pytest.mark.parametrize('scheduler', [optim.LambdaLR({'optimizer': torch.optim.SGD, 'optim_args': {'lr': 0.01}, - 'lr_lambda': lambda epoch: 2. ** epoch}), - optim.StepLR({'optimizer': torch.optim.SGD, 'optim_args': {'lr': 0.01}, - 'gamma': 2, 'step_size': 1}), - optim.ExponentialLR({'optimizer': torch.optim.SGD, 'optim_args': {'lr': 0.01}, - 'gamma': 2}), - optim.ReduceLROnPlateau({'optimizer': torch.optim.SGD, 'optim_args': {'lr': 1.0}, - 'factor': 0.1, 'patience': 1})]) +@pytest.mark.parametrize( + "scheduler", + [ + optim.LambdaLR( + { + "optimizer": torch.optim.SGD, + "optim_args": {"lr": 0.01}, + "lr_lambda": lambda epoch: 2.0 ** epoch, + } + ), + optim.StepLR( + { + "optimizer": torch.optim.SGD, + "optim_args": {"lr": 0.01}, + "gamma": 2, + "step_size": 1, + } + ), + optim.ExponentialLR( + {"optimizer": torch.optim.SGD, "optim_args": {"lr": 0.01}, "gamma": 2} + ), + optim.ReduceLROnPlateau( + { + "optimizer": torch.optim.SGD, + "optim_args": {"lr": 1.0}, + "factor": 0.1, + "patience": 1, + } + ), + ], +) def test_dynamic_lr(scheduler): pyro.clear_param_store() def model(): - sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3))) - return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1)) + sample = pyro.sample("latent", Normal(torch.tensor(0.0), torch.tensor(0.3))) + return pyro.sample( + "obs", Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1) + ) def guide(): - loc = pyro.param('loc', torch.tensor(0.)) - scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive) - pyro.sample('latent', Normal(loc, scale)) + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) + pyro.sample("latent", Normal(loc, scale)) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for epoch in range(4): svi.step() svi.step() - loc = pyro.param('loc').unconstrained() + loc = pyro.param("loc").unconstrained() opt_loc = scheduler.optim_objs[loc].optimizer opt_scale = scheduler.optim_objs[loc].optimizer - if issubclass(scheduler.pt_scheduler_constructor, torch.optim.lr_scheduler.ReduceLROnPlateau): - scheduler.step(1.) + if issubclass( + scheduler.pt_scheduler_constructor, + torch.optim.lr_scheduler.ReduceLROnPlateau, + ): + scheduler.step(1.0) if epoch == 2: - assert opt_loc.state_dict()['param_groups'][0]['lr'] == 0.1 - assert opt_scale.state_dict()['param_groups'][0]['lr'] == 0.1 + assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.1 + assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.1 if epoch == 4: - assert opt_loc.state_dict()['param_groups'][0]['lr'] == 0.01 - assert opt_scale.state_dict()['param_groups'][0]['lr'] == 0.01 + assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.01 + assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.01 continue - assert opt_loc.state_dict()['param_groups'][0]['initial_lr'] == 0.01 - assert opt_scale.state_dict()['param_groups'][0]['initial_lr'] == 0.01 + assert opt_loc.state_dict()["param_groups"][0]["initial_lr"] == 0.01 + assert opt_scale.state_dict()["param_groups"][0]["initial_lr"] == 0.01 if epoch == 0: scheduler.step() - assert opt_loc.state_dict()['param_groups'][0]['lr'] == 0.02 - assert opt_scale.state_dict()['param_groups'][0]['lr'] == 0.02 - assert abs(pyro.param('loc').item()) > 1e-5 - assert abs(pyro.param('scale').item() - 0.5) > 1e-5 + assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.02 + assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.02 + assert abs(pyro.param("loc").item()) > 1e-5 + assert abs(pyro.param("scale").item() - 0.5) > 1e-5 if epoch == 2: scheduler.step() - assert opt_loc.state_dict()['param_groups'][0]['lr'] == 0.04 - assert opt_scale.state_dict()['param_groups'][0]['lr'] == 0.04 + assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.04 + assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.04 -@pytest.mark.parametrize('factory', [optim.Adam, optim.ClippedAdam, optim.DCTAdam, optim.RMSprop, optim.SGD]) +@pytest.mark.parametrize( + "factory", [optim.Adam, optim.ClippedAdam, optim.DCTAdam, optim.RMSprop, optim.SGD] +) def test_autowrap(factory): instance = factory({}) assert instance.pt_optim_constructor.__name__ == factory.__name__ -@pytest.mark.parametrize('pyro_optim', [optim.Adam, optim.SGD]) -@pytest.mark.parametrize('clip', ['clip_norm', 'clip_value']) -@pytest.mark.parametrize('value', [1., 3., 5.]) +@pytest.mark.parametrize("pyro_optim", [optim.Adam, optim.SGD]) +@pytest.mark.parametrize("clip", ["clip_norm", "clip_value"]) +@pytest.mark.parametrize("value", [1.0, 3.0, 5.0]) def test_clip_norm(pyro_optim, clip, value): - x1 = torch.tensor(0., requires_grad=True) - x2 = torch.tensor(0., requires_grad=True) - opt_c = pyro_optim({"lr": 1.}, {clip: value}) - opt = pyro_optim({"lr": 1.}) + x1 = torch.tensor(0.0, requires_grad=True) + x2 = torch.tensor(0.0, requires_grad=True) + opt_c = pyro_optim({"lr": 1.0}, {clip: value}) + opt = pyro_optim({"lr": 1.0}) for step in range(3): - x1.backward(Uniform(value, value + 3.).sample()) + x1.backward(Uniform(value, value + 3.0).sample()) x2.backward(torch.tensor(value)) opt_c([x1]) opt([x2]) @@ -157,28 +190,32 @@ def test_clip_norm(pyro_optim, clip, value): opt.optim_objs[x2].zero_grad() -@pytest.mark.parametrize('clip_norm', [1., 3., 5.]) +@pytest.mark.parametrize("clip_norm", [1.0, 3.0, 5.0]) def test_clippedadam_clip(clip_norm): - x1 = torch.tensor(0., requires_grad=True) - x2 = torch.tensor(0., requires_grad=True) - opt_ca = optim.clipped_adam.ClippedAdam(params=[x1], lr=1., lrd=1., clip_norm=clip_norm) - opt_a = torch.optim.Adam(params=[x2], lr=1.) + x1 = torch.tensor(0.0, requires_grad=True) + x2 = torch.tensor(0.0, requires_grad=True) + opt_ca = optim.clipped_adam.ClippedAdam( + params=[x1], lr=1.0, lrd=1.0, clip_norm=clip_norm + ) + opt_a = torch.optim.Adam(params=[x2], lr=1.0) for step in range(3): opt_ca.zero_grad() opt_a.zero_grad() - x1.backward(Uniform(clip_norm, clip_norm + 3.).sample()) + x1.backward(Uniform(clip_norm, clip_norm + 3.0).sample()) x2.backward(torch.tensor(clip_norm)) opt_ca.step() opt_a.step() assert_equal(x1, x2) -@pytest.mark.parametrize('clip_norm', [1., 3., 5.]) +@pytest.mark.parametrize("clip_norm", [1.0, 3.0, 5.0]) def test_clippedadam_pass(clip_norm): - x1 = torch.tensor(0., requires_grad=True) - x2 = torch.tensor(0., requires_grad=True) - opt_ca = optim.clipped_adam.ClippedAdam(params=[x1], lr=1., lrd=1., clip_norm=clip_norm) - opt_a = torch.optim.Adam(params=[x2], lr=1.) + x1 = torch.tensor(0.0, requires_grad=True) + x2 = torch.tensor(0.0, requires_grad=True) + opt_ca = optim.clipped_adam.ClippedAdam( + params=[x1], lr=1.0, lrd=1.0, clip_norm=clip_norm + ) + opt_a = torch.optim.Adam(params=[x2], lr=1.0) for step in range(3): g = Uniform(-clip_norm, clip_norm).sample() opt_ca.zero_grad() @@ -190,16 +227,16 @@ def test_clippedadam_pass(clip_norm): assert_equal(x1, x2) -@pytest.mark.parametrize('lrd', [1., 3., 5.]) +@pytest.mark.parametrize("lrd", [1.0, 3.0, 5.0]) def test_clippedadam_lrd(lrd): - x1 = torch.tensor(0., requires_grad=True) + x1 = torch.tensor(0.0, requires_grad=True) orig_lr = 1.0 opt_ca = optim.clipped_adam.ClippedAdam(params=[x1], lr=orig_lr, lrd=lrd) for step in range(3): - g = Uniform(-5., 5.).sample() + g = Uniform(-5.0, 5.0).sample() x1.backward(g) opt_ca.step() - assert opt_ca.param_groups[0]['lr'] == orig_lr * lrd**(step + 1) + assert opt_ca.param_groups[0]["lr"] == orig_lr * lrd ** (step + 1) def test_dctadam_param_subsample(): @@ -213,11 +250,13 @@ def test_dctadam_param_subsample(): def model(): with pyro.plate("outer", outer_size, subsample_size=outer_subsize, dim=-3): with pyro.plate("inner", inner_size, subsample_size=inner_subsize, dim=-1): - pyro.param("loc", - torch.randn(outer_size, middle_size, inner_size, event_size), - event_dim=1) + pyro.param( + "loc", + torch.randn(outer_size, middle_size, inner_size, event_size), + event_dim=1, + ) - optimizer = optim.DCTAdam({"lr": 1., "subsample_aware": True}) + optimizer = optim.DCTAdam({"lr": 1.0, "subsample_aware": True}) model() param = pyro.param("loc").unconstrained() param.sum().backward() @@ -250,8 +289,7 @@ def forward(self, features, data): loc = self.loc(features) scale = self.scale with pyro.plate("data", len(data)): - pyro.sample("obs", dist.Normal(loc, scale), - obs=data) + pyro.sample("obs", dist.Normal(loc, scale), obs=data) model = Model() params = list(model.parameters()) diff --git a/tests/params/test_module.py b/tests/params/test_module.py index fbcf98810d..5ab792b86a 100644 --- a/tests/params/test_module.py +++ b/tests/params/test_module.py @@ -12,7 +12,6 @@ class outest(nn.Module): - def __init__(self): super().__init__() self.l0 = outer() @@ -24,7 +23,6 @@ def forward(self, s): class outer(torch.nn.Module): - def __init__(self): super().__init__() self.l0 = inner() @@ -35,7 +33,6 @@ def forward(self, s): class inner(torch.nn.Module): - def __init__(self): super().__init__() self.l0 = nn.Linear(2, 2) @@ -45,11 +42,7 @@ def forward(self, s): pass -sequential = nn.Sequential( - nn.Conv2d(1, 20, 5), - nn.ReLU(), - nn.Conv2d(20, 64, 5) - ) +sequential = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5)) @pytest.mark.parametrize("nn_module", [outest, outer]) @@ -74,9 +67,9 @@ def forward(self, s): pass with pytest.warns(UserWarning): - pyro.module('net', net()) - assert 'net$$$x' in pyro.get_param_store().keys() - assert 'net$$$y' not in pyro.get_param_store().keys() + pyro.module("net", net()) + assert "net$$$x" in pyro.get_param_store().keys() + assert "net$$$y" not in pyro.get_param_store().keys() @pytest.mark.parametrize("nn_module", [sequential]) diff --git a/tests/params/test_param.py b/tests/params/test_param.py index 326cd47b38..98acb4d701 100644 --- a/tests/params/test_param.py +++ b/tests/params/test_param.py @@ -15,7 +15,6 @@ class ParamStoreDictTests(TestCase): - def setUp(self): pyro.clear_param_store() self.linear_module = nn.Linear(3, 2) @@ -31,7 +30,7 @@ def test_save_and_load(self): cost = torch.sum(torch.pow(lin(x), 2.0)) * torch.pow(myparam, 4.0) cost.backward() params = list(self.linear_module.parameters()) + [myparam] - optim = torch.optim.Adam(params, lr=.01) + optim = torch.optim.Adam(params, lr=0.01) myparam_copy_stale = copy(pyro.param("myparam").detach().cpu().numpy()) optim.step() @@ -42,25 +41,39 @@ def test_save_and_load(self): assert len(list(param_store_params.keys())) == 5 assert len(list(param_store_param_to_name.values())) == 5 - pyro.get_param_store().save('paramstore.unittest.out') + pyro.get_param_store().save("paramstore.unittest.out") pyro.clear_param_store() assert len(list(pyro.get_param_store()._params)) == 0 assert len(list(pyro.get_param_store()._param_to_name)) == 0 - pyro.get_param_store().load('paramstore.unittest.out') + pyro.get_param_store().load("paramstore.unittest.out") def modules_are_equal(): - weights_equal = np.sum(np.fabs(self.linear_module3.weight.detach().cpu().numpy() - - self.linear_module.weight.detach().cpu().numpy())) == 0.0 - bias_equal = np.sum(np.fabs(self.linear_module3.bias.detach().cpu().numpy() - - self.linear_module.bias.detach().cpu().numpy())) == 0.0 - return (weights_equal and bias_equal) + weights_equal = ( + np.sum( + np.fabs( + self.linear_module3.weight.detach().cpu().numpy() + - self.linear_module.weight.detach().cpu().numpy() + ) + ) + == 0.0 + ) + bias_equal = ( + np.sum( + np.fabs( + self.linear_module3.bias.detach().cpu().numpy() + - self.linear_module.bias.detach().cpu().numpy() + ) + ) + == 0.0 + ) + return weights_equal and bias_equal assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=False) - assert id(self.linear_module3.weight) != id(pyro.param('mymodule$$$weight')) + assert id(self.linear_module3.weight) != id(pyro.param("mymodule$$$weight")) assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=True) - assert id(self.linear_module3.weight) == id(pyro.param('mymodule$$$weight')) + assert id(self.linear_module3.weight) == id(pyro.param("mymodule$$$weight")) assert modules_are_equal() myparam = pyro.param("myparam") @@ -68,7 +81,9 @@ def modules_are_equal(): assert myparam_copy_stale != myparam.detach().cpu().numpy() assert myparam_copy == myparam.detach().cpu().numpy() assert sorted(param_store_params.keys()) == sorted(store._params.keys()) - assert sorted(param_store_param_to_name.values()) == sorted(store._param_to_name.values()) + assert sorted(param_store_param_to_name.values()) == sorted( + store._param_to_name.values() + ) assert sorted(store._params.keys()) == sorted(store._param_to_name.values()) @@ -79,58 +94,58 @@ def test_dict_interface(): param_store.clear() assert not param_store assert len(param_store) == 0 - assert 'x' not in param_store - assert 'y' not in param_store + assert "x" not in param_store + assert "y" not in param_store assert list(param_store.items()) == [] assert list(param_store.keys()) == [] assert list(param_store.values()) == [] # add x - param_store['x'] = torch.zeros(1, 2, 3) + param_store["x"] = torch.zeros(1, 2, 3) assert param_store assert len(param_store) == 1 - assert 'x' in param_store - assert 'y' not in param_store - assert list(param_store.keys()) == ['x'] - assert [key for key, value in param_store.items()] == ['x'] + assert "x" in param_store + assert "y" not in param_store + assert list(param_store.keys()) == ["x"] + assert [key for key, value in param_store.items()] == ["x"] assert len(list(param_store.values())) == 1 - assert param_store['x'].shape == (1, 2, 3) - assert_equal(param_store.setdefault('x', torch.ones(1, 2, 3)), torch.zeros(1, 2, 3)) - assert param_store['x'].unconstrained() is param_store['x'] + assert param_store["x"].shape == (1, 2, 3) + assert_equal(param_store.setdefault("x", torch.ones(1, 2, 3)), torch.zeros(1, 2, 3)) + assert param_store["x"].unconstrained() is param_store["x"] # add y - param_store.setdefault('y', torch.ones(4, 5), constraint=constraints.positive) + param_store.setdefault("y", torch.ones(4, 5), constraint=constraints.positive) assert param_store assert len(param_store) == 2 - assert 'x' in param_store - assert 'y' in param_store - assert sorted(param_store.keys()) == ['x', 'y'] - assert sorted(key for key, value in param_store.items()) == ['x', 'y'] + assert "x" in param_store + assert "y" in param_store + assert sorted(param_store.keys()) == ["x", "y"] + assert sorted(key for key, value in param_store.items()) == ["x", "y"] assert len(list(param_store.values())) == 2 - assert param_store['x'].shape == (1, 2, 3) - assert param_store['y'].shape == (4, 5) - assert_equal(param_store.setdefault('y', torch.zeros(4, 5)), torch.ones(4, 5)) - assert_equal(param_store['y'].unconstrained(), torch.zeros(4, 5)) + assert param_store["x"].shape == (1, 2, 3) + assert param_store["y"].shape == (4, 5) + assert_equal(param_store.setdefault("y", torch.zeros(4, 5)), torch.ones(4, 5)) + assert_equal(param_store["y"].unconstrained(), torch.zeros(4, 5)) # remove x - del param_store['x'] + del param_store["x"] assert param_store assert len(param_store) == 1 - assert 'x' not in param_store - assert 'y' in param_store - assert list(param_store.keys()) == ['y'] - assert list(key for key, value in param_store.items()) == ['y'] + assert "x" not in param_store + assert "y" in param_store + assert list(param_store.keys()) == ["y"] + assert list(key for key, value in param_store.items()) == ["y"] assert len(list(param_store.values())) == 1 - assert param_store['y'].shape == (4, 5) - assert_equal(param_store.setdefault('y', torch.zeros(4, 5)), torch.ones(4, 5)) - assert_equal(param_store['y'].unconstrained(), torch.zeros(4, 5)) + assert param_store["y"].shape == (4, 5) + assert_equal(param_store.setdefault("y", torch.zeros(4, 5)), torch.ones(4, 5)) + assert_equal(param_store["y"].unconstrained(), torch.zeros(4, 5)) # remove y - del param_store['y'] + del param_store["y"] assert not param_store assert len(param_store) == 0 - assert 'x' not in param_store - assert 'y' not in param_store + assert "x" not in param_store + assert "y" not in param_store assert list(param_store.keys()) == [] assert list(key for key, value in param_store.items()) == [] assert len(list(param_store.values())) == 0 diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 2753e64ed6..280a6f2963 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -20,12 +20,14 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS -Model = namedtuple('TestModel', ['model', 'model_args', 'model_id']) +Model = namedtuple("TestModel", ["model", "model_args", "model_id"]) TEST_MODELS = [] MODEL_IDS = [] -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +ROOT_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) PROF_DIR = os.path.join(ROOT_DIR, ".benchmarks") if not os.path.exists(PROF_DIR): os.makedirs(PROF_DIR) @@ -38,13 +40,26 @@ def register_fn(model): TEST_MODELS.append(test_model) MODEL_IDS.append(model_id) return model + return register_fn -@register_model(reparameterized=True, Elbo=TraceGraph_ELBO, id='PoissonGamma::reparam=True_TraceGraph') -@register_model(reparameterized=True, Elbo=Trace_ELBO, id='PoissonGamma::reparam=True_Trace') -@register_model(reparameterized=False, Elbo=TraceGraph_ELBO, id='PoissonGamma::reparam=False_TraceGraph') -@register_model(reparameterized=False, Elbo=Trace_ELBO, id='PoissonGamma::reparam=False_Trace') +@register_model( + reparameterized=True, + Elbo=TraceGraph_ELBO, + id="PoissonGamma::reparam=True_TraceGraph", +) +@register_model( + reparameterized=True, Elbo=Trace_ELBO, id="PoissonGamma::reparam=True_Trace" +) +@register_model( + reparameterized=False, + Elbo=TraceGraph_ELBO, + id="PoissonGamma::reparam=False_TraceGraph", +) +@register_model( + reparameterized=False, Elbo=Trace_ELBO, id="PoissonGamma::reparam=False_Trace" +) def poisson_gamma_model(reparameterized, Elbo): pyro.set_rng_seed(0) alpha0 = torch.tensor(1.0) @@ -72,18 +87,20 @@ def guide(): alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) - adam = optim.Adam({"lr": .0002, "betas": (0.97, 0.999)}) + adam = optim.Adam({"lr": 0.0002, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=Elbo()) for k in range(3000): svi.step() -@register_model(kernel=NUTS, step_size=0.02, num_samples=300, id='BernoulliBeta::NUTS') -@register_model(kernel=HMC, step_size=0.02, num_steps=3, num_samples=1000, id='BernoulliBeta::HMC') +@register_model(kernel=NUTS, step_size=0.02, num_samples=300, id="BernoulliBeta::NUTS") +@register_model( + kernel=HMC, step_size=0.02, num_steps=3, num_samples=1000, id="BernoulliBeta::HMC" +) def bernoulli_beta_hmc(**kwargs): def model(data): - alpha = pyro.param('alpha', torch.tensor([1.1, 1.1])) - beta = pyro.param('beta', torch.tensor([1.1, 1.1])) + alpha = pyro.param("alpha", torch.tensor([1.1, 1.1])) + beta = pyro.param("beta", torch.tensor([1.1, 1.1])) p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta)) pyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent @@ -91,16 +108,16 @@ def model(data): pyro.set_rng_seed(0) true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - kernel = kwargs.pop('kernel') - num_samples = kwargs.pop('num_samples') + kernel = kwargs.pop("kernel") + num_samples = kwargs.pop("num_samples") mcmc_kernel = kernel(model, **kwargs) mcmc = MCMC(mcmc_kernel, num_samples=num_samples, warmup_steps=100) mcmc.run(data) - return mcmc.get_samples()['p_latent'] + return mcmc.get_samples()["p_latent"] -@register_model(num_steps=2000, whiten=False, id='VSGP::MultiClass_whiten=False') -@register_model(num_steps=2000, whiten=True, id='VSGP::MultiClass_whiten=True') +@register_model(num_steps=2000, whiten=False, id="VSGP::MultiClass_whiten=False") +@register_model(num_steps=2000, whiten=True, id="VSGP::MultiClass_whiten=True") def vsgp_multiclass(num_steps, whiten): # adapted from http://gpflow.readthedocs.io/en/latest/notebooks/multiclass.html pyro.set_rng_seed(0) @@ -109,14 +126,15 @@ def vsgp_multiclass(num_steps, whiten): f = torch.linalg.cholesky(K).matmul(torch.randn(100, 3)) y = f.argmax(dim=-1) - kernel = gp.kernels.Sum(gp.kernels.Matern32(1), - gp.kernels.WhiteNoise(1, variance=torch.tensor(0.01))) + kernel = gp.kernels.Sum( + gp.kernels.Matern32(1), gp.kernels.WhiteNoise(1, variance=torch.tensor(0.01)) + ) likelihood = gp.likelihoods.MultiClass(num_classes=3) Xu = X[::5].clone() - gpmodule = gp.models.VariationalSparseGP(X, y, kernel, Xu, likelihood, - latent_shape=torch.Size([3]), - whiten=whiten) + gpmodule = gp.models.VariationalSparseGP( + X, y, kernel, Xu, likelihood, latent_shape=torch.Size([3]), whiten=whiten + ) gpmodule.Xu.requires_grad_(False) gpmodule.kernel.kern1.variance_unconstrained.requires_grad_(False) @@ -125,7 +143,7 @@ def vsgp_multiclass(num_steps, whiten): gp.util.train(gpmodule, optimizer, num_steps=num_steps) -@pytest.mark.parametrize('model, model_args, id', TEST_MODELS, ids=MODEL_IDS) +@pytest.mark.parametrize("model, model_args, id", TEST_MODELS, ids=MODEL_IDS) @pytest.mark.benchmark( min_rounds=5, disable_gc=True, @@ -139,6 +157,7 @@ def test_benchmark(benchmark, model, model_args, id): def profile_fn(test_model): def wrapped(): test_model.model(**test_model.model_args) + return wrapped @@ -147,12 +166,24 @@ def wrapped(): This script is invoked to run cProfile on one of the models specified above. """ parser = argparse.ArgumentParser(description="Profiling different Pyro models.") - parser.add_argument("-m", "--models", nargs="*", - help="model name to match against model id, partial match (e.g. *NAME*) is acceptable.") - parser.add_argument("-b", "--suffix", default="current_branch", - help="suffix to append to the cprofile output dump.") - parser.add_argument("-d", "--benchmark_dir", default=PROF_DIR, - help="directory to save profiling benchmarks.") + parser.add_argument( + "-m", + "--models", + nargs="*", + help="model name to match against model id, partial match (e.g. *NAME*) is acceptable.", + ) + parser.add_argument( + "-b", + "--suffix", + default="current_branch", + help="suffix to append to the cprofile output dump.", + ) + parser.add_argument( + "-d", + "--benchmark_dir", + default=PROF_DIR, + help="directory to save profiling benchmarks.", + ) args = parser.parse_args() search_regexp = [re.compile(".*" + m + ".*") for m in args.models] profile_ids = [] @@ -168,6 +199,8 @@ def wrapped(): pr = cProfile.Profile() fn = profile_fn(test_model) pr.runctx("fn()", globals(), locals()) - profile_file = os.path.join(args.benchmark_dir, test_model.model_id + "#" + args.suffix + ".prof") + profile_file = os.path.join( + args.benchmark_dir, test_model.model_id + "#" + args.suffix + ".prof" + ) pr.dump_stats(profile_file) print("Results in - {}".format(profile_file)) diff --git a/tests/poutine/test_counterfactual.py b/tests/poutine/test_counterfactual.py index 286c771eba..81db18bedc 100644 --- a/tests/poutine/test_counterfactual.py +++ b/tests/poutine/test_counterfactual.py @@ -16,18 +16,21 @@ def _item(x): return x -@pytest.mark.parametrize('intervene,observe,flip', [ - (True, False, False), - (False, True, False), - (True, True, False), - (True, True, True), -]) +@pytest.mark.parametrize( + "intervene,observe,flip", + [ + (True, False, False), + (False, True, False), + (True, True, False), + (True, True, True), + ], +) def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] - observations = {"x": 1., "y": None, "z": 1., "w": 1.} - interventions = {"x": None, "y": 0., "z": 2., "w": 1.} + observations = {"x": 1.0, "y": None, "z": 1.0, "w": 1.0} + interventions = {"x": None, "y": 0.0, "z": 2.0, "w": 1.0} def model(): x = _item(pyro.sample("x", dist.Normal(0, 1))) @@ -43,8 +46,8 @@ def model(): model = poutine.condition(model, data=observations) elif flip and intervene and observe: model = poutine.do( - poutine.condition(model, data=observations), - data=interventions) + poutine.condition(model, data=observations), data=interventions + ) tr = poutine.trace(model).get_trace() actual_values = tr.nodes["_RETURN"]["value"] @@ -52,43 +55,44 @@ def model(): # case 1: purely observational query like poutine.condition if not intervene and observe: if observations[name] is not None: - assert tr.nodes[name]['is_observed'] + assert tr.nodes[name]["is_observed"] assert_equal(observations[name], actual_values[name]) - assert_equal(observations[name], tr.nodes[name]['value']) + assert_equal(observations[name], tr.nodes[name]["value"]) if interventions[name] != observations[name]: assert_not_equal(interventions[name], actual_values[name]) # case 2: purely interventional query like old poutine.do elif intervene and not observe: - assert not tr.nodes[name]['is_observed'] + assert not tr.nodes[name]["is_observed"] if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) - assert_not_equal(observations[name], tr.nodes[name]['value']) - assert_not_equal(interventions[name], tr.nodes[name]['value']) + assert_not_equal(observations[name], tr.nodes[name]["value"]) + assert_not_equal(interventions[name], tr.nodes[name]["value"]) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: - assert tr.nodes[name]['is_observed'] - assert_equal(observations[name], tr.nodes[name]['value']) + assert tr.nodes[name]["is_observed"] + assert_equal(observations[name], tr.nodes[name]["value"]) if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) if interventions[name] != observations[name]: - assert_not_equal(interventions[name], tr.nodes[name]['value']) + assert_not_equal(interventions[name], tr.nodes[name]["value"]) def test_plate_duplication_smoke(): - def model(N): with pyro.plate("x_plate", N): - z1 = pyro.sample("z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) - z2 = pyro.sample("z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) - return pyro.sample("x", dist.MultivariateNormal(z1+z2, torch.eye(2))) + z1 = pyro.sample( + "z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)) + ) + z2 = pyro.sample( + "z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)) + ) + return pyro.sample("x", dist.MultivariateNormal(z1 + z2, torch.eye(2))) - fix_z1 = torch.tensor([[-6.1258, -6.1524], - [-4.1513, -4.3080]]) + fix_z1 = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]]) - obs_x = torch.tensor([[-6.1258, -6.1524], - [-4.1513, -4.3080]]) + obs_x = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]]) do_model = poutine.do(model, data={"z1": fix_z1}) do_model = poutine.condition(do_model, data={"x": obs_x}) diff --git a/tests/poutine/test_mapdata.py b/tests/poutine/test_mapdata.py index ce38c023c8..8d7543e2d4 100644 --- a/tests/poutine/test_mapdata.py +++ b/tests/poutine/test_mapdata.py @@ -23,8 +23,13 @@ def test_nested_iplate(): def model(means, stds): a_plate = pyro.plate("a", len(means), mean_batch_size) b_plate = pyro.plate("b", len(stds), std_batch_size) - return [[pyro.sample("x_{}{}".format(i, j), dist.Normal(means[i], stds[j])) - for j in b_plate] for i in a_plate] + return [ + [ + pyro.sample("x_{}{}".format(i, j), dist.Normal(means[i], stds[j])) + for j in b_plate + ] + for i in a_plate + ] xs = model(means, stds) assert len(xs) == mean_batch_size @@ -39,7 +44,7 @@ def model(means, stds): def plate_model(subsample_size): loc = torch.zeros(20) scale = torch.ones(20) - with pyro.plate('plate', 20, subsample_size) as batch: + with pyro.plate("plate", 20, subsample_size) as batch: pyro.sample("x", dist.Normal(loc[batch], scale[batch])) result = list(batch.data) return result @@ -49,7 +54,7 @@ def iplate_model(subsample_size): loc = torch.zeros(20) scale = torch.ones(20) result = [] - for i in pyro.plate('plate', 20, subsample_size): + for i in pyro.plate("plate", 20, subsample_size): pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i])) result.append(i) return result @@ -63,24 +68,35 @@ def nested_iplate_model(subsample_size): for i in pyro.plate("outer", 20, subsample_size): result.append([]) for j in inner_iplate: - pyro.sample("x_{}_{}".format(i, j), dist.Normal(loc[i] + loc[j], scale[i] + scale[j])) + pyro.sample( + "x_{}_{}".format(i, j), + dist.Normal(loc[i] + loc[j], scale[i] + scale[j]), + ) result[-1].append(j) return result -@pytest.mark.parametrize('subsample_size', [5, 20]) -@pytest.mark.parametrize('model', [plate_model, iplate_model, nested_iplate_model], - ids=['plate', 'iplate', 'nested_iplate']) +@pytest.mark.parametrize("subsample_size", [5, 20]) +@pytest.mark.parametrize( + "model", + [plate_model, iplate_model, nested_iplate_model], + ids=["plate", "iplate", "nested_iplate"], +) def test_cond_indep_stack(model, subsample_size): tr = poutine.trace(model).get_trace(subsample_size) for name, node in tr.nodes.items(): if name.startswith("x"): - assert node["cond_indep_stack"], "missing cond_indep_stack at node {}".format(name) + assert node[ + "cond_indep_stack" + ], "missing cond_indep_stack at node {}".format(name) -@pytest.mark.parametrize('subsample_size', [5, 20]) -@pytest.mark.parametrize('model', [plate_model, iplate_model, nested_iplate_model], - ids=['plate', 'iplate', 'nested_iplate']) +@pytest.mark.parametrize("subsample_size", [5, 20]) +@pytest.mark.parametrize( + "model", + [plate_model, iplate_model, nested_iplate_model], + ids=["plate", "iplate", "nested_iplate"], +) def test_replay(model, subsample_size): pyro.set_rng_seed(0) @@ -96,20 +112,21 @@ def test_replay(model, subsample_size): def plate_custom_model(subsample): - with pyro.plate('plate', 20, subsample=subsample) as batch: + with pyro.plate("plate", 20, subsample=subsample) as batch: result = batch return result def iplate_custom_model(subsample): result = [] - for i in pyro.plate('plate', 20, subsample=subsample): + for i in pyro.plate("plate", 20, subsample=subsample): result.append(i) return result -@pytest.mark.parametrize('model', [plate_custom_model, iplate_custom_model], - ids=['plate', 'iplate']) +@pytest.mark.parametrize( + "model", [plate_custom_model, iplate_custom_model], ids=["plate", "iplate"] +) def test_custom_subsample(model): pyro.set_rng_seed(0) @@ -133,25 +150,30 @@ def iplate_cuda_model(subsample_size): @requires_cuda -@pytest.mark.parametrize('subsample_size', [5, 20]) -@pytest.mark.parametrize('model', [plate_cuda_model, iplate_cuda_model], ids=["plate", "iplate"]) +@pytest.mark.parametrize("subsample_size", [5, 20]) +@pytest.mark.parametrize( + "model", [plate_cuda_model, iplate_cuda_model], ids=["plate", "iplate"] +) def test_cuda(model, subsample_size): tr = poutine.trace(model).get_trace(subsample_size) assert tr.log_prob_sum().is_cuda -@pytest.mark.parametrize('model', [plate_model, iplate_model], ids=['plate', 'iplate']) -@pytest.mark.parametrize("behavior,model_size,guide_size", [ - ("error", 20, 5), - ("error", 5, 20), - ("error", 5, None), - ("ok", 20, 20), - ("ok", 20, None), - ("ok", 5, 5), - ("ok", None, 20), - ("ok", None, 5), - ("ok", None, None), -]) +@pytest.mark.parametrize("model", [plate_model, iplate_model], ids=["plate", "iplate"]) +@pytest.mark.parametrize( + "behavior,model_size,guide_size", + [ + ("error", 20, 5), + ("error", 5, 20), + ("error", 5, None), + ("ok", 20, 20), + ("ok", 20, None), + ("ok", 5, 5), + ("ok", None, 20), + ("ok", None, 5), + ("ok", None, None), + ], +) def test_model_guide_mismatch(behavior, model_size, guide_size, model): model = poutine.trace(model) expected_ind = model(guide_size) diff --git a/tests/poutine/test_nesting.py b/tests/poutine/test_nesting.py index ede0456c32..8d04a05f4c 100644 --- a/tests/poutine/test_nesting.py +++ b/tests/poutine/test_nesting.py @@ -12,7 +12,6 @@ def test_nested_reset(): - def nested_model(): pyro.sample("internal0", dist.Bernoulli(0.5)) with poutine.escape(escape_fn=lambda msg: msg["name"] == "internal2"): diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 99fbfb6336..55745ed4fd 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -25,22 +25,17 @@ def eq(x, y, prec=1e-10): - return (torch.norm(x - y).item() < prec) + return torch.norm(x - y).item() < prec # XXX name is a bit silly class NormalNormalNormalHandlerTestCase(TestCase): - def setUp(self): pyro.clear_param_store() def model(): - latent1 = pyro.sample("latent1", - Normal(torch.zeros(2), - torch.ones(2))) - latent2 = pyro.sample("latent2", - Normal(latent1, - 5 * torch.ones(2))) + latent1 = pyro.sample("latent1", Normal(torch.zeros(2), torch.ones(2))) + latent2 = pyro.sample("latent2", Normal(latent1, 5 * torch.ones(2))) x_dist = Normal(latent2, torch.ones(2)) pyro.sample("obs", x_dist, obs=torch.ones(2)) return latent1 @@ -58,21 +53,24 @@ def guide(): self.model = model self.guide = guide - self.model_sites = ["latent1", "latent2", - "obs", - "_INPUT", "_RETURN"] - - self.guide_sites = ["latent1", "latent2", - "loc1", "scale1", - "loc2", "scale2", - "_INPUT", "_RETURN"] + self.model_sites = ["latent1", "latent2", "obs", "_INPUT", "_RETURN"] + + self.guide_sites = [ + "latent1", + "latent2", + "loc1", + "scale1", + "loc2", + "scale2", + "_INPUT", + "_RETURN", + ] self.full_sample_sites = {"latent1": "latent1", "latent2": "latent2"} self.partial_sample_sites = {"latent1": "latent1"} class TraceHandlerTests(NormalNormalNormalHandlerTestCase): - def test_trace_full(self): guide_trace = poutine.trace(self.guide).get_trace() model_trace = poutine.trace(self.model).get_trace() @@ -81,15 +79,20 @@ def test_trace_full(self): for name in guide_trace.nodes.keys(): assert name in self.guide_sites - assert guide_trace.nodes[name]["type"] in \ - ("args", "return", "sample", "param") + assert guide_trace.nodes[name]["type"] in ( + "args", + "return", + "sample", + "param", + ) if guide_trace.nodes[name]["type"] == "sample": assert not guide_trace.nodes[name]["is_observed"] def test_trace_return(self): model_trace = poutine.trace(self.model).get_trace() - assert_equal(model_trace.nodes["latent1"]["value"], - model_trace.nodes["_RETURN"]["value"]) + assert_equal( + model_trace.nodes["latent1"]["value"], model_trace.nodes["_RETURN"]["value"] + ) def test_trace_param_only(self): model_trace = poutine.trace(self.model, param_only=True).get_trace() @@ -97,13 +100,15 @@ def test_trace_param_only(self): class ReplayHandlerTests(NormalNormalNormalHandlerTestCase): - def test_replay_full(self): guide_trace = poutine.trace(self.guide).get_trace() - model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace() + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace() for name in self.full_sample_sites.keys(): - assert_equal(model_trace.nodes[name]["value"], - guide_trace.nodes[name]["value"]) + assert_equal( + model_trace.nodes[name]["value"], guide_trace.nodes[name]["value"] + ) def test_replay_full_repeat(self): model_trace = poutine.trace(self.model).get_trace() @@ -119,12 +124,13 @@ def test_replay_full_repeat(self): class BlockHandlerTests(NormalNormalNormalHandlerTestCase): - def test_block_hide_fn(self): model_trace = poutine.trace( - poutine.block(self.model, - hide_fn=lambda msg: "latent" in msg["name"], - expose=["latent1"]) + poutine.block( + self.model, + hide_fn=lambda msg: "latent" in msg["name"], + expose=["latent1"], + ) ).get_trace() assert "latent1" not in model_trace assert "latent2" not in model_trace @@ -132,9 +138,11 @@ def test_block_hide_fn(self): def test_block_expose_fn(self): model_trace = poutine.trace( - poutine.block(self.model, - expose_fn=lambda msg: "latent" in msg["name"], - hide=["latent1"]) + poutine.block( + self.model, + expose_fn=lambda msg: "latent" in msg["name"], + hide=["latent1"], + ) ).get_trace() assert "latent1" in model_trace assert "latent2" in model_trace @@ -149,20 +157,24 @@ def test_block_full(self): assert guide_trace.nodes[name]["type"] in ("args", "return") def test_block_full_hide(self): - model_trace = poutine.trace(poutine.block(self.model, - hide=self.model_sites)).get_trace() - guide_trace = poutine.trace(poutine.block(self.guide, - hide=self.guide_sites)).get_trace() + model_trace = poutine.trace( + poutine.block(self.model, hide=self.model_sites) + ).get_trace() + guide_trace = poutine.trace( + poutine.block(self.guide, hide=self.guide_sites) + ).get_trace() for name in model_trace.nodes.keys(): assert model_trace.nodes[name]["type"] in ("args", "return") for name in guide_trace.nodes.keys(): assert guide_trace.nodes[name]["type"] in ("args", "return") def test_block_full_expose(self): - model_trace = poutine.trace(poutine.block(self.model, - expose=self.model_sites)).get_trace() - guide_trace = poutine.trace(poutine.block(self.guide, - expose=self.guide_sites)).get_trace() + model_trace = poutine.trace( + poutine.block(self.model, expose=self.model_sites) + ).get_trace() + guide_trace = poutine.trace( + poutine.block(self.guide, expose=self.guide_sites) + ).get_trace() for name in self.model_sites: assert name in model_trace for name in self.guide_sites: @@ -170,18 +182,22 @@ def test_block_full_expose(self): def test_block_full_hide_expose(self): try: - poutine.block(self.model, - hide=self.partial_sample_sites.keys(), - expose=self.partial_sample_sites.keys())() + poutine.block( + self.model, + hide=self.partial_sample_sites.keys(), + expose=self.partial_sample_sites.keys(), + )() assert False except AssertionError: assert True def test_block_partial_hide(self): model_trace = poutine.trace( - poutine.block(self.model, hide=self.partial_sample_sites.keys())).get_trace() + poutine.block(self.model, hide=self.partial_sample_sites.keys()) + ).get_trace() guide_trace = poutine.trace( - poutine.block(self.guide, hide=self.partial_sample_sites.keys())).get_trace() + poutine.block(self.guide, hide=self.partial_sample_sites.keys()) + ).get_trace() for name in self.full_sample_sites.keys(): if name in self.partial_sample_sites: assert name not in model_trace @@ -192,9 +208,11 @@ def test_block_partial_hide(self): def test_block_partial_expose(self): model_trace = poutine.trace( - poutine.block(self.model, expose=self.partial_sample_sites.keys())).get_trace() + poutine.block(self.model, expose=self.partial_sample_sites.keys()) + ).get_trace() guide_trace = poutine.trace( - poutine.block(self.guide, expose=self.partial_sample_sites.keys())).get_trace() + poutine.block(self.guide, expose=self.partial_sample_sites.keys()) + ).get_trace() for name in self.full_sample_sites.keys(): if name in self.partial_sample_sites: assert name in model_trace @@ -206,7 +224,8 @@ def test_block_partial_expose(self): def test_block_tutorial_case(self): model_trace = poutine.trace(self.model).get_trace() guide_trace = poutine.trace( - poutine.block(self.guide, hide_types=["observe"])).get_trace() + poutine.block(self.guide, hide_types=["observe"]) + ).get_trace() assert "latent1" in model_trace assert "latent1" in guide_trace @@ -215,7 +234,6 @@ def test_block_tutorial_case(self): class QueueHandlerDiscreteTest(TestCase): - def setUp(self): # simple Gaussian-mixture HMM @@ -229,18 +247,26 @@ def model(): for t in range(3): latents.append( - pyro.sample("latent_{}".format(str(t)), - Bernoulli(probs[latents[-1][0].long().data]))) + pyro.sample( + "latent_{}".format(str(t)), + Bernoulli(probs[latents[-1][0].long().data]), + ) + ) observes.append( - pyro.sample("observe_{}".format(str(t)), - Normal(loc[latents[-1][0].long().data], scale), - obs=torch.ones(1))) + pyro.sample( + "observe_{}".format(str(t)), + Normal(loc[latents[-1][0].long().data], scale), + obs=torch.ones(1), + ) + ) return latents - self.sites = ["observe_{}".format(str(t)) for t in range(3)] + \ - ["latent_{}".format(str(t)) for t in range(3)] + \ - ["_INPUT", "_RETURN"] + self.sites = ( + ["observe_{}".format(str(t)) for t in range(3)] + + ["latent_{}".format(str(t)) for t in range(3)] + + ["_INPUT", "_RETURN"] + ) self.model = model self.queue = Queue() self.queue.put(poutine.Trace()) @@ -266,9 +292,16 @@ def test_queue_enumerate(self): tr_latents = [] for tr in trs: - tr_latents.append(tuple([int(tr.nodes[name]["value"].view(-1).item()) for name in tr - if tr.nodes[name]["type"] == "sample" and - not tr.nodes[name]["is_observed"]])) + tr_latents.append( + tuple( + [ + int(tr.nodes[name]["value"].view(-1).item()) + for name in tr + if tr.nodes[name]["type"] == "sample" + and not tr.nodes[name]["is_observed"] + ] + ) + ) assert true_latents == set(tr_latents) @@ -288,7 +321,6 @@ def forward(self, x): class LiftHandlerTests(TestCase): - def setUp(self): pyro.clear_param_store() @@ -342,7 +374,12 @@ def dup_param_guide(): self.guide = guide self.dup_param_guide = dup_param_guide self.prior = scale1_prior - self.prior_dict = {"loc1": loc1_prior, "scale1": scale1_prior, "loc2": loc2_prior, "scale2": scale2_prior} + self.prior_dict = { + "loc1": loc1_prior, + "scale1": scale1_prior, + "loc2": loc2_prior, + "scale2": scale2_prior, + } self.partial_dict = {"loc1": loc1_prior, "scale1": scale1_prior} self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior} self.fn = stoch_fn @@ -350,9 +387,11 @@ def dup_param_guide(): def test_splice(self): tr = poutine.trace(self.guide).get_trace() - lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace() + lifted_tr = poutine.trace( + poutine.lift(self.guide, prior=self.prior) + ).get_trace() for name in tr.nodes.keys(): - if name in ('loc1', 'loc2', 'scale1', 'scale2'): + if name in ("loc1", "loc2", "scale1", "scale2"): assert name not in lifted_tr else: assert name in lifted_tr @@ -362,65 +401,72 @@ def test_memoize(self): def test_prior_dict(self): tr = poutine.trace(self.guide).get_trace() - lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace() + lifted_tr = poutine.trace( + poutine.lift(self.guide, prior=self.prior_dict) + ).get_trace() for name in tr.nodes.keys(): assert name in lifted_tr - if name in {'scale1', 'loc1', 'scale2', 'loc2'}: - assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__ + if name in {"scale1", "loc1", "scale2", "loc2"}: + assert name + "_prior" == lifted_tr.nodes[name]["fn"].__name__ if tr.nodes[name]["type"] == "param": assert lifted_tr.nodes[name]["type"] == "sample" assert not lifted_tr.nodes[name]["is_observed"] def test_unlifted_param(self): tr = poutine.trace(self.guide).get_trace() - lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.partial_dict)).get_trace() + lifted_tr = poutine.trace( + poutine.lift(self.guide, prior=self.partial_dict) + ).get_trace() for name in tr.nodes.keys(): assert name in lifted_tr - if name in ('scale1', 'loc1'): - assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__ + if name in ("scale1", "loc1"): + assert name + "_prior" == lifted_tr.nodes[name]["fn"].__name__ assert lifted_tr.nodes[name]["type"] == "sample" assert not lifted_tr.nodes[name]["is_observed"] - if name in ('scale2', 'loc2'): + if name in ("scale2", "loc2"): assert lifted_tr.nodes[name]["type"] == "param" - @pytest.mark.filterwarnings('ignore::FutureWarning') + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_random_module(self): pyro.clear_param_store() with pyro.validation_enabled(): - lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace() + lifted_tr = poutine.trace( + pyro.random_module("name", self.model, prior=self.prior) + ).get_trace() for name in lifted_tr.nodes.keys(): if lifted_tr.nodes[name]["type"] == "param": assert lifted_tr.nodes[name]["type"] == "sample" assert not lifted_tr.nodes[name]["is_observed"] - @pytest.mark.filterwarnings('ignore::FutureWarning') + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_random_module_warn(self): pyro.clear_param_store() - bad_prior = {'foo': None} + bad_prior = {"foo": None} with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") with pyro.validation_enabled(): - poutine.trace(pyro.random_module("name", self.model, prior=bad_prior)).get_trace() - assert len(w), 'No warnings were raised' + poutine.trace( + pyro.random_module("name", self.model, prior=bad_prior) + ).get_trace() + assert len(w), "No warnings were raised" for warning in w: logger.info(warning) - @pytest.mark.filterwarnings('ignore::FutureWarning') + @pytest.mark.filterwarnings("ignore::FutureWarning") def test_random_module_prior_dict(self): pyro.clear_param_store() lifted_nn = pyro.random_module("name", self.model, prior=self.nn_prior) lifted_tr = poutine.trace(lifted_nn).get_trace() for key_name in lifted_tr.nodes.keys(): name = pyro.params.user_param_name(key_name) - if name in {'fc.weight', 'fc.prior'}: + if name in {"fc.weight", "fc.prior"}: dist_name = name[3:] - assert dist_name + "_prior" == lifted_tr.nodes[key_name]['fn'].__name__ + assert dist_name + "_prior" == lifted_tr.nodes[key_name]["fn"].__name__ assert lifted_tr.nodes[key_name]["type"] == "sample" assert not lifted_tr.nodes[key_name]["is_observed"] class QueueHandlerMixedTest(TestCase): - def setUp(self): # Simple model with 1 continuous + 1 discrete + 1 continuous variable. @@ -453,8 +499,11 @@ def test_queue_enumerate(self): assert len(trs) == 2 values = [ - {name: tr.nodes[name]['value'].view(-1).item() for name in tr.nodes.keys() - if tr.nodes[name]['type'] == 'sample'} + { + name: tr.nodes[name]["value"].view(-1).item() + for name in tr.nodes.keys() + if tr.nodes[name]["type"] == "sample" + } for tr in trs ] @@ -470,41 +519,54 @@ def test_queue_enumerate(self): class IndirectLambdaHandlerTests(TestCase): - def setUp(self): - def model(batch_size_outer=2, batch_size_inner=2): data = [[torch.ones(1)] * 2] * 2 - loc_latent = pyro.sample("loc_latent", dist.Normal(torch.zeros(1), torch.ones(1))) + loc_latent = pyro.sample( + "loc_latent", dist.Normal(torch.zeros(1), torch.ones(1)) + ) for i in pyro.plate("plate_outer", 2, batch_size_outer): for j in pyro.plate("plate_inner_%d" % i, 2, batch_size_inner): - pyro.sample("z_%d_%d" % (i, j), dist.Normal(loc_latent + data[i][j], torch.ones(1))) + pyro.sample( + "z_%d_%d" % (i, j), + dist.Normal(loc_latent + data[i][j], torch.ones(1)), + ) self.model = model - self.expected_nodes = set(["z_0_0", "z_0_1", "z_1_0", "z_1_1", "loc_latent", - "_INPUT", "_RETURN"]) - self.expected_edges = set([ - ("loc_latent", "z_0_0"), ("loc_latent", "z_0_1"), - ("loc_latent", "z_1_0"), ("loc_latent", "z_1_1"), - ]) + self.expected_nodes = set( + ["z_0_0", "z_0_1", "z_1_0", "z_1_1", "loc_latent", "_INPUT", "_RETURN"] + ) + self.expected_edges = set( + [ + ("loc_latent", "z_0_0"), + ("loc_latent", "z_0_1"), + ("loc_latent", "z_1_0"), + ("loc_latent", "z_1_1"), + ] + ) def test_graph_structure(self): tracegraph = poutine.trace(self.model, graph_type="dense").get_trace() # Ignore structure on plate_* nodes. actual_nodes = set(n for n in tracegraph.nodes if not n.startswith("plate_")) - actual_edges = set((n1, n2) for n1, n2 in tracegraph.edges - if not n1.startswith("plate_") if not n2.startswith("plate_")) + actual_edges = set( + (n1, n2) + for n1, n2 in tracegraph.edges + if not n1.startswith("plate_") + if not n2.startswith("plate_") + ) assert actual_nodes == self.expected_nodes assert actual_edges == self.expected_edges def test_scale_factors(self): def _test_scale_factor(batch_size_outer, batch_size_inner, expected): - trace = poutine.trace(self.model, graph_type="dense").get_trace(batch_size_outer=batch_size_outer, - batch_size_inner=batch_size_inner) + trace = poutine.trace(self.model, graph_type="dense").get_trace( + batch_size_outer=batch_size_outer, batch_size_inner=batch_size_inner + ) scale_factors = [] - for node in ['z_0_0', 'z_0_1', 'z_1_0', 'z_1_1']: + for node in ["z_0_0", "z_0_1", "z_1_0", "z_1_1"]: if node in trace: - scale_factors.append(trace.nodes[node]['scale']) + scale_factors.append(trace.nodes[node]["scale"]) assert scale_factors == expected _test_scale_factor(1, 1, [4.0]) @@ -514,49 +576,56 @@ def _test_scale_factor(batch_size_outer, batch_size_inner, expected): class ConditionHandlerTests(NormalNormalNormalHandlerTestCase): - def test_condition(self): data = {"latent2": torch.randn(2)} tr2 = poutine.trace(poutine.condition(self.model, data=data)).get_trace() assert "latent2" in tr2 - assert tr2.nodes["latent2"]["type"] == "sample" and \ - tr2.nodes["latent2"]["is_observed"] + assert ( + tr2.nodes["latent2"]["type"] == "sample" + and tr2.nodes["latent2"]["is_observed"] + ) assert tr2.nodes["latent2"]["value"] is data["latent2"] def test_trace_data(self): tr1 = poutine.trace( - poutine.block(self.model, expose_types=["sample"])).get_trace() - tr2 = poutine.trace( - poutine.condition(self.model, data=tr1)).get_trace() - assert tr2.nodes["latent2"]["type"] == "sample" and \ - tr2.nodes["latent2"]["is_observed"] + poutine.block(self.model, expose_types=["sample"]) + ).get_trace() + tr2 = poutine.trace(poutine.condition(self.model, data=tr1)).get_trace() + assert ( + tr2.nodes["latent2"]["type"] == "sample" + and tr2.nodes["latent2"]["is_observed"] + ) assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"] def test_stack_overwrite_behavior(self): data1 = {"latent2": torch.randn(2)} data2 = {"latent2": torch.randn(2)} with poutine.trace() as tr: - cm = poutine.condition(poutine.condition(self.model, data=data1), - data=data2) + cm = poutine.condition( + poutine.condition(self.model, data=data1), data=data2 + ) cm() - assert tr.trace.nodes['latent2']['value'] is data2['latent2'] + assert tr.trace.nodes["latent2"]["value"] is data2["latent2"] def test_stack_success(self): data1 = {"latent1": torch.randn(2)} data2 = {"latent2": torch.randn(2)} tr = poutine.trace( - poutine.condition(poutine.condition(self.model, data=data1), - data=data2)).get_trace() - assert tr.nodes["latent1"]["type"] == "sample" and \ - tr.nodes["latent1"]["is_observed"] + poutine.condition(poutine.condition(self.model, data=data1), data=data2) + ).get_trace() + assert ( + tr.nodes["latent1"]["type"] == "sample" + and tr.nodes["latent1"]["is_observed"] + ) assert tr.nodes["latent1"]["value"] is data1["latent1"] - assert tr.nodes["latent2"]["type"] == "sample" and \ - tr.nodes["latent2"]["is_observed"] + assert ( + tr.nodes["latent2"]["type"] == "sample" + and tr.nodes["latent2"]["is_observed"] + ) assert tr.nodes["latent2"]["value"] is data2["latent2"] class UnconditionHandlerTests(NormalNormalNormalHandlerTestCase): - def test_uncondition(self): unconditioned_model = poutine.uncondition(self.model) unconditioned_trace = poutine.trace(unconditioned_model).get_trace() @@ -566,13 +635,14 @@ def test_uncondition(self): def test_undo_uncondition(self): unconditioned_model = poutine.uncondition(self.model) - reconditioned_model = pyro.condition(unconditioned_model, {"obs": torch.ones(2)}) + reconditioned_model = pyro.condition( + unconditioned_model, {"obs": torch.ones(2)} + ) reconditioned_trace = poutine.trace(reconditioned_model).get_trace() assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2)) class EscapeHandlerTests(TestCase): - def setUp(self): # Simple model with 1 continuous + 1 discrete + 1 continuous variable. @@ -591,18 +661,19 @@ def model(): def test_discrete_escape(self): try: - poutine.escape(self.model, - escape_fn=functools.partial(discrete_escape, - poutine.Trace()))() + poutine.escape( + self.model, + escape_fn=functools.partial(discrete_escape, poutine.Trace()), + )() assert False except NonlocalExit as e: assert e.site["name"] == "y" def test_all_escape(self): try: - poutine.escape(self.model, - escape_fn=functools.partial(all_escape, - poutine.Trace()))() + poutine.escape( + self.model, escape_fn=functools.partial(all_escape, poutine.Trace()) + )() assert False except NonlocalExit as e: assert e.site["name"] == "x" @@ -610,17 +681,19 @@ def test_all_escape(self): def test_trace_compose(self): tm = poutine.trace(self.model) try: - poutine.escape(tm, - escape_fn=functools.partial(all_escape, - poutine.Trace()))() + poutine.escape( + tm, escape_fn=functools.partial(all_escape, poutine.Trace()) + )() assert False except NonlocalExit: assert "x" in tm.trace try: tem = poutine.trace( - poutine.escape(self.model, - escape_fn=functools.partial(all_escape, - poutine.Trace()))) + poutine.escape( + self.model, + escape_fn=functools.partial(all_escape, poutine.Trace()), + ) + ) tem() assert False except NonlocalExit: @@ -631,8 +704,9 @@ class InferConfigHandlerTests(TestCase): def setUp(self): def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) - pyro.sample("a", Bernoulli(torch.tensor([0.5])), - infer={"enumerate": "parallel"}) + pyro.sample( + "a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"} + ) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) self.model = model @@ -655,15 +729,17 @@ def test_infer_config_sample(self): assert tr.nodes["p"]["infer"] == {} -@pytest.mark.parametrize('first_available_dim', [-1, -2, -3]) -@pytest.mark.parametrize('depth', [0, 1, 2]) +@pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) +@pytest.mark.parametrize("depth", [0, 1, 2]) def test_enumerate_poutine(depth, first_available_dim): num_particles = 2 def model(): pyro.sample("x", Bernoulli(0.5)) for i in range(depth): - pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"}) + pyro.sample( + "a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"} + ) model = poutine.enum(model, first_available_dim=first_available_dim) model = poutine.trace(model) @@ -676,11 +752,11 @@ def model(): expected_shape = (2,) * depth if depth: expected_shape = expected_shape + (1,) * (-1 - first_available_dim) - assert actual_shape == expected_shape, 'error on iteration {}'.format(i) + assert actual_shape == expected_shape, "error on iteration {}".format(i) -@pytest.mark.parametrize('first_available_dim', [-1, -2, -3]) -@pytest.mark.parametrize('depth', [0, 1, 2]) +@pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) +@pytest.mark.parametrize("depth", [0, 1, 2]) def test_replay_enumerate_poutine(depth, first_available_dim): num_particles = 2 y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25])) @@ -695,10 +771,14 @@ def guide(): def model(): pyro.sample("x", Bernoulli(0.5)) for i in range(depth): - pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"}) + pyro.sample( + "a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"} + ) pyro.sample("y", y_dist, infer={"enumerate": "parallel"}) for i in range(depth): - pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"}) + pyro.sample( + "b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"} + ) model = poutine.enum(model, first_available_dim=first_available_dim) model = poutine.replay(model, trace=guide_trace) @@ -710,15 +790,17 @@ def model(): tr.compute_log_prob() log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes()) actual_shape = log_prob.shape - expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * (-1 - first_available_dim) - assert actual_shape == expected_shape, 'error on iteration {}'.format(i) + expected_shape = ( + (2,) * depth + (3,) + (2,) * depth + (1,) * (-1 - first_available_dim) + ) + assert actual_shape == expected_shape, "error on iteration {}".format(i) @pytest.mark.parametrize("has_rsample", [False, True]) @pytest.mark.parametrize("depth", [0, 1, 2]) def test_plate_preserves_has_rsample(has_rsample, depth): def guide(): - loc = pyro.param("loc", torch.tensor(0.)) + loc = pyro.param("loc", torch.tensor(0.0)) with pyro.plate_stack("plates", (2,) * depth): return pyro.sample("x", dist.Normal(loc, 1).has_rsample_(has_rsample)) @@ -729,22 +811,22 @@ def guide(): def test_plate_error_on_enter(): def model(): - with pyro.plate('foo', 0): + with pyro.plate("foo", 0): pass assert len(_DIM_ALLOCATOR._stack) == 0 with pytest.raises(ZeroDivisionError): poutine.trace(model)() - assert len(_DIM_ALLOCATOR._stack) == 0, 'stack was not cleaned on error' + assert len(_DIM_ALLOCATOR._stack) == 0, "stack was not cleaned on error" def test_decorator_interface_primitives(): - @poutine.trace def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) - pyro.sample("a", Bernoulli(torch.tensor([0.5])), - infer={"enumerate": "parallel"}) + pyro.sample( + "a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"} + ) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) tr = model.get_trace() @@ -754,8 +836,9 @@ def model(): @poutine.trace(graph_type="dense") def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) - pyro.sample("a", Bernoulli(torch.tensor([0.5])), - infer={"enumerate": "parallel"}) + pyro.sample( + "a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"} + ) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) tr = model.get_trace() @@ -790,10 +873,8 @@ def model(): def test_method_decorator_interface_condition(): - class cls_model: - - @poutine.condition(data={"b": torch.tensor(1.)}) + @poutine.condition(data={"b": torch.tensor(1.0)}) def model(self, p): self._model(p) @@ -804,53 +885,62 @@ def _model(self, p): tr = poutine.trace(cls_model().model).get_trace(0.5) assert isinstance(tr, poutine.Trace) assert tr.graph_type == "flat" - assert tr.nodes["b"]["is_observed"] and tr.nodes["b"]["value"].item() == 1. + assert tr.nodes["b"]["is_observed"] and tr.nodes["b"]["value"].item() == 1.0 def test_trace_log_prob_err_msg(): def model(v): - pyro.sample("test_site", dist.Beta(1., 1.), obs=v) + pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) - tr = poutine.trace(model).get_trace(torch.tensor(2.)) - exp_msg = r"Error while computing log_prob at site 'test_site':\s*" \ - r"The value argument must be within the support" + tr = poutine.trace(model).get_trace(torch.tensor(2.0)) + exp_msg = ( + r"Error while computing log_prob at site 'test_site':\s*" + r"The value argument must be within the support" + ) with pytest.raises(ValueError, match=exp_msg): tr.compute_log_prob() def test_trace_log_prob_sum_err_msg(): def model(v): - pyro.sample("test_site", dist.Beta(1., 1.), obs=v) + pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) - tr = poutine.trace(model).get_trace(torch.tensor(2.)) - exp_msg = r"Error while computing log_prob_sum at site 'test_site':\s*" \ - r"The value argument must be within the support" + tr = poutine.trace(model).get_trace(torch.tensor(2.0)) + exp_msg = ( + r"Error while computing log_prob_sum at site 'test_site':\s*" + r"The value argument must be within the support" + ) with pytest.raises(ValueError, match=exp_msg): tr.log_prob_sum() def test_trace_score_parts_err_msg(): def guide(v): - pyro.sample("test_site", dist.Beta(1., 1.), obs=v) + pyro.sample("test_site", dist.Beta(1.0, 1.0), obs=v) - tr = poutine.trace(guide).get_trace(torch.tensor(2.)) - exp_msg = r"Error while computing score_parts at site 'test_site':\s*" \ - r"The value argument must be within the support" + tr = poutine.trace(guide).get_trace(torch.tensor(2.0)) + exp_msg = ( + r"Error while computing score_parts at site 'test_site':\s*" + r"The value argument must be within the support" + ) with pytest.raises(ValueError, match=exp_msg): tr.compute_score_parts() -def _model(a=torch.tensor(1.), b=torch.tensor(1.)): +def _model(a=torch.tensor(1.0), b=torch.tensor(1.0)): latent = pyro.sample("latent", dist.Beta(a, b)) return pyro.sample("test_site", dist.Bernoulli(latent), obs=torch.tensor(1)) -@pytest.mark.parametrize('wrapper', [ - lambda fn: poutine.block(fn), - lambda fn: poutine.condition(fn, {'latent': 0.9}), - lambda fn: poutine.enum(fn, -1), - lambda fn: poutine.replay(fn, poutine.trace(fn).get_trace()), -]) +@pytest.mark.parametrize( + "wrapper", + [ + lambda fn: poutine.block(fn), + lambda fn: poutine.condition(fn, {"latent": 0.9}), + lambda fn: poutine.enum(fn, -1), + lambda fn: poutine.replay(fn, poutine.trace(fn).get_trace()), + ], +) def test_pickling(wrapper): wrapped = wrapper(_model) buffer = io.BytesIO() @@ -864,16 +954,21 @@ def test_pickling(wrapper): pyro.set_rng_seed(0) expected_trace = poutine.trace(wrapped).get_trace(obs) assert tuple(actual_trace) == tuple(expected_trace.nodes) - assert_close([actual_trace.nodes[site]['value'] for site in actual_trace.stochastic_nodes], - [expected_trace.nodes[site]['value'] for site in expected_trace.stochastic_nodes]) + assert_close( + [actual_trace.nodes[site]["value"] for site in actual_trace.stochastic_nodes], + [ + expected_trace.nodes[site]["value"] + for site in expected_trace.stochastic_nodes + ], + ) def test_arg_kwarg_error(): - def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) - pyro.sample("a", Bernoulli(torch.tensor([0.5])), - infer={"enumerate": "parallel"}) + pyro.sample( + "a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"} + ) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) with pytest.raises(ValueError, match="not callable"): diff --git a/tests/poutine/test_properties.py b/tests/poutine/test_properties.py index ca0c354df8..7aa4d30316 100644 --- a/tests/poutine/test_properties.py +++ b/tests/poutine/test_properties.py @@ -45,57 +45,66 @@ def register_fn(fn): return register_fn -@register_model(replay={'trace': poutine.Trace()}, - block={}, - condition={'data': {}}, - do={'data': {}}) +@register_model( + replay={"trace": poutine.Trace()}, block={}, condition={"data": {}}, do={"data": {}} +) def trivial_model(): return [] tr_normal = poutine.Trace() -tr_normal.add_node("normal_0", type="sample", is_observed=False, value=torch.zeros(1), infer={}) +tr_normal.add_node( + "normal_0", type="sample", is_observed=False, value=torch.zeros(1), infer={} +) -@register_model(replay={'trace': tr_normal}, - block={'hide': ['normal_0']}, - condition={'data': {'normal_0': torch.zeros(1)}}, - do={'data': {'normal_0': torch.zeros(1)}}) +@register_model( + replay={"trace": tr_normal}, + block={"hide": ["normal_0"]}, + condition={"data": {"normal_0": torch.zeros(1)}}, + do={"data": {"normal_0": torch.zeros(1)}}, +) def normal_model(): - normal_0 = pyro.sample('normal_0', dist.Normal(torch.zeros(1), torch.ones(1))) + normal_0 = pyro.sample("normal_0", dist.Normal(torch.zeros(1), torch.ones(1))) return [normal_0] tr_normal_normal = poutine.Trace() -tr_normal_normal.add_node("normal_0", type="sample", is_observed=False, value=torch.zeros(1), infer={}) +tr_normal_normal.add_node( + "normal_0", type="sample", is_observed=False, value=torch.zeros(1), infer={} +) -@register_model(replay={'trace': tr_normal_normal}, - block={'hide': ['normal_0']}, - condition={'data': {'normal_0': torch.zeros(1)}}, - do={'data': {'normal_0': torch.zeros(1)}}) +@register_model( + replay={"trace": tr_normal_normal}, + block={"hide": ["normal_0"]}, + condition={"data": {"normal_0": torch.zeros(1)}}, + do={"data": {"normal_0": torch.zeros(1)}}, +) def normal_normal_model(): - normal_0 = pyro.sample('normal_0', dist.Normal(torch.zeros(1), torch.ones(1))) + normal_0 = pyro.sample("normal_0", dist.Normal(torch.zeros(1), torch.ones(1))) normal_1 = torch.ones(1) - pyro.sample('normal_1', dist.Normal(normal_0, torch.ones(1)), - obs=normal_1) + pyro.sample("normal_1", dist.Normal(normal_0, torch.ones(1)), obs=normal_1) return [normal_0, normal_1] tr_bernoulli_normal = poutine.Trace() -tr_bernoulli_normal.add_node("bern_0", type="sample", is_observed=False, value=torch.ones(1), infer={}) +tr_bernoulli_normal.add_node( + "bern_0", type="sample", is_observed=False, value=torch.ones(1), infer={} +) -@register_model(replay={'trace': tr_bernoulli_normal}, - block={'hide': ['bern_0']}, - condition={'data': {'bern_0': torch.ones(1)}}, - do={'data': {'bern_0': torch.ones(1)}}) +@register_model( + replay={"trace": tr_bernoulli_normal}, + block={"hide": ["bern_0"]}, + condition={"data": {"bern_0": torch.ones(1)}}, + do={"data": {"bern_0": torch.ones(1)}}, +) def bernoulli_normal_model(): - bern_0 = pyro.sample('bern_0', dist.Bernoulli(torch.zeros(1) * 1e-2)) + bern_0 = pyro.sample("bern_0", dist.Bernoulli(torch.zeros(1) * 1e-2)) loc = torch.ones(1) if bern_0.item() else -torch.ones(1) normal_0 = torch.ones(1) - pyro.sample('normal_0', dist.Normal(loc, torch.ones(1) * 1e-2), - obs=normal_0) + pyro.sample("normal_0", dist.Normal(loc, torch.ones(1) * 1e-2), obs=normal_0) return [bern_0, normal_0] @@ -104,12 +113,15 @@ def get_trace(fn, *args, **kwargs): return poutine.trace(fn).get_trace(*args, **kwargs) -@pytest.mark.parametrize('model', EXAMPLE_MODELS, ids=EXAMPLE_MODEL_IDS) -@pytest.mark.parametrize('poutine_name', [ - 'block', - 'replay', - 'trace', -]) +@pytest.mark.parametrize("model", EXAMPLE_MODELS, ids=EXAMPLE_MODEL_IDS) +@pytest.mark.parametrize( + "poutine_name", + [ + "block", + "replay", + "trace", + ], +) def test_idempotent(poutine_name, model): p = model.bind_poutine(poutine_name) expected_trace = get_trace(p(model)) @@ -117,12 +129,15 @@ def test_idempotent(poutine_name, model): assert_equal(actual_trace, expected_trace, prec=0) -@pytest.mark.parametrize('model', EXAMPLE_MODELS, ids=EXAMPLE_MODEL_IDS) -@pytest.mark.parametrize('p1_name,p2_name', [ - ('trace', 'condition'), - ('trace', 'do'), - ('trace', 'replay'), -]) +@pytest.mark.parametrize("model", EXAMPLE_MODELS, ids=EXAMPLE_MODEL_IDS) +@pytest.mark.parametrize( + "p1_name,p2_name", + [ + ("trace", "condition"), + ("trace", "do"), + ("trace", "replay"), + ], +) def test_commutes(p1_name, p2_name, model): p1 = model.bind_poutine(p1_name) p2 = model.bind_poutine(p2_name) diff --git a/tests/poutine/test_trace_struct.py b/tests/poutine/test_trace_struct.py index 4511ccbdf3..9bdd7f2ba6 100644 --- a/tests/poutine/test_trace_struct.py +++ b/tests/poutine/test_trace_struct.py @@ -30,10 +30,9 @@ ] -@pytest.mark.parametrize('edges', [ - perm for edges in EDGE_SETS - for perm in itertools.permutations(edges) -]) +@pytest.mark.parametrize( + "edges", [perm for edges in EDGE_SETS for perm in itertools.permutations(edges)] +) def test_topological_sort(edges): tr = Trace() for n1, n2 in edges: @@ -51,10 +50,9 @@ def test_topological_sort(edges): assert ranks[n1] < ranks[n2] -@pytest.mark.parametrize('edges', [ - perm for edges in EDGE_SETS - for perm in itertools.permutations(edges) -]) +@pytest.mark.parametrize( + "edges", [perm for edges in EDGE_SETS for perm in itertools.permutations(edges)] +) def test_connectivity_on_removal(edges): # check that when nodes are removed in reverse topological order # connectivity of the DAG is maintained, i.e. remaining nodes diff --git a/tests/pyroapi/test_pyroapi.py b/tests/pyroapi/test_pyroapi.py index 271c38efab..682c48b94d 100644 --- a/tests/pyroapi/test_pyroapi.py +++ b/tests/pyroapi/test_pyroapi.py @@ -5,7 +5,7 @@ from pyroapi import pyro_backend from pyroapi.tests import * # noqa F401 -pytestmark = pytest.mark.stage('unit') +pytestmark = pytest.mark.stage("unit") @pytest.fixture(params=["pyro", "minipyro"]) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9298a78f11..e4e65bdd4f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -18,301 +18,333 @@ ) logger = logging.getLogger(__name__) -pytestmark = pytest.mark.stage('test_examples') +pytestmark = pytest.mark.stage("test_examples") CPU_EXAMPLES = [ - 'air/main.py --num-steps=1', - 'air/main.py --num-steps=1 --no-baseline', - 'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2', - 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200', - 'capture_recapture/cjs.py --num-steps=1 -m 1', - 'capture_recapture/cjs.py --num-steps=1 -m 2', - 'capture_recapture/cjs.py --num-steps=1 -m 3', - 'capture_recapture/cjs.py --num-steps=1 -m 4', - 'capture_recapture/cjs.py --num-steps=1 -m 5', - 'capture_recapture/cjs.py --num-steps=1 -m 1 --tmc --tmc-num-samples=2', - 'capture_recapture/cjs.py --num-steps=1 -m 2 --tmc --tmc-num-samples=2', - 'capture_recapture/cjs.py --num-steps=1 -m 3 --tmc --tmc-num-samples=2', - 'capture_recapture/cjs.py --num-steps=1 -m 4 --tmc --tmc-num-samples=2', - 'capture_recapture/cjs.py --num-steps=1 -m 5 --tmc --tmc-num-samples=2', - 'contrib/autoname/scoping_mixture.py --num-epochs=1', - 'contrib/autoname/mixture.py --num-epochs=1', - 'contrib/autoname/tree_data.py --num-epochs=1', - 'contrib/cevae/synthetic.py --num-epochs=1', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -c=2', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a', - 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2', - 'contrib/epidemiology/sir.py --nojit -np=128 -ss=2 -n=4 -d=20 -p=1000 -f 2 --svi', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4', - 'contrib/epidemiology/regional.py --nojit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', - 'contrib/forecast/bart.py --num-steps=2 --stride=99999', - 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', - 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save', - 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save', - 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save', - 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save', - 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', - 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', - 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', - 'dmm.py --num-epochs=1', - 'dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2', - 'dmm.py --num-epochs=1 --num-iafs=1', - 'dmm.py --num-epochs=1 --tmc --tmc-num-samples=2', - 'dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2', - 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100', - 'eight_schools/svi.py --num-epochs=1', - 'einsum.py', - 'hmm.py --num-steps=1 --truncate=10 --model=0', - 'hmm.py --num-steps=1 --truncate=10 --model=1', - 'hmm.py --num-steps=1 --truncate=10 --model=2', - 'hmm.py --num-steps=1 --truncate=10 --model=3', - 'hmm.py --num-steps=1 --truncate=10 --model=4', - 'hmm.py --num-steps=1 --truncate=10 --model=5', - 'hmm.py --num-steps=1 --truncate=10 --model=6', - 'hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization', - 'hmm.py --num-steps=1 --truncate=10 --model=7', - 'hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2', - 'hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2', - 'inclined_plane.py --num-samples=1', - 'lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8', - 'minipyro.py --backend=pyro', - 'minipyro.py', - 'mixed_hmm/experiment.py --timesteps=1', - 'neutra.py -n 10 --num-warmup 10 --num-samples 10', - 'rsa/generics.py --num-samples=10', - 'rsa/hyperbole.py --price=10000', - 'rsa/schelling.py --num-samples=10', - 'rsa/schelling_false.py --num-samples=10', - 'rsa/semantic_parsing.py --num-samples=10', - 'scanvi/scanvi.py --num-epochs 1 --dataset mock', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2', - 'smcfilter.py --num-timesteps=3 --num-particles=10', - 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom', - 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto', - 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy', - 'svi_horovod.py --num-epochs=2 --size=400 --no-horovod', - 'toy_mixture_model_discrete_enumeration.py --num-steps=1', - 'sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11', - 'vae/ss_vae_M2.py --num-epochs=1', - 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential', - 'vae/vae.py --num-epochs=1', - 'vae/vae_comparison.py --num-epochs=1', - 'cvae/main.py --num-quadrant-inputs=1 --num-epochs=1', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 -rp', + "air/main.py --num-steps=1", + "air/main.py --num-steps=1 --no-baseline", + "baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2", + "lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200", + "capture_recapture/cjs.py --num-steps=1 -m 1", + "capture_recapture/cjs.py --num-steps=1 -m 2", + "capture_recapture/cjs.py --num-steps=1 -m 3", + "capture_recapture/cjs.py --num-steps=1 -m 4", + "capture_recapture/cjs.py --num-steps=1 -m 5", + "capture_recapture/cjs.py --num-steps=1 -m 1 --tmc --tmc-num-samples=2", + "capture_recapture/cjs.py --num-steps=1 -m 2 --tmc --tmc-num-samples=2", + "capture_recapture/cjs.py --num-steps=1 -m 3 --tmc --tmc-num-samples=2", + "capture_recapture/cjs.py --num-steps=1 -m 4 --tmc --tmc-num-samples=2", + "capture_recapture/cjs.py --num-steps=1 -m 5 --tmc --tmc-num-samples=2", + "contrib/autoname/scoping_mixture.py --num-epochs=1", + "contrib/autoname/mixture.py --num-epochs=1", + "contrib/autoname/tree_data.py --num-epochs=1", + "contrib/cevae/synthetic.py --num-epochs=1", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -c=2", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a", + "contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2", + "contrib/epidemiology/sir.py --nojit -np=128 -ss=2 -n=4 -d=20 -p=1000 -f 2 --svi", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4", + "contrib/epidemiology/regional.py --nojit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi", + "contrib/forecast/bart.py --num-steps=2 --stride=99999", + "contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000", + "contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000", + "contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save", + "contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save", + "contrib/mue/ProfileHMM.py --test --small --no-plots --no-save", + "contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save", + "contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2", + "contrib/timeseries/gp_models.py -m imgp --test --num-steps=2", + "contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2", + "dmm.py --num-epochs=1", + "dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2", + "dmm.py --num-epochs=1 --num-iafs=1", + "dmm.py --num-epochs=1 --tmc --tmc-num-samples=2", + "dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2", + "eight_schools/mcmc.py --num-samples=500 --warmup-steps=100", + "eight_schools/svi.py --num-epochs=1", + "einsum.py", + "hmm.py --num-steps=1 --truncate=10 --model=0", + "hmm.py --num-steps=1 --truncate=10 --model=1", + "hmm.py --num-steps=1 --truncate=10 --model=2", + "hmm.py --num-steps=1 --truncate=10 --model=3", + "hmm.py --num-steps=1 --truncate=10 --model=4", + "hmm.py --num-steps=1 --truncate=10 --model=5", + "hmm.py --num-steps=1 --truncate=10 --model=6", + "hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization", + "hmm.py --num-steps=1 --truncate=10 --model=7", + "hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2", + "hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2", + "inclined_plane.py --num-samples=1", + "lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8", + "minipyro.py --backend=pyro", + "minipyro.py", + "mixed_hmm/experiment.py --timesteps=1", + "neutra.py -n 10 --num-warmup 10 --num-samples 10", + "rsa/generics.py --num-samples=10", + "rsa/hyperbole.py --price=10000", + "rsa/schelling.py --num-samples=10", + "rsa/schelling_false.py --num-samples=10", + "rsa/semantic_parsing.py --num-samples=10", + "scanvi/scanvi.py --num-epochs 1 --dataset mock", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential", + "sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2", + "smcfilter.py --num-timesteps=3 --num-particles=10", + "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom", + "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto", + "sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy", + "svi_horovod.py --num-epochs=2 --size=400 --no-horovod", + "toy_mixture_model_discrete_enumeration.py --num-steps=1", + "sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11", + "vae/ss_vae_M2.py --num-epochs=1", + "vae/ss_vae_M2.py --num-epochs=1 --aux-loss", + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel", + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential", + "vae/vae.py --num-epochs=1", + "vae/vae_comparison.py --num-epochs=1", + "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 -rp", ] CUDA_EXAMPLES = [ - 'air/main.py --num-steps=1 --cuda', - 'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda', - 'contrib/cevae/synthetic.py --num-epochs=1 --cuda', - 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', - 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', - 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', - 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', - 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', - 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda --cpu-data --pin-mem', - 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda', - 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda --cpu-data --pin-mem', - 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --cuda', - 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', - 'dmm.py --num-epochs=1 --cuda', - 'dmm.py --num-epochs=1 --num-iafs=1 --cuda', - 'dmm.py --num-epochs=1 --tmc --tmc-num-samples=2 --cuda', - 'dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2 --cuda', - 'einsum.py --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=0 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=1 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=3 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=4 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=5 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=6 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=6 --cuda --raftery-parameterization', - 'hmm.py --num-steps=1 --truncate=10 --model=7 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --cuda', - 'hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --cuda', - 'scanvi/scanvi.py --num-epochs 1 --dataset mock --cuda', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda', - 'svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod', - 'vae/vae.py --num-epochs=1 --cuda', - 'vae/ss_vae_M2.py --num-epochs=1 --cuda', - 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --cuda', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', - 'cvae/main.py --num-quadrant-inputs=1 --num-epochs=1 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda --raftery-parameterization ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda--tmc --tmc-num-samples=2 ', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda--tmc --tmc-num-samples=2 -rp', + "air/main.py --num-steps=1 --cuda", + "baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda", + "contrib/cevae/synthetic.py --num-epochs=1 --cuda", + "contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda", + "contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda", + "contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda", + "contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda", + "contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda", + "contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda --cpu-data --pin-mem", + "contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda", + "contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda --cpu-data --pin-mem", + "contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --cuda", + "lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda", + "dmm.py --num-epochs=1 --cuda", + "dmm.py --num-epochs=1 --num-iafs=1 --cuda", + "dmm.py --num-epochs=1 --tmc --tmc-num-samples=2 --cuda", + "dmm.py --num-epochs=1 --tmcelbo --tmc-num-samples=2 --cuda", + "einsum.py --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=0 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=1 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=3 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=4 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=5 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=6 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=6 --cuda --raftery-parameterization", + "hmm.py --num-steps=1 --truncate=10 --model=7 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --cuda", + "hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --cuda", + "scanvi/scanvi.py --num-epochs 1 --dataset mock --cuda", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda", + "sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda", + "svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod", + "vae/vae.py --num-epochs=1 --cuda", + "vae/ss_vae_M2.py --num-epochs=1 --cuda", + "vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda", + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --cuda", + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda", + "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda --raftery-parameterization ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda--tmc --tmc-num-samples=2 ", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --cuda--tmc --tmc-num-samples=2 -rp", ] def xfail_jit(*args, **kwargs): reason = kwargs.pop("reason", "not jittable") - return pytest.param(*args, marks=[pytest.mark.xfail(reason=reason), - pytest.mark.skipif('CI' in os.environ, reason='slow test')]) + return pytest.param( + *args, + marks=[ + pytest.mark.xfail(reason=reason), + pytest.mark.skipif("CI" in os.environ, reason="slow test"), + ] + ) JIT_EXAMPLES = [ - 'air/main.py --num-steps=1 --jit', - xfail_jit('baseball.py --num-samples=200 --warmup-steps=100 --jit', - reason='unreproducible RuntimeError on CI'), - 'contrib/autoname/mixture.py --num-epochs=1 --jit', - 'contrib/cevae/synthetic.py --num-epochs=1 --jit', - 'contrib/epidemiology/sir.py --jit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/sir.py --jit -np=128 -ss=2 -n=4 -d=20 -p=1000 -f 2 --svi', - 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', - xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), - 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --jit', - 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --jit', - 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --jit', - 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --jit', - xfail_jit('dmm.py --num-epochs=1 --jit'), - xfail_jit('dmm.py --num-epochs=1 --num-iafs=1 --jit'), - 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit', - 'eight_schools/svi.py --num-epochs=1 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=1 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=2 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=3 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=4 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=5 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=7 --jit', - xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --jit'), - xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --jit'), - xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --jit'), - xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --jit'), - 'lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit', - 'minipyro.py --backend=pyro --jit', - 'minipyro.py --jit', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit', - xfail_jit('sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit'), - xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --jit', - 'vae/vae.py --num-epochs=1 --jit', - 'vae/vae_comparison.py --num-epochs=1 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --raftery-parameterization ', + "air/main.py --num-steps=1 --jit", + xfail_jit( + "baseball.py --num-samples=200 --warmup-steps=100 --jit", + reason="unreproducible RuntimeError on CI", + ), + "contrib/autoname/mixture.py --num-epochs=1 --jit", + "contrib/cevae/synthetic.py --num-epochs=1 --jit", + "contrib/epidemiology/sir.py --jit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2", + "contrib/epidemiology/sir.py --jit -np=128 -ss=2 -n=4 -d=20 -p=1000 -f 2 --svi", + "contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2", + "contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi", + xfail_jit("contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit"), + "contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --jit", + "contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --jit", + "contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --jit", + "contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --jit", + xfail_jit("dmm.py --num-epochs=1 --jit"), + xfail_jit("dmm.py --num-epochs=1 --num-iafs=1 --jit"), + "eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit", + "eight_schools/svi.py --num-epochs=1 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=1 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=2 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=3 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=4 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=5 --jit", + "hmm.py --num-steps=1 --truncate=10 --model=7 --jit", + xfail_jit( + "hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --jit" + ), + xfail_jit( + "hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --jit" + ), + xfail_jit( + "hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --jit" + ), + xfail_jit( + "hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --jit" + ), + "lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit", + "minipyro.py --backend=pyro --jit", + "minipyro.py --jit", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit", + "sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit", + xfail_jit("sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit"), + xfail_jit("vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit"), + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit", + "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit", + "vae/ss_vae_M2.py --num-epochs=1 --jit", + "vae/vae.py --num-epochs=1 --jit", + "vae/vae_comparison.py --num-epochs=1 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --raftery-parameterization ", ] HOROVOD_EXAMPLES = [ - 'svi_horovod.py --num-epochs=2 --size=400', - pytest.param('svi_horovod.py --num-epochs=2 --size=400 --cuda', - marks=[requires_cuda]), + "svi_horovod.py --num-epochs=2 --size=400", + pytest.param( + "svi_horovod.py --num-epochs=2 --size=400 --cuda", marks=[requires_cuda] + ), ] FUNSOR_EXAMPLES = [ - xfail_param('contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --funsor', - reason="unreproducible recursion error on travis?"), - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --raftery-parameterization --funsor', - xfail_param('contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2 --funsor', - reason="unreproducible recursion error on travis?"), - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --funsor -rp', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --tmc --tmc-num-samples=2 --funsor', - 'contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --tmc --tmc-num-samples=2 --funsor -rp', + xfail_param( + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --funsor", + reason="unreproducible recursion error on travis?", + ), + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --raftery-parameterization --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --raftery-parameterization --funsor", + xfail_param( + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 --tmc --tmc-num-samples=2 --funsor", + reason="unreproducible recursion error on travis?", + ), + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=3 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --funsor -rp", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --tmc --tmc-num-samples=2 --funsor", + "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=6 --jit --tmc --tmc-num-samples=2 --funsor -rp", ] def test_coverage(): - cpu_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CPU_EXAMPLES) - cuda_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CUDA_EXAMPLES) - jit_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in JIT_EXAMPLES) + cpu_tests = set( + (e if isinstance(e, str) else e.values[0]).split()[0] for e in CPU_EXAMPLES + ) + cuda_tests = set( + (e if isinstance(e, str) else e.values[0]).split()[0] for e in CUDA_EXAMPLES + ) + jit_tests = set( + (e if isinstance(e, str) else e.values[0]).split()[0] for e in JIT_EXAMPLES + ) for root, dirs, files in os.walk(EXAMPLES_DIR): for basename in files: - if not basename.endswith('.py'): + if not basename.endswith(".py"): continue path = os.path.join(root, basename) with open(path) as f: text = f.read() example = os.path.relpath(path, EXAMPLES_DIR) - if '__main__' in text: + if "__main__" in text: if example not in cpu_tests: - pytest.fail('Example: {} not covered in CPU_EXAMPLES.'.format(example)) - if '--cuda' in text and example not in cuda_tests: - pytest.fail('Example: {} not covered by CUDA_EXAMPLES.'.format(example)) - if '--jit' in text and example not in jit_tests: - pytest.fail('Example: {} not covered by JIT_EXAMPLES.'.format(example)) + pytest.fail( + "Example: {} not covered in CPU_EXAMPLES.".format(example) + ) + if "--cuda" in text and example not in cuda_tests: + pytest.fail( + "Example: {} not covered by CUDA_EXAMPLES.".format(example) + ) + if "--jit" in text and example not in jit_tests: + pytest.fail( + "Example: {} not covered by JIT_EXAMPLES.".format(example) + ) -@pytest.mark.parametrize('example', CPU_EXAMPLES) +@pytest.mark.parametrize("example", CPU_EXAMPLES) def test_cpu(example): - logger.info('Running:\npython examples/{}'.format(example)) + logger.info("Running:\npython examples/{}".format(example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) @@ -320,18 +352,18 @@ def test_cpu(example): @requires_cuda -@pytest.mark.parametrize('example', CUDA_EXAMPLES) +@pytest.mark.parametrize("example", CUDA_EXAMPLES) def test_cuda(example): - logger.info('Running:\npython examples/{}'.format(example)) + logger.info("Running:\npython examples/{}".format(example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) check_call([sys.executable, filename] + args) -@pytest.mark.parametrize('example', JIT_EXAMPLES) +@pytest.mark.parametrize("example", JIT_EXAMPLES) def test_jit(example): - logger.info('Running:\npython examples/{}'.format(example)) + logger.info("Running:\npython examples/{}".format(example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) @@ -339,13 +371,13 @@ def test_jit(example): @requires_horovod -@pytest.mark.parametrize('np', [1, 2]) -@pytest.mark.parametrize('example', HOROVOD_EXAMPLES) +@pytest.mark.parametrize("np", [1, 2]) +@pytest.mark.parametrize("example", HOROVOD_EXAMPLES) def test_horovod(np, example): - if 'cuda' in example and np > torch.cuda.device_count(): + if "cuda" in example and np > torch.cuda.device_count(): pytest.skip() - horovodrun = 'horovodrun -np {} --mpi-args=--oversubscribe'.format(np) - logger.info('Running:\n{} python examples/{}'.format(horovodrun, example)) + horovodrun = "horovodrun -np {} --mpi-args=--oversubscribe".format(np) + logger.info("Running:\n{} python examples/{}".format(horovodrun, example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) @@ -353,9 +385,9 @@ def test_horovod(np, example): @requires_funsor -@pytest.mark.parametrize('example', FUNSOR_EXAMPLES) +@pytest.mark.parametrize("example", FUNSOR_EXAMPLES) def test_funsor(example): - logger.info('Running:\npython examples/{}'.format(example)) + logger.info("Running:\npython examples/{}".format(example)) example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) diff --git a/tests/test_generic.py b/tests/test_generic.py index 1ca5c77588..1eaa8663c0 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -7,23 +7,27 @@ from pyro.generic import handlers, infer, ops, pyro, pyro_backend from tests.common import xfail_if_not_implemented -pytestmark = pytest.mark.stage('unit') +pytestmark = pytest.mark.stage("unit") @pytest.mark.filterwarnings("ignore", category=UserWarning) -@pytest.mark.parametrize('model', MODELS) -@pytest.mark.parametrize('backend', ['pyro']) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["pyro"]) def test_mcmc_interface(model, backend): with pyro_backend(backend), handlers.seed(rng_seed=20): f = MODELS[model]() - model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) + model, args, kwargs = ( + f["model"], + f.get("model_args", ()), + f.get("model_kwargs", {}), + ) nuts_kernel = infer.NUTS(model=model) mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10) mcmc.run(*args, **kwargs) mcmc.summary() -@pytest.mark.parametrize('backend', ['pyro', 'minipyro']) +@pytest.mark.parametrize("backend", ["pyro", "minipyro"]) def test_not_implemented(backend): with pyro_backend(backend): pyro.sample # should be implemented @@ -32,21 +36,21 @@ def test_not_implemented(backend): pyro.nonexistent_primitive -@pytest.mark.parametrize('model', MODELS) -@pytest.mark.parametrize('backend', ['minipyro', 'pyro']) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["minipyro", "pyro"]) def test_model_sample(model, backend): with pyro_backend(backend), handlers.seed(rng_seed=2), xfail_if_not_implemented(): f = MODELS[model]() - model, model_args = f['model'], f.get('model_args', ()) + model, model_args = f["model"], f.get("model_args", ()) model(*model_args) -@pytest.mark.parametrize('model', MODELS) -@pytest.mark.parametrize('backend', ['minipyro', 'pyro']) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["minipyro", "pyro"]) def test_rng_seed(model, backend): with pyro_backend(backend), handlers.seed(rng_seed=2), xfail_if_not_implemented(): f = MODELS[model]() - model, model_args = f['model'], f.get('model_args', ()) + model, model_args = f["model"], f.get("model_args", ()) with handlers.seed(rng_seed=0): expected = model(*model_args) if expected is None: @@ -56,12 +60,12 @@ def test_rng_seed(model, backend): assert ops.allclose(actual, expected) -@pytest.mark.parametrize('model', MODELS) -@pytest.mark.parametrize('backend', ['minipyro', 'pyro']) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["minipyro", "pyro"]) def test_rng_state(model, backend): with pyro_backend(backend), handlers.seed(rng_seed=2), xfail_if_not_implemented(): f = MODELS[model]() - model, model_args = f['model'], f.get('model_args', ()) + model, model_args = f["model"], f.get("model_args", ()) with handlers.seed(rng_seed=0): model(*model_args) expected = model(*model_args) @@ -75,11 +79,15 @@ def test_rng_state(model, backend): assert ops.allclose(actual, expected) -@pytest.mark.parametrize('model', MODELS) -@pytest.mark.parametrize('backend', ['minipyro', 'pyro']) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["minipyro", "pyro"]) def test_trace_handler(model, backend): with pyro_backend(backend), handlers.seed(rng_seed=2), xfail_if_not_implemented(): f = MODELS[model]() - model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) + model, model_args, model_kwargs = ( + f["model"], + f.get("model_args", ()), + f.get("model_kwargs", {}), + ) # should be implemented handlers.trace(model).get_trace(*model_args, **model_kwargs) diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 22f331a450..663e3a67b9 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -7,7 +7,7 @@ import pyro import pyro.distributions as dist -pytestmark = pytest.mark.stage('unit') +pytestmark = pytest.mark.stage("unit") def test_sample_ok(): @@ -18,17 +18,16 @@ def test_sample_ok(): def test_observe_warn(): with pytest.warns(RuntimeWarning): - pyro.sample("x", dist.Normal(0, 1), - obs=torch.tensor(0.)) + pyro.sample("x", dist.Normal(0, 1), obs=torch.tensor(0.0)) def test_param_ok(): - x = pyro.param("x", torch.tensor(0.)) + x = pyro.param("x", torch.tensor(0.0)) assert isinstance(x, torch.Tensor) assert x.shape == () def test_deterministic_ok(): - x = pyro.deterministic("x", torch.tensor(0.)) + x = pyro.deterministic("x", torch.tensor(0.0)) assert isinstance(x, torch.Tensor) assert x.shape == () diff --git a/tests/test_util.py b/tests/test_util.py index 09ec92f4f7..edc14690b0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,20 +8,20 @@ from pyro import util -pytestmark = pytest.mark.stage('unit') +pytestmark = pytest.mark.stage("unit") def test_warn_if_nan(): # scalar case with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - x = float('inf') + x = float("inf") msg = "example message" y = util.warn_if_nan(x, msg) assert y is x assert len(w) == 0 - x = float('nan') + x = float("nan") util.warn_if_nan(x, msg) # Verify some things assert len(w) == 1 @@ -33,7 +33,7 @@ def test_warn_if_nan(): x = torch.ones(2) msg = "example message" util.warn_if_nan(x, msg) - x[1] = float('nan') + x[1] = float("nan") util.warn_if_nan(x, msg) assert len(w) == 1 assert msg in str(w[-1].message) @@ -44,7 +44,7 @@ def test_warn_if_nan(): x = torch.ones(2, requires_grad=True) util.warn_if_nan(x, msg) y = x.sum() - y.backward([torch.tensor(float('nan'))]) + y.backward([torch.tensor(float("nan"))]) assert len(w) == 1 assert msg in str(w[-1].message) @@ -58,7 +58,7 @@ def test_warn_if_inf(): y = util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True) assert y is x assert len(w) == 0 - x = float('inf') + x = float("inf") util.warn_if_inf(x, msg, allow_posinf=True) assert len(w) == 0 util.warn_if_inf(x, msg, allow_neginf=True) @@ -71,7 +71,7 @@ def test_warn_if_inf(): x = torch.ones(2) util.warn_if_inf(x, msg, allow_posinf=True, allow_neginf=True) assert len(w) == 0 - x[0] = float('inf') + x[0] = float("inf") util.warn_if_inf(x, msg, allow_posinf=True) assert len(w) == 0 util.warn_if_inf(x, msg, allow_neginf=True) @@ -84,11 +84,11 @@ def test_warn_if_inf(): x = torch.ones(2, requires_grad=True) util.warn_if_inf(x, msg, allow_posinf=True) y = x.sum() - y.backward([torch.tensor(float('inf'))]) + y.backward([torch.tensor(float("inf"))]) assert len(w) == 0 x.grad = None - y.backward([torch.tensor(-float('inf'))]) + y.backward([torch.tensor(-float("inf"))]) assert len(w) == 1 assert msg in str(w[-1].message) @@ -97,9 +97,9 @@ def test_warn_if_inf(): z = torch.ones(2, requires_grad=True) y = z.sum() util.warn_if_inf(z, msg, allow_neginf=True) - y.backward([torch.tensor(-float('inf'))]) + y.backward([torch.tensor(-float("inf"))]) assert len(w) == 0 z.grad = None - y.backward([torch.tensor(float('inf'))]) + y.backward([torch.tensor(float("inf"))]) assert len(w) == 1 assert msg in str(w[-1].message) diff --git a/tutorial/source/cleannb.py b/tutorial/source/cleannb.py index 48141d7049..bc98e08d67 100644 --- a/tutorial/source/cleannb.py +++ b/tutorial/source/cleannb.py @@ -22,7 +22,9 @@ def cleannb(nbfile): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Clean kernelspec metadata of a notebook") + parser = argparse.ArgumentParser( + description="Clean kernelspec metadata of a notebook" + ) parser.add_argument("nbfiles", nargs="*", help="Files to clean kernelspec metadata") args = parser.parse_args() diff --git a/tutorial/source/conf.py b/tutorial/source/conf.py index dab0126ee5..a62bbd184c 100644 --- a/tutorial/source/conf.py +++ b/tutorial/source/conf.py @@ -37,36 +37,37 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.mathjax', - 'sphinx.ext.githubpages', - 'nbsphinx', - 'sphinx.ext.autodoc' - ] +extensions = [ + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.githubpages", + "nbsphinx", + "sphinx.ext.autodoc", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = ['.rst', '.ipynb'] +source_suffix = [".rst", ".ipynb"] # do not execute cells -nbsphinx_execute = 'never' +nbsphinx_execute = "never" # allow errors because not all tutorials build nbsphinx_allow_errors = True # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Pyro Tutorials' -copyright = u'Pyro Contributors' -author = u'Uber AI Labs' +project = u"Pyro Tutorials" +copyright = u"Pyro Contributors" +author = u"Uber AI Labs" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -87,10 +88,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['.ipynb_checkpoints'] +exclude_patterns = [".ipynb_checkpoints"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -106,29 +107,27 @@ html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # logo -html_logo = '_static/img/pyro_logo_wide.png' +html_logo = "_static/img/pyro_logo_wide.png" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -html_theme_options = { - 'logo_only': True -} +html_theme_options = {"logo_only": True} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_style = 'css/pyro.css' +html_static_path = ["_static"] +html_style = "css/pyro.css" -html_favicon = '_static/img/favicon/favicon.ico' +html_favicon = "_static/img/favicon/favicon.ico" # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'PyroTutorialsdoc' +htmlhelp_basename = "PyroTutorialsdoc" # -- Options for LaTeX output --------------------------------------------- @@ -137,15 +136,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -155,8 +151,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'PyroTutorials.tex', u'Pyro Examples and Tutorials', - u'Uber AI Labs', 'manual'), + ( + master_doc, + "PyroTutorials.tex", + u"Pyro Examples and Tutorials", + u"Uber AI Labs", + "manual", + ), ] @@ -164,10 +165,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pyrotutorials', u'Pyro Examples and Tutorials', - [author], 1) -] +man_pages = [(master_doc, "pyrotutorials", u"Pyro Examples and Tutorials", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -176,7 +174,13 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'PyroTutorials', u'Pyro Examples and Tutorials', - author, 'PyroTutorials', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "PyroTutorials", + u"Pyro Examples and Tutorials", + author, + "PyroTutorials", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/tutorial/source/search_inference.py b/tutorial/source/search_inference.py index 0c3dada49b..9dc8b6a880 100644 --- a/tutorial/source/search_inference.py +++ b/tutorial/source/search_inference.py @@ -34,15 +34,16 @@ class HashingMarginal(dist.Distribution): Turns a TracePosterior object into a Distribution over the return values of the TracePosterior's model. """ + def __init__(self, trace_dist, sites=None): - assert isinstance(trace_dist, TracePosterior), \ - "trace_dist must be trace posterior distribution object" + assert isinstance( + trace_dist, TracePosterior + ), "trace_dist must be trace posterior distribution object" if sites is None: sites = "_RETURN" - assert isinstance(sites, (str, list)), \ - "sites must be either '_RETURN' or list" + assert isinstance(sites, (str, list)), "sites must be either '_RETURN' or list" self.sites = sites super().__init__() @@ -54,8 +55,7 @@ def __init__(self, trace_dist, sites=None): def _dist_and_values(self): # XXX currently this whole object is very inefficient values_map, logits = collections.OrderedDict(), collections.OrderedDict() - for tr, logit in zip(self.trace_dist.exec_traces, - self.trace_dist.log_weights): + for tr, logit in zip(self.trace_dist.exec_traces, self.trace_dist.log_weights): if isinstance(self.sites, str): value = tr.nodes[self.sites]["value"] else: @@ -71,7 +71,9 @@ def _dist_and_values(self): value_hash = hash(value) if value_hash in logits: # Value has already been seen. - logits[value_hash] = logsumexp(torch.stack([logits[value_hash], logit]), dim=-1) + logits[value_hash] = logsumexp( + torch.stack([logits[value_hash], logit]), dim=-1 + ) else: logits[value_hash] = logit values_map[value_hash] = value @@ -133,10 +135,12 @@ def variance(self): # Exact Search inference ######################## + class Search(TracePosterior): """ Exact inference by enumerating over all possible executions """ + def __init__(self, model, max_tries=int(1e6), **kwargs): self.model = model self.max_tries = max_tries @@ -145,8 +149,7 @@ def __init__(self, model, max_tries=int(1e6), **kwargs): def _traces(self, *args, **kwargs): q = queue.Queue() q.put(poutine.Trace()) - p = poutine.trace( - poutine.queue(self.model, queue=q, max_tries=self.max_tries)) + p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries)) while not q.empty(): tr = p.get_trace(*args, **kwargs) yield tr, tr.log_prob_sum() @@ -158,30 +161,38 @@ def _traces(self, *args, **kwargs): def pqueue(fn, queue): - def sample_escape(tr, site): - return (site["name"] not in tr) and \ - (site["type"] == "sample") and \ - (not site["is_observed"]) + return ( + (site["name"] not in tr) + and (site["type"] == "sample") + and (not site["is_observed"]) + ) def _fn(*args, **kwargs): for i in range(int(1e6)): - assert not queue.empty(), \ - "trying to get() from an empty queue will deadlock" + assert ( + not queue.empty() + ), "trying to get() from an empty queue will deadlock" priority, next_trace = queue.get() try: - ftr = poutine.trace(poutine.escape(poutine.replay(fn, next_trace), - functools.partial(sample_escape, - next_trace))) + ftr = poutine.trace( + poutine.escape( + poutine.replay(fn, next_trace), + functools.partial(sample_escape, next_trace), + ) + ) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() - for tr in poutine.util.enum_extend(ftr.trace.copy(), - site_container.site): + for tr in poutine.util.enum_extend( + ftr.trace.copy(), site_container.site + ): # add a little bit of noise to the priority to break ties... - queue.put((tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr)) + queue.put( + (tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr) + ) raise ValueError("max tries ({}) exceeded".format(str(1e6))) @@ -193,6 +204,7 @@ class BestFirstSearch(TracePosterior): Inference by enumerating executions ordered by their probabilities. Exact (and results equivalent to Search) if all executions are enumerated. """ + def __init__(self, model, num_samples=None, **kwargs): if num_samples is None: num_samples = 100