diff --git a/tutorial/source/model_rendering.ipynb b/tutorial/source/model_rendering.ipynb index d23125e0a9..c76a73e052 100644 --- a/tutorial/source/model_rendering.ipynb +++ b/tutorial/source/model_rendering.ipynb @@ -60,63 +60,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "m\n", - "\n", - "m\n", - "\n", - "\n", - "\n", - "sd\n", - "\n", - "sd\n", - "\n", - "\n", - "\n", - "m->sd\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "obs\n", - "\n", - "obs\n", - "\n", - "\n", - "\n", - "m->obs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sd->obs\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nm\n\nm\n\n\n\nsd\n\nsd\n\n\n\nm->sd\n\n\n\n\n\nobs\n\nobs\n\n\n\nm->obs\n\n\n\n\n\nsd->obs\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -233,91 +179,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_annotator\n", - "\n", - "annotator\n", - "\n", - "\n", - "cluster_item\n", - "\n", - "item\n", - "\n", - "\n", - "cluster_position\n", - "\n", - "position\n", - "\n", - "\n", - "\n", - "ε\n", - "\n", - "ε\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "ε->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "θ\n", - "\n", - "θ\n", - "\n", - "\n", - "\n", - "s\n", - "\n", - "s\n", - "\n", - "\n", - "\n", - "θ->s\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "c\n", - "\n", - "c\n", - "\n", - "\n", - "\n", - "c->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "s->y\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\ns->y\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -338,91 +202,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_annotator\n", - "\n", - "annotator\n", - "\n", - "\n", - "cluster_item\n", - "\n", - "item\n", - "\n", - "\n", - "cluster_position\n", - "\n", - "position\n", - "\n", - "\n", - "\n", - "ε\n", - "\n", - "ε\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "ε->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "θ\n", - "\n", - "θ\n", - "\n", - "\n", - "\n", - "s\n", - "\n", - "s\n", - "\n", - "\n", - "\n", - "θ->s\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "s->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "c\n", - "\n", - "c\n", - "\n", - "\n", - "\n", - "c->y\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\ns->y\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -475,87 +257,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "\n", - "\n", - "\n", - "sigma->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "\n", - "\n", - "\n", - "mu->x\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 9, @@ -586,96 +290,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "\n", - "\n", - "\n", - "sigma->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "\n", - "\n", - "\n", - "mu->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "distribution_description_node\n", - "x ~ Normal\n", - "y ~ LogNormal\n", - "z ~ Normal\n", - "sigma : GreaterThan(lower_bound=0.0)\n", - "mu : Real()\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ LogNormal\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -732,67 +349,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_plate1\n", - "\n", - "plate1\n", - "\n", - "\n", - "cluster_plate2\n", - "\n", - "plate2\n", - "\n", - "\n", - "cluster_plate2__CLONE\n", - "\n", - "plate2\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_plate1\n\nplate1\n\n\ncluster_plate2\n\nplate2\n\n\ncluster_plate2__CLONE\n\nplate2\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\ny->z\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 12, @@ -838,63 +397,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "z->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "y->x\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nz\n\nz\n\n\n\nx\n\nx\n\n\n\nz->x\n\n\n\n\n\ny\n\n\n\n\n\n\n\ny\n\n\n\ny->x\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 14, @@ -913,12 +418,122 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "attachments": {}, + "cell_type": "markdown", "id": "837047a8", "metadata": {}, + "source": [ + "# Rendering deterministic variables\n", + "\n", + "Pyro allows deterministic variables to be defined using `pyro.deterministic`. These variables can be rendered by setting `render_deterministic=True` in `pyro.render_model` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d90dc8d7", + "metadata": {}, + "outputs": [], + "source": [ + "def model_deterministic(data):\n", + " sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n", + " mu = pyro.param(\"mu\", torch.tensor([0.]))\n", + " x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n", + " log_y = pyro.sample(\"y\", dist.Normal(x, 1))\n", + " y = pyro.deterministic(\"y_deterministic\", log_y.exp())\n", + " with pyro.plate(\"N\", len(data)):\n", + " eps_z_loc = pyro.sample(\"eps_z_loc\", dist.Normal(0, 1))\n", + " z_loc = pyro.deterministic(\"z_loc\", eps_z_loc + x, event_dim=0)\n", + " pyro.sample(\"z\", dist.Normal(z_loc, y), obs=data)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6fcc43d8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz_loc\n\nz_loc\n\n\n\nx->z_loc\n\n\n\n\n\ny_deterministic\n\ny_deterministic\n\n\n\ny->y_deterministic\n\n\n\n\n\nz\n\nz\n\n\n\ny_deterministic->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\neps_z_loc\n\neps_z_loc\n\n\n\neps_z_loc->z_loc\n\n\n\n\n\nz_loc->z\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ Normal\ny_deterministic ~ Deterministic\neps_z_loc ~ Normal\nz_loc ~ Deterministic\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = torch.ones(10)\n", + "pyro.render_model(\n", + " model_deterministic,\n", + " model_args=(data,),\n", + " render_params=True,\n", + " render_distributions=True,\n", + " render_deterministic=True\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9dff1ccc", + "metadata": {}, + "source": [ + "Another example:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "279f6417", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "def model(data):\n", + " a = pyro.sample(\"a\", dist.Normal(0, 1))\n", + " b = pyro.sample(\"b\", dist.Normal(a, 1))\n", + " c = pyro.sample(\"c\", dist.Normal(a, b.exp()))\n", + " d = pyro.sample(\"d\", dist.Bernoulli(logits=c), obs=torch.tensor(0.0))\n", + "\n", + " with pyro.plate(\"p\", len(data)):\n", + " e = pyro.sample(\"e\", dist.Normal(a, b.exp()))\n", + " f = pyro.deterministic(\"f\", e + 1)\n", + " g = pyro.sample(\"g\", dist.Delta(e + 1), obs=e + 1)\n", + " h = pyro.sample(\"h\", dist.Delta(e + 1))\n", + " i = pyro.sample(\"i\", dist.Normal(e, (f + g + h).exp()), obs=data)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "591b6e0e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_p\n\np\n\n\n\na\n\na\n\n\n\nb\n\nb\n\n\n\na->b\n\n\n\n\n\nc\n\nc\n\n\n\na->c\n\n\n\n\n\ne\n\ne\n\n\n\na->e\n\n\n\n\n\nb->c\n\n\n\n\n\nb->e\n\n\n\n\n\nd\n\nd\n\n\n\nc->d\n\n\n\n\n\nf\n\nf\n\n\n\ne->f\n\n\n\n\n\ng\n\ng\n\n\n\ne->g\n\n\n\n\n\nh\n\nh\n\n\n\ne->h\n\n\n\n\n\ni\n\ni\n\n\n\ne->i\n\n\n\n\n\nf->i\n\n\n\n\n\ng->i\n\n\n\n\n\nh->i\n\n\n\n\n\ndistribution_description_node\na ~ Normal\nb ~ Normal\nc ~ Normal\nd ~ Bernoulli\ne ~ Normal\nf ~ Deterministic\ng ~ Delta\nh ~ Delta\ni ~ Normal\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs = torch.ones(10)\n", + "pyro.render_model(\n", + " model,\n", + " model_args=(obs,),\n", + " render_distributions=True,\n", + " render_params=True,\n", + " render_deterministic=True,\n", + ")" + ] } ], "metadata": {