diff --git a/tutorial/source/model_rendering.ipynb b/tutorial/source/model_rendering.ipynb
index d23125e0a9..c76a73e052 100644
--- a/tutorial/source/model_rendering.ipynb
+++ b/tutorial/source/model_rendering.ipynb
@@ -60,63 +60,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 3,
@@ -233,91 +179,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 6,
@@ -338,91 +202,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 7,
@@ -475,87 +257,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 9,
@@ -586,96 +290,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 10,
@@ -732,67 +349,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 12,
@@ -838,63 +397,9 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
+ "image/svg+xml": "\n\n\n\n\n",
"text/plain": [
- ""
+ ""
]
},
"execution_count": 14,
@@ -913,12 +418,122 @@
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "attachments": {},
+ "cell_type": "markdown",
"id": "837047a8",
"metadata": {},
+ "source": [
+ "# Rendering deterministic variables\n",
+ "\n",
+ "Pyro allows deterministic variables to be defined using `pyro.deterministic`. These variables can be rendered by setting `render_deterministic=True` in `pyro.render_model` as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "d90dc8d7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def model_deterministic(data):\n",
+ " sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n",
+ " mu = pyro.param(\"mu\", torch.tensor([0.]))\n",
+ " x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n",
+ " log_y = pyro.sample(\"y\", dist.Normal(x, 1))\n",
+ " y = pyro.deterministic(\"y_deterministic\", log_y.exp())\n",
+ " with pyro.plate(\"N\", len(data)):\n",
+ " eps_z_loc = pyro.sample(\"eps_z_loc\", dist.Normal(0, 1))\n",
+ " z_loc = pyro.deterministic(\"z_loc\", eps_z_loc + x, event_dim=0)\n",
+ " pyro.sample(\"z\", dist.Normal(z_loc, y), obs=data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "6fcc43d8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data = torch.ones(10)\n",
+ "pyro.render_model(\n",
+ " model_deterministic,\n",
+ " model_args=(data,),\n",
+ " render_params=True,\n",
+ " render_distributions=True,\n",
+ " render_deterministic=True\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "9dff1ccc",
+ "metadata": {},
+ "source": [
+ "Another example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "279f6417",
+ "metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "def model(data):\n",
+ " a = pyro.sample(\"a\", dist.Normal(0, 1))\n",
+ " b = pyro.sample(\"b\", dist.Normal(a, 1))\n",
+ " c = pyro.sample(\"c\", dist.Normal(a, b.exp()))\n",
+ " d = pyro.sample(\"d\", dist.Bernoulli(logits=c), obs=torch.tensor(0.0))\n",
+ "\n",
+ " with pyro.plate(\"p\", len(data)):\n",
+ " e = pyro.sample(\"e\", dist.Normal(a, b.exp()))\n",
+ " f = pyro.deterministic(\"f\", e + 1)\n",
+ " g = pyro.sample(\"g\", dist.Delta(e + 1), obs=e + 1)\n",
+ " h = pyro.sample(\"h\", dist.Delta(e + 1))\n",
+ " i = pyro.sample(\"i\", dist.Normal(e, (f + g + h).exp()), obs=data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "591b6e0e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "obs = torch.ones(10)\n",
+ "pyro.render_model(\n",
+ " model,\n",
+ " model_args=(obs,),\n",
+ " render_distributions=True,\n",
+ " render_params=True,\n",
+ " render_deterministic=True,\n",
+ ")"
+ ]
}
],
"metadata": {