From 8e43fc722151bc7c9104ede96e0ef84657f4f750 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 2 Jul 2021 13:24:04 -0700 Subject: [PATCH 1/2] Update enumeration tutorial --- tutorial/source/enumeration.ipynb | 260 +++++++++++++++++++--------- tutorial/source/tensor_shapes.ipynb | 7 +- 2 files changed, 181 insertions(+), 86 deletions(-) diff --git a/tutorial/source/enumeration.ipynb b/tutorial/source/enumeration.ipynb index cf0927c1ce..5a77ac670b 100644 --- a/tutorial/source/enumeration.ipynb +++ b/tutorial/source/enumeration.ipynb @@ -46,7 +46,7 @@ "from torch.distributions import constraints\n", "from pyro import poutine\n", "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete\n", - "from pyro.infer.autoguide import AutoDiagonalNormal\n", + "from pyro.infer.autoguide import AutoNormal\n", "from pyro.ops.indexing import Vindex\n", "\n", "smoke_test = ('CI' in os.environ)\n", @@ -60,7 +60,7 @@ "source": [ "## Overview \n", "\n", - "Pyro's enumeration strategy encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference ([SVI](http://docs.pyro.ai/en/dev/inference_algos.html)) and monte carlo ([HMC](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC) and [NUTS](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS)). Enumeration can be used either as a stand-alone strategy via [infer_discrete](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.discrete.infer_discrete), or as a component of other strategies. Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides." + "Pyro's enumeration strategy ([Obermeyer et al. 2019](https://arxiv.org/abs/1902.03210)) encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference ([SVI](http://docs.pyro.ai/en/dev/inference_algos.html)) and monte carlo ([HMC](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC) and [NUTS](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS)). Enumeration can be used either as a stand-alone strategy via [infer_discrete](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.discrete.infer_discrete), or as a component of other strategies. Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides." ] }, { @@ -89,11 +89,11 @@ "source": [ "def model():\n", " z = pyro.sample(\"z\", dist.Categorical(torch.ones(5)))\n", - " print('model z = {}'.format(z))\n", + " print(f\"model z = {z}\")\n", "\n", "def guide():\n", " z = pyro.sample(\"z\", dist.Categorical(torch.ones(5)))\n", - " print('guide z = {}'.format(z))\n", + " print(f\"guide z = {z}\")\n", "\n", "elbo = Trace_ELBO()\n", "elbo.loss(model, guide);" @@ -181,9 +181,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "model x.shape = torch.Size([3])\n", - "model y.shape = torch.Size([3, 1])\n", - "model z.shape = torch.Size([3, 1, 1])\n" + "Sampling:\n", + " model x.shape = torch.Size([])\n", + " model y.shape = torch.Size([])\n", + " model z.shape = torch.Size([])\n", + "Enumerated Inference:\n", + " model x.shape = torch.Size([3])\n", + " model y.shape = torch.Size([3, 1])\n", + " model z.shape = torch.Size([3, 1, 1])\n" ] } ], @@ -194,15 +199,18 @@ " x = pyro.sample(\"x\", dist.Categorical(p[0]))\n", " y = pyro.sample(\"y\", dist.Categorical(p[x]))\n", " z = pyro.sample(\"z\", dist.Categorical(p[y]))\n", - " print('model x.shape = {}'.format(x.shape))\n", - " print('model y.shape = {}'.format(y.shape))\n", - " print('model z.shape = {}'.format(z.shape))\n", + " print(f\" model x.shape = {x.shape}\")\n", + " print(f\" model y.shape = {y.shape}\")\n", + " print(f\" model z.shape = {z.shape}\")\n", " return x, y, z\n", " \n", "def guide():\n", " pass\n", "\n", "pyro.clear_param_store()\n", + "print(\"Sampling:\")\n", + "model()\n", + "print(\"Enumerated Inference:\")\n", "elbo = TraceEnum_ELBO(max_plate_nesting=0)\n", "elbo.loss(model, guide);" ] @@ -228,24 +236,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "model x.shape = torch.Size([3])\n", - "model y.shape = torch.Size([3, 1])\n", - "model z.shape = torch.Size([3, 1, 1])\n", - "model x.shape = torch.Size([])\n", - "model y.shape = torch.Size([])\n", - "model z.shape = torch.Size([])\n", - "x = 0\n", - "y = 2\n", - "z = 1\n" + " model x.shape = torch.Size([3])\n", + " model y.shape = torch.Size([3, 1])\n", + " model z.shape = torch.Size([3, 1, 1])\n", + " model x.shape = torch.Size([])\n", + " model y.shape = torch.Size([])\n", + " model z.shape = torch.Size([])\n", + "x = 2\n", + "y = 1\n", + "z = 0\n" ] } ], "source": [ "serving_model = infer_discrete(model, first_available_dim=-1)\n", "x, y, z = serving_model() # takes the same args as model(), here no args\n", - "print(\"x = {}\".format(x))\n", - "print(\"y = {}\".format(y))\n", - "print(\"z = {}\".format(z))" + "print(f\"x = {x}\")\n", + "print(f\"y = {y}\")\n", + "print(f\"z = {z}\")" ] }, { @@ -261,7 +269,7 @@ "source": [ "### Indexing with enumerated variables\n", "\n", - "It can be tricky to use [advanced indexing](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html) to select an element of a tensor using one or more enumerated variables. For example, suppose a plated random variable `z` depends on two different random variables:\n", + "It can be tricky to use [advanced indexing](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html) to select an element of a tensor using one or more enumerated variables. This is especially true in Pyro models where your model's indexing operations need to work in multiple interpretations: both sampling from the model (to generate data) and during enumerated inference. For example, suppose a plated random variable `z` depends on two different random variables:\n", "```py\n", "p = pyro.param(\"p\", torch.randn(5, 4, 3, 2).exp(),\n", " constraint=constraints.simplex)\n", @@ -279,7 +287,7 @@ " y.unsqueeze(-1),\n", " torch.arange(2, device=p.device)]\n", "```\n", - "Pyro provides a helper [Vindex()[]](http://docs.pyro.ai/en/dev/ops.html#pyro.ops.indexing.Vindex) to use enumeration-compatible advanced indexing semantics rather than PyTorch/NumPy semantics. `Vindex()[]` makes the `.__getitem__()` operator broadcast like other familiar operators `+`, `*` etc. Using `Vindex()[]` we can write the same expression as if `x` and `y` were numbers (i.e. not enumerated):\n", + "Pyro provides a helper [Vindex()[]](http://docs.pyro.ai/en/dev/ops.html#pyro.ops.indexing.Vindex) to use enumeration-compatible advanced indexing semantics rather than standard PyTorch/NumPy semantics. (Note the `Vindex` name and semantics follow the Numpy Enhancement Proposal [NEP 21](https://numpy.org/neps/nep-0021-advanced-indexing.html)). `Vindex()[]` makes the `.__getitem__()` operator broadcast like other familiar operators `+`, `*` etc. Using `Vindex()[]` we can write the same expression as if `x` and `y` were numbers (i.e. not enumerated):\n", "```py\n", "# Recommended syntax compatible with enumeration:\n", "p_xy = Vindex(p)[..., x, y, :]\n", @@ -296,11 +304,18 @@ "name": "stdout", "output_type": "stream", "text": [ - " p.shape = torch.Size([5, 4, 3, 2])\n", - " x.shape = torch.Size([4, 1])\n", - " y.shape = torch.Size([3, 1, 1])\n", - "p_xy.shape = torch.Size([3, 4, 5, 2])\n", - " z.shape = torch.Size([2, 1, 1, 1])\n" + "Sampling:\n", + " p.shape = torch.Size([5, 4, 3, 2])\n", + " x.shape = torch.Size([])\n", + " y.shape = torch.Size([])\n", + " p_xy.shape = torch.Size([5, 2])\n", + " z.shape = torch.Size([5])\n", + "Enumerated Inference:\n", + " p.shape = torch.Size([5, 4, 3, 2])\n", + " x.shape = torch.Size([4, 1])\n", + " y.shape = torch.Size([3, 1, 1])\n", + " p_xy.shape = torch.Size([3, 4, 5, 2])\n", + " z.shape = torch.Size([2, 1, 1, 1])\n" ] } ], @@ -313,21 +328,90 @@ " with pyro.plate(\"z_plate\", 5):\n", " p_xy = Vindex(p)[..., x, y, :]\n", " z = pyro.sample(\"z\", dist.Categorical(p_xy))\n", - " print(' p.shape = {}'.format(p.shape))\n", - " print(' x.shape = {}'.format(x.shape))\n", - " print(' y.shape = {}'.format(y.shape))\n", - " print('p_xy.shape = {}'.format(p_xy.shape))\n", - " print(' z.shape = {}'.format(z.shape))\n", + " print(f\" p.shape = {p.shape}\")\n", + " print(f\" x.shape = {x.shape}\")\n", + " print(f\" y.shape = {y.shape}\")\n", + " print(f\" p_xy.shape = {p_xy.shape}\")\n", + " print(f\" z.shape = {z.shape}\")\n", " return x, y, z\n", " \n", "def guide():\n", " pass\n", "\n", "pyro.clear_param_store()\n", + "print(\"Sampling:\")\n", + "model()\n", + "print(\"Enumerated Inference:\")\n", "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n", "elbo.loss(model, guide);" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When enumering within a plate (as described in the next section) ``Vindex`` can also be used together with capturing the plate index via ``with pyro.plate(...) as i`` to index into batch dimensions. Here's an example with nontrivial event dimensions due to the ``Dirichlet`` distribution." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampling:\n", + " p.shape = torch.Size([5, 4, 3])\n", + " c.shape = torch.Size([6])\n", + " vdx.shape = torch.Size([5])\n", + " pc.shape = torch.Size([5, 6, 3])\n", + " x.shape = torch.Size([5, 6])\n", + "Enumerated Inference:\n", + " p.shape = torch.Size([5, 4, 3])\n", + " c.shape = torch.Size([4, 1, 1])\n", + " vdx.shape = torch.Size([5])\n", + " pc.shape = torch.Size([4, 5, 1, 3])\n", + " x.shape = torch.Size([5, 6])\n" + ] + } + ], + "source": [ + "@config_enumerate\n", + "def model():\n", + " data_plate = pyro.plate(\"data_plate\", 6, dim=-1)\n", + " feature_plate = pyro.plate(\"feature_plate\", 5, dim=-2)\n", + " component_plate = pyro.plate(\"component_plate\", 4, dim=-1)\n", + " with feature_plate: \n", + " with component_plate:\n", + " p = pyro.sample(\"p\", dist.Dirichlet(torch.ones(3)))\n", + " with data_plate:\n", + " c = pyro.sample(\"c\", dist.Categorical(torch.ones(4)))\n", + " with feature_plate as vdx: # Capture plate index.\n", + " pc = Vindex(p)[vdx[..., None], c, :] # Reshape it and use in Vindex.\n", + " x = pyro.sample(\"x\", dist.Categorical(pc),\n", + " obs=torch.zeros(5, 6, dtype=torch.long))\n", + " print(f\" p.shape = {p.shape}\")\n", + " print(f\" c.shape = {c.shape}\")\n", + " print(f\" vdx.shape = {vdx.shape}\")\n", + " print(f\" pc.shape = {pc.shape}\")\n", + " print(f\" x.shape = {x.shape}\")\n", + "\n", + "def guide():\n", + " feature_plate = pyro.plate(\"feature_plate\", 5, dim=-2)\n", + " component_plate = pyro.plate(\"component_plate\", 4, dim=-1)\n", + " with feature_plate, component_plate:\n", + " pyro.sample(\"p\", dist.Dirichlet(torch.ones(3)))\n", + " \n", + "pyro.clear_param_store()\n", + "print(\"Sampling:\")\n", + "model()\n", + "print(\"Enumerated Inference:\")\n", + "elbo = TraceEnum_ELBO(max_plate_nesting=2)\n", + "elbo.loss(model, guide);" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -341,42 +425,50 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running model with 10 data points\n", - "x.shape = torch.Size([10])\n", - "dist.Normal(loc[x], scale).batch_shape = torch.Size([10])\n", - "Running model with 10 data points\n", - "x.shape = torch.Size([3, 1])\n", - "dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])\n" + "Sampling:\n", + " Running model with 10 data points\n", + " x.shape = torch.Size([10])\n", + " dist.Normal(loc[x], scale).batch_shape = torch.Size([10])\n", + "Enumerated Inference:\n", + " Running model with 10 data points\n", + " x.shape = torch.Size([10])\n", + " dist.Normal(loc[x], scale).batch_shape = torch.Size([10])\n", + " Running model with 10 data points\n", + " x.shape = torch.Size([3, 1])\n", + " dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])\n" ] } ], "source": [ "@config_enumerate\n", "def model(data, num_components=3):\n", - " print('Running model with {} data points'.format(len(data)))\n", + " print(f\" Running model with {len(data)} data points\")\n", " p = pyro.sample(\"p\", dist.Dirichlet(0.5 * torch.ones(num_components)))\n", " scale = pyro.sample(\"scale\", dist.LogNormal(0, num_components))\n", " with pyro.plate(\"components\", num_components):\n", " loc = pyro.sample(\"loc\", dist.Normal(0, 10))\n", " with pyro.plate(\"data\", len(data)):\n", " x = pyro.sample(\"x\", dist.Categorical(p))\n", - " print(\"x.shape = {}\".format(x.shape))\n", + " print(\" x.shape = {}\".format(x.shape))\n", " pyro.sample(\"obs\", dist.Normal(loc[x], scale), obs=data)\n", - " print(\"dist.Normal(loc[x], scale).batch_shape = {}\".format(\n", + " print(\" dist.Normal(loc[x], scale).batch_shape = {}\".format(\n", " dist.Normal(loc[x], scale).batch_shape))\n", " \n", - "guide = AutoDiagonalNormal(poutine.block(model, hide=[\"x\", \"data\"]))\n", + "guide = AutoNormal(poutine.block(model, hide=[\"x\", \"data\"]))\n", "\n", "data = torch.randn(10)\n", " \n", "pyro.clear_param_store()\n", + "print(\"Sampling:\")\n", + "model(data)\n", + "print(\"Enumerated Inference:\")\n", "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n", "elbo.loss(model, guide, data);" ] @@ -385,13 +477,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Observe that the model is run twice, first by the `AutoDiagonalNormal` to trace sample sites, and second by `elbo` to compute loss. In the first run, `x` has the standard interpretation of one sample per datum, hence shape `(10,)`. In the second run enumeration can use the same three values `(3,1)` for all data points, and relies on broadcasting for any dependent sample or observe sites that depend on data. For example, in the `pyro.sample(\"obs\",...)` statement, the distribution has shape `(3,1)`, the data has shape`(10,)`, and the broadcasted log probability tensor has shape `(3,10)`.\n", + "Observe that during inference the model is run twice, first by the `AutoNormal` to trace sample sites, and second by `elbo` to compute loss. In the first run, `x` has the standard interpretation of one sample per datum, hence shape `(10,)`. In the second run enumeration can use the same three values `(3,1)` for all data points, and relies on broadcasting for any dependent sample or observe sites that depend on data. For example, in the `pyro.sample(\"obs\",...)` statement, the distribution has shape `(3,1)`, the data has shape`(10,)`, and the broadcasted log probability tensor has shape `(3,10)`.\n", "\n", "For a more in-depth treatment of enumeration in mixture models, see the [Gaussian Mixture Model Tutorial](http://pyro.ai/examples/gmm.html) and the [HMM Example](http://pyro.ai/examples/hmm.html).\n", "\n", "### Dependencies among plates \n", "\n", - "The computational savings of enumerating in vectorized plates comes with restrictions on the dependency structure of models. These restrictions are in addition to the usual restrictions of conditional independence. The enumeration restrictions are checked by `TraceEnum_ELBO` and will result in an error if violated (however the usual conditional independence restriction cannot be generally verified by Pyro). For completeness we list all three restrictions:\n", + "The computational savings of enumerating in vectorized plates comes with restrictions on the dependency structure of models (as described in ([Obermeyer et al. 2019](https://arxiv.org/abs/1902.03210))). These restrictions are in addition to the usual restrictions of conditional independence. The enumeration restrictions are checked by `TraceEnum_ELBO` and will result in an error if violated (however the usual conditional independence restriction cannot be generally verified by Pyro). For completeness we list all three restrictions:\n", "\n", "#### Restriction 1: conditional independence\n", "Variables within a plate may not depend on each other (along the plate dimension). This applies to any variable, whether or not it is enumerated. This applies to both sequential plates and vectorized plates. For example the following model is invalid:\n", @@ -399,7 +491,7 @@ "def invalid_model():\n", " x = 0\n", " for i in pyro.plate(\"invalid\", 10):\n", - " x = pyro.sample(\"x_{}\".format(i), dist.Normal(x, 1.))\n", + " x = pyro.sample(f\"x_{i}\", dist.Normal(x, 1.))\n", "```\n", "\n", "#### Restriction 2: no downstream coupling\n", @@ -418,7 +510,7 @@ "def valid_model(data):\n", " x = []\n", " for i in pyro.plate(\"plate\", 10): # <--- valid sequential plate\n", - " x.append(pyro.sample(\"x_{}\".format(i), dist.Bernoulli(0.5)))\n", + " x.append(pyro.sample(f\"x_{i}\", dist.Bernoulli(0.5)))\n", " assert len(x) == 10\n", " pyro.sample(\"obs\", dist.Normal(sum(x), 1.), data)\n", "```\n", @@ -453,9 +545,9 @@ " with plate_1:\n", " x = pyro.sample(\"y\", dist.Bernoulli(0.5))\n", " for i in plate_2:\n", - " y = pyro.sample(\"x_{}\".format(i), dist.Bernoulli(0.5))\n", + " y = pyro.sample(f\"x_{i}\", dist.Bernoulli(0.5))\n", " with plate_1:\n", - " z = pyro.sample(\"z_{}\".format(i), dist.Bernoulli((1. + x + y) / 4.))\n", + " z = pyro.sample(f\"z_{i}\", dist.Bernoulli((1. + x + y) / 4.))\n", " ...\n", "```\n", "but beware that this increases the computational complexity, which may be exponential in the size of the sequential plate." @@ -472,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -481,7 +573,7 @@ "data = dist.Categorical(torch.ones(num_steps, data_dim)).sample()\n", "\n", "def hmm_model(data, data_dim, hidden_dim=10):\n", - " print('Running for {} time steps'.format(len(data)))\n", + " print(f\"Running for {len(data)} time steps\")\n", " # Sample global matrices wrt a Jeffreys prior.\n", " with pyro.plate(\"hidden_state\", hidden_dim):\n", " transition = pyro.sample(\"transition\", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))\n", @@ -489,10 +581,10 @@ "\n", " x = 0 # initial state\n", " for t, y in enumerate(data):\n", - " x = pyro.sample(\"x_{}\".format(t), dist.Categorical(transition[x]),\n", + " x = pyro.sample(f\"x_{t}\", dist.Categorical(transition[x]),\n", " infer={\"enumerate\": \"parallel\"})\n", - " pyro.sample(\"y_{}\".format(t), dist.Categorical(emission[x]), obs=y)\n", - " print(\"x_{}.shape = {}\".format(t, x.shape))" + " pyro.sample(f\" y_{t}\", dist.Categorical(emission[x]), obs=y)\n", + " print(f\" x_{t}.shape = {x.shape}\")" ] }, { @@ -504,7 +596,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -512,32 +604,32 @@ "output_type": "stream", "text": [ "Running for 10 time steps\n", - "x_0.shape = torch.Size([])\n", - "x_1.shape = torch.Size([])\n", - "x_2.shape = torch.Size([])\n", - "x_3.shape = torch.Size([])\n", - "x_4.shape = torch.Size([])\n", - "x_5.shape = torch.Size([])\n", - "x_6.shape = torch.Size([])\n", - "x_7.shape = torch.Size([])\n", - "x_8.shape = torch.Size([])\n", - "x_9.shape = torch.Size([])\n", + " x_0.shape = torch.Size([])\n", + " x_1.shape = torch.Size([])\n", + " x_2.shape = torch.Size([])\n", + " x_3.shape = torch.Size([])\n", + " x_4.shape = torch.Size([])\n", + " x_5.shape = torch.Size([])\n", + " x_6.shape = torch.Size([])\n", + " x_7.shape = torch.Size([])\n", + " x_8.shape = torch.Size([])\n", + " x_9.shape = torch.Size([])\n", "Running for 10 time steps\n", - "x_0.shape = torch.Size([10, 1])\n", - "x_1.shape = torch.Size([10, 1, 1])\n", - "x_2.shape = torch.Size([10, 1, 1, 1])\n", - "x_3.shape = torch.Size([10, 1, 1, 1, 1])\n", - "x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])\n", - "x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])\n", - "x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])\n", - "x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])\n", - "x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n", - "x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n" + " x_0.shape = torch.Size([10, 1])\n", + " x_1.shape = torch.Size([10, 1, 1])\n", + " x_2.shape = torch.Size([10, 1, 1, 1])\n", + " x_3.shape = torch.Size([10, 1, 1, 1, 1])\n", + " x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])\n", + " x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])\n", + " x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])\n", + " x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])\n", + " x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n", + " x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n" ] } ], "source": [ - "hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=[\"transition\", \"emission\"]))\n", + "hmm_guide = AutoNormal(poutine.block(hmm_model, expose=[\"transition\", \"emission\"]))\n", "\n", "pyro.clear_param_store()\n", "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n", @@ -548,7 +640,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Notice that the model was run twice here: first it was run without enumeration by `AutoDiagonalNormal`, so that the autoguide can record all sample sites; then second it is run by `TraceEnum_ELBO` with enumeration enabled. We see in the first run that samples have the standard interpretation, whereas in the second run samples have the enumeration interpretation.\n", + "Notice that the model was run twice here: first it was run without enumeration by `AutoNormal`, so that the autoguide can record all sample sites; then second it is run by `TraceEnum_ELBO` with enumeration enabled. We see in the first run that samples have the standard interpretation, whereas in the second run samples have the enumeration interpretation.\n", "\n", "For more complex examples, including minibatching and multiple plates, see the [HMM tutorial](https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py)." ] @@ -568,7 +660,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -596,10 +688,10 @@ "\n", " x = 0 # initial state\n", " for t, y in pyro.markov(enumerate(data)):\n", - " x = pyro.sample(\"x_{}\".format(t), dist.Categorical(transition[x]),\n", + " x = pyro.sample(f\"x_{t}\", dist.Categorical(transition[x]),\n", " infer={\"enumerate\": \"parallel\"})\n", - " pyro.sample(\"y_{}\".format(t), dist.Categorical(emission[x]), obs=y)\n", - " print(\"x_{}.shape = {}\".format(t, x.shape))\n", + " pyro.sample(f\"y_{t}\", dist.Categorical(emission[x]), obs=y)\n", + " print(f\"x_{t}.shape = {x.shape}\")\n", "\n", "# We'll reuse the same guide and elbo.\n", "elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);" @@ -636,7 +728,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.0" } }, "nbformat": 4, diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index 546a9679d8..154fd2b43f 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -6,7 +6,7 @@ "source": [ "# Tensor shapes in Pyro\n", "\n", - "This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html).\n", + "This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html). After this tutorial, you may want to also read about [enumeration](http://pyro.ai/examples/tensor_shapes.html).\n", "\n", "#### Summary:\n", "- Tensors broadcast by aligning on the right: `torch.ones(3,4,5) + torch.ones(5)`.\n", @@ -20,6 +20,9 @@ " - use negative indices like `x.sum(-1)` rather than `x.sum(2)`\n", " - use ellipsis notation like `pixel = image[..., i, j]`\n", " - use [Vindex](http://docs.pyro.ai/en/dev/ops.html#pyro.ops.indexing.Vindex) if `i,j` are enumerated, `pixel = Vindex(image)[..., i, j]`\n", + "- When using `pyro.plate`'s automatic subsampling, be sure to subsample your data:\n", + " - Either manually subample by capturing the index `with pyro.plate(...) as i: ...`\n", + " - or automatically subsample via `batch = pyro.subsample(data, event_dim=...)`.\n", "- When debugging, examine all shapes in a trace using [Trace.format_shapes()](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.Trace.format_shapes).\n", " \n", "#### Table of Contents\n", @@ -759,7 +762,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.10" + "version": "3.7.0" } }, "nbformat": 4, From ec315a3034b18e9c848b7be581cd2c114e73d84f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 2 Jul 2021 13:29:24 -0700 Subject: [PATCH 2/2] Fix link --- tutorial/source/tensor_shapes.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index 154fd2b43f..a74f48c7eb 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -6,7 +6,7 @@ "source": [ "# Tensor shapes in Pyro\n", "\n", - "This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html). After this tutorial, you may want to also read about [enumeration](http://pyro.ai/examples/tensor_shapes.html).\n", + "This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html). After this tutorial, you may want to also read about [enumeration](http://pyro.ai/examples/enumeration.html).\n", "\n", "#### Summary:\n", "- Tensors broadcast by aligning on the right: `torch.ones(3,4,5) + torch.ones(5)`.\n",