From c18c5fd383c206cfe236e93163d69895691e7697 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Wed, 1 Jun 2022 22:49:47 +0700 Subject: [PATCH] Use : for param domain when render --- pyro/infer/inspect.py | 2 +- tutorial/source/model_rendering.ipynb | 390 +++++++++++++------------- 2 files changed, 193 insertions(+), 199 deletions(-) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index 9bb5ca2520..851e31fd39 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -516,7 +516,7 @@ def render_graph( dist_label += rf"{rv} ~ {rv_dist}\l" if "constraint" in data and data["constraint"]: - dist_label += rf"{rv} ∈ {data['constraint']}\l" + dist_label += rf"{rv} : {data['constraint']}\l" graph.node("distribution_description_node", label=dist_label, shape="plaintext") diff --git a/tutorial/source/model_rendering.ipynb b/tutorial/source/model_rendering.ipynb index 63eb278bb3..861a947469 100644 --- a/tutorial/source/model_rendering.ipynb +++ b/tutorial/source/model_rendering.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "id": "8a068eb0", "metadata": {}, "outputs": [], @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "id": "855a7d8f", "metadata": {}, "outputs": [], @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 3, "id": "e1e9628e", "metadata": {}, "outputs": [ @@ -64,63 +64,62 @@ "\n", "\n", - "\n", - "\n", + "\n", "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_N\n", - "\n", - "N\n", + "\n", + "N\n", "\n", "\n", "\n", "m\n", - "\n", - "m\n", + "\n", + "m\n", "\n", "\n", "\n", "sd\n", - "\n", - "sd\n", + "\n", + "sd\n", "\n", "\n", "\n", "m->sd\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "obs\n", - "\n", - "obs\n", + "\n", + "obs\n", "\n", "\n", - "\n", + "\n", "m->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "sd->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 24, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -141,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "id": "700a0917", "metadata": {}, "outputs": [], @@ -162,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 5, "id": "01a4b74b", "metadata": {}, "outputs": [], @@ -228,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 6, "id": "311896af", "metadata": {}, "outputs": [ @@ -238,91 +237,90 @@ "\n", "\n", - "\n", - "\n", + "\n", "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_annotator\n", - "\n", - "annotator\n", + "\n", + "annotator\n", "\n", "\n", "cluster_item\n", - "\n", - "item\n", + "\n", + "item\n", "\n", "\n", "cluster_position\n", - "\n", - "position\n", + "\n", + "position\n", "\n", "\n", "\n", "ε\n", - "\n", - "ε\n", + "\n", + "ε\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", + "\n", + "y\n", "\n", "\n", "\n", "ε->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "θ\n", - "\n", - "θ\n", + "\n", + "θ\n", "\n", "\n", "\n", "s\n", - "\n", - "s\n", + "\n", + "s\n", "\n", "\n", "\n", "θ->s\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "c\n", - "\n", - "c\n", + "\n", + "c\n", "\n", "\n", "\n", "c->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "s->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 27, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -334,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 7, "id": "8babebb4", "metadata": {}, "outputs": [ @@ -344,91 +342,90 @@ "\n", "\n", - "\n", - "\n", + "\n", "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_annotator\n", - "\n", - "annotator\n", + "\n", + "annotator\n", "\n", "\n", "cluster_item\n", - "\n", - "item\n", + "\n", + "item\n", "\n", "\n", "cluster_position\n", - "\n", - "position\n", + "\n", + "position\n", "\n", "\n", "\n", "ε\n", - "\n", - "ε\n", + "\n", + "ε\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", + "\n", + "y\n", "\n", "\n", "\n", "ε->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "θ\n", - "\n", - "θ\n", + "\n", + "θ\n", "\n", "\n", "\n", "s\n", - "\n", - "s\n", + "\n", + "s\n", "\n", "\n", "\n", "θ->s\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "s->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "c\n", - "\n", - "c\n", + "\n", + "c\n", "\n", "\n", "\n", "c->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 28, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -456,7 +453,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 8, "id": "645df936", "metadata": {}, "outputs": [], @@ -472,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 9, "id": "66fc9f55", "metadata": {}, "outputs": [ @@ -482,87 +479,86 @@ "\n", "\n", - "\n", - "\n", + "\n", "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_N\n", - "\n", - "N\n", + "\n", + "N\n", "\n", "\n", "\n", "x\n", - "\n", - "x\n", + "\n", + "x\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", + "\n", + "y\n", "\n", "\n", "\n", "x->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "z\n", - "\n", - "z\n", + "\n", + "z\n", "\n", "\n", "\n", "x->z\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "y->z\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "mu\n", - "\n", - "mu\n", + "sigma\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "mu->x\n", - "\n", - "\n", + "sigma->x\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "sigma\n", - "\n", - "sigma\n", + "mu\n", + "\n", + "mu\n", "\n", - "\n", + "\n", "\n", - "sigma->x\n", - "\n", - "\n", + "mu->x\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 30, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -584,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 10, "id": "8130359c", "metadata": {}, "outputs": [ @@ -594,96 +590,95 @@ "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_N\n", - "\n", - "N\n", + "\n", + "N\n", "\n", "\n", "\n", "x\n", - "\n", - "x\n", + "\n", + "x\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", + "\n", + "y\n", "\n", "\n", "\n", "x->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "z\n", - "\n", - "z\n", + "\n", + "z\n", "\n", "\n", "\n", "x->z\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "y->z\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "mu\n", - "\n", - "mu\n", + "sigma\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "mu->x\n", - "\n", - "\n", + "sigma->x\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "sigma\n", - "\n", - "sigma\n", + "mu\n", + "\n", + "mu\n", "\n", - "\n", + "\n", "\n", - "sigma->x\n", - "\n", - "\n", + "mu->x\n", + "\n", + "\n", "\n", "\n", "\n", "distribution_description_node\n", - "x ~ Normal\n", - "y ~ LogNormal\n", - "z ~ Normal\n", - "sigma ∈ GreaterThan(lower_bound=0.0)\n", - "mu ∈ Real()\n", + "x ~ Normal\n", + "y ~ LogNormal\n", + "z ~ Normal\n", + "sigma : GreaterThan(lower_bound=0.0)\n", + "mu : Real()\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 31, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -713,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 11, "id": "9f1a4ebf", "metadata": {}, "outputs": [], @@ -731,7 +726,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 12, "id": "8514d7df", "metadata": {}, "outputs": [ @@ -741,67 +736,66 @@ "\n", "\n", - "\n", - "\n", + "\n", "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "cluster_plate1\n", - "\n", - "plate1\n", + "\n", + "plate1\n", "\n", "\n", "cluster_plate2\n", - "\n", - "plate2\n", + "\n", + "plate2\n", "\n", "\n", "cluster_plate2__CLONE\n", - "\n", - "plate2\n", + "\n", + "plate2\n", "\n", "\n", "\n", "x\n", - "\n", - "x\n", + "\n", + "x\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", + "\n", + "y\n", "\n", "\n", "\n", "x->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "z\n", - "\n", - "z\n", + "\n", + "z\n", "\n", "\n", "\n", "y->z\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 33, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -813,9 +807,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:root] *", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-root-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -827,7 +821,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.4" } }, "nbformat": 4,