diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index a8216985ea..88a722fd5d 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -6,7 +6,7 @@ from collections import defaultdict from pathlib import Path from types import SimpleNamespace -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Collection, Dict, List, Optional, Union import torch @@ -23,29 +23,42 @@ graphviz = SimpleNamespace(Digraph=object) # for type hints -def is_sample_site(msg): +def is_sample_site(msg, *, include_deterministic=False): if msg["type"] != "sample": return False if site_is_subsample(msg): return False - # Ignore masked observations. - if msg["is_observed"] and msg["mask"] is False: - return False + if not include_deterministic: + # Ignore masked observations. + if msg["is_observed"] and msg["mask"] is False: + return False - # Exclude deterministic sites. - fn = msg["fn"] - while hasattr(fn, "base_dist"): - fn = fn.base_dist - if type(fn).__name__ == "Delta": - return False + # Exclude deterministic sites. + fn = msg["fn"] + while hasattr(fn, "base_dist"): + fn = fn.base_dist + if type(fn).__name__ == "Delta": + return False return True +def site_is_deterministic(msg: dict) -> bool: + return msg["type"] == "sample" and msg["infer"].get("_deterministic", False) + + class TrackProvenance(Messenger): + def __init__(self, *, include_deterministic=False): + self.include_deterministic = include_deterministic + def _pyro_post_sample(self, msg): - if is_sample_site(msg): + if self.include_deterministic and site_is_deterministic(msg): + provenance = frozenset({msg["name"]}) # track only direct dependencies + value = detach_provenance(msg["value"]) + msg["value"] = ProvenanceTensor(value, provenance) + + elif is_sample_site(msg, include_deterministic=self.include_deterministic): provenance = frozenset({msg["name"]}) # track only direct dependencies value = detach_provenance(msg["value"]) msg["value"] = ProvenanceTensor(value, provenance) @@ -62,6 +75,7 @@ def get_dependencies( model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, + include_deterministic: bool = False, ) -> Dict[str, object]: r""" Infers dependency structure about a conditioned model. @@ -169,6 +183,7 @@ def model_3(): :param callable model: A model. :param tuple model_args: Optional tuple of model args. :param dict model_kwargs: Optional dict of model kwargs. + :param bool include_deterministic: Whether to include deterministic sites. :returns: A dictionary of metadata (see above). :rtype: dict """ @@ -179,7 +194,7 @@ def model_3(): # Collect sites with tracked provenance. with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False): - with TrackProvenance(): + with TrackProvenance(include_deterministic=include_deterministic): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)] @@ -238,6 +253,7 @@ def get_model_relations( model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, + include_deterministic: bool = False, ): """ Infer relations of RVs and plates from given model and optionally data. @@ -271,6 +287,7 @@ def model(data): :param callable model: A model to inspect. :param model_args: Optional tuple of model args. :param model_kwargs: Optional dict of model kwargs. + :param bool include_deterministic: Whether to include deterministic sites. :rtype: dict """ if model_args is None: @@ -281,7 +298,7 @@ def model(data): assert isinstance(model_kwargs, dict) with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False): - with TrackProvenance(): + with TrackProvenance(include_deterministic=include_deterministic): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) sample_sample = {} @@ -301,9 +318,14 @@ def _get_type_from_frozenname(frozen_name): if site["type"] != "sample" or site_is_subsample(site): continue + provenance = get_provenance( + site["fn"].log_prob(site["value"]) + if not site_is_deterministic(site) + else site["fn"].base_dist.log_prob(site["value"]) + ) sample_sample[name] = [ upstream - for upstream in get_provenance(site["fn"].log_prob(site["value"])) + for upstream in provenance if upstream != name and _get_type_from_frozenname(upstream) == "sample" ] @@ -313,7 +335,11 @@ def _get_type_from_frozenname(frozen_name): if upstream != name and _get_type_from_frozenname(upstream) == "param" ] - sample_dist[name] = _get_dist_name(site["fn"]) + sample_dist[name] = ( + _get_dist_name(site["fn"]) + if not site_is_deterministic(site) + else "Deterministic" + ) for frame in site["cond_indep_stack"]: plate_sample[frame.name].append(name) if site["is_observed"]: @@ -332,10 +358,15 @@ def _resolve_plate_samples(plate_samples): return plate_samples plate_sample = _resolve_plate_samples(plate_sample) - # convert set to list to keep order of variables - plate_sample = { - k: [name for name in trace.nodes if name in v] for k, v in plate_sample.items() - } + + # Normalize order of variables. + def sort_by_time(names: Collection[str]) -> List[str]: + return [name for name in trace.nodes if name in names] + + sample_sample = {k: sort_by_time(v) for k, v in sample_sample.items()} + sample_param = {k: sort_by_time(v) for k, v in sample_param.items()} + plate_sample = {k: sort_by_time(v) for k, v in plate_sample.items()} + observed = sort_by_time(observed) return { "sample_sample": sample_sample, @@ -510,11 +541,22 @@ def render_graph( # For sample_nodes - ellipse if node_data[rv]["distribution"]: shape = "ellipse" + rv_label = rv # For param_nodes - No shape else: shape = "plain" - cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color) + rv_label = rv.replace("$params", "") + + # use different symbol for Deterministic site + node_style = ( + "filled,dashed" + if node_data[rv]["distribution"] == "Deterministic" + else "filled" + ) + cur_graph.node( + rv, label=rv_label, shape=shape, style=node_style, fillcolor=color + ) # add leaf nodes first while len(plate_data) >= 1: @@ -560,6 +602,7 @@ def render_model( filename: Optional[str] = None, render_distributions: bool = False, render_params: bool = False, + render_deterministic: bool = False, ) -> "graphviz.Digraph": """ Renders a model using `graphviz `_ . @@ -577,12 +620,20 @@ def render_model( :param bool render_distributions: Whether to include RV distribution annotations (and param constraints) in the plot. :param bool render_params: Whether to show params inthe plot. + :param bool render_deterministic: Whether to include deterministic sites. :returns: A model graph. :rtype: graphviz.Digraph """ # Get model relations. if not isinstance(model_args, list) and not isinstance(model_kwargs, list): - relations = [get_model_relations(model, model_args, model_kwargs)] + relations = [ + get_model_relations( + model, + model_args, + model_kwargs, + include_deterministic=render_deterministic, + ) + ] else: # semisupervised if isinstance(model_args, list): if not isinstance(model_kwargs, list): @@ -591,7 +642,9 @@ def render_model( model_args = [model_args] * len(model_kwargs) assert len(model_args) == len(model_kwargs) relations = [ - get_model_relations(model, args, kwargs) + get_model_relations( + model, args, kwargs, include_deterministic=render_deterministic + ) for args, kwargs in zip(model_args, model_kwargs) ] diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index 9173ad825d..8b72c621ca 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -7,7 +7,7 @@ import pyro import pyro.distributions as dist from pyro.distributions.testing.fakes import NonreparameterizedNormal -from pyro.infer.inspect import _deep_merge, get_dependencies +from pyro.infer.inspect import _deep_merge, get_dependencies, get_model_relations @pytest.mark.parametrize("grad_enabled", [True, False]) @@ -450,3 +450,109 @@ def model(): def test_deep_merge(things, expected): actual = _deep_merge(things) assert actual == expected + + +@pytest.mark.parametrize("include_deterministic", [True, False]) +def test_get_model_relations(include_deterministic): + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(a, b.exp())) + d = pyro.sample("d", dist.Bernoulli(logits=c), obs=torch.tensor(0.0)) + + with pyro.plate("p", len(data)): + e = pyro.sample("e", dist.Normal(a, b.exp())) + f = pyro.deterministic("f", e + 1) + g = pyro.sample("g", dist.Delta(e + 1), obs=e + 1) + h = pyro.sample("h", dist.Delta(e + 1)) + i = pyro.sample("i", dist.Normal(e, (f + g + h).exp()), obs=data) + + return [a, b, c, d, e, f, g, h, i] + + data = torch.randn(3) + actual = get_model_relations( + model, + (data,), + include_deterministic=include_deterministic, + ) + + if include_deterministic: + expected = { + "observed": ["d", "f", "g", "i"], + "param_constraint": {}, + "plate_sample": {"p": ["e", "f", "g", "h", "i"]}, + "sample_dist": { + "a": "Normal", + "b": "Normal", + "c": "Normal", + "d": "Bernoulli", + "e": "Normal", + "f": "Deterministic", + "g": "Delta", + "h": "Delta", + "i": "Normal", + }, + "sample_param": { + "a": [], + "b": [], + "c": [], + "d": [], + "e": [], + "f": [], + "g": [], + "h": [], + "i": [], + }, + "sample_sample": { + "a": [], + "b": ["a"], + "c": ["a", "b"], + "d": ["c"], + "e": ["a", "b"], + "f": ["e"], + "g": ["e"], + "h": ["e"], + "i": ["e", "f", "g", "h"], + }, + } + else: + expected = { + "sample_sample": { + "a": [], + "b": ["a"], + "c": ["a", "b"], + "d": ["c"], + "e": ["a", "b"], + "f": ["e"], + "g": ["e"], + "h": ["e"], + "i": ["e"], + }, + "sample_param": { + "a": [], + "b": [], + "c": [], + "d": [], + "e": [], + "f": [], + "g": [], + "h": [], + "i": [], + }, + "sample_dist": { + "a": "Normal", + "b": "Normal", + "c": "Normal", + "d": "Bernoulli", + "e": "Normal", + "f": "Deterministic", + "g": "Delta", + "h": "Delta", + "i": "Normal", + }, + "param_constraint": {}, + "plate_sample": {"p": ["e", "f", "g", "h", "i"]}, + "observed": ["d", "f", "g", "i"], + } + + assert actual == expected diff --git a/tutorial/source/model_rendering.ipynb b/tutorial/source/model_rendering.ipynb index d23125e0a9..71ec181b42 100644 --- a/tutorial/source/model_rendering.ipynb +++ b/tutorial/source/model_rendering.ipynb @@ -60,63 +60,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "m\n", - "\n", - "m\n", - "\n", - "\n", - "\n", - "sd\n", - "\n", - "sd\n", - "\n", - "\n", - "\n", - "m->sd\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "obs\n", - "\n", - "obs\n", - "\n", - "\n", - "\n", - "m->obs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sd->obs\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nm\n\nm\n\n\n\nsd\n\nsd\n\n\n\nm->sd\n\n\n\n\n\nobs\n\nobs\n\n\n\nm->obs\n\n\n\n\n\nsd->obs\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -233,91 +179,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_annotator\n", - "\n", - "annotator\n", - "\n", - "\n", - "cluster_item\n", - "\n", - "item\n", - "\n", - "\n", - "cluster_position\n", - "\n", - "position\n", - "\n", - "\n", - "\n", - "ε\n", - "\n", - "ε\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "ε->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "θ\n", - "\n", - "θ\n", - "\n", - "\n", - "\n", - "s\n", - "\n", - "s\n", - "\n", - "\n", - "\n", - "θ->s\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "c\n", - "\n", - "c\n", - "\n", - "\n", - "\n", - "c->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "s->y\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\ns->y\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -338,91 +202,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_annotator\n", - "\n", - "annotator\n", - "\n", - "\n", - "cluster_item\n", - "\n", - "item\n", - "\n", - "\n", - "cluster_position\n", - "\n", - "position\n", - "\n", - "\n", - "\n", - "ε\n", - "\n", - "ε\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "ε->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "θ\n", - "\n", - "θ\n", - "\n", - "\n", - "\n", - "s\n", - "\n", - "s\n", - "\n", - "\n", - "\n", - "θ->s\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "s->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "c\n", - "\n", - "c\n", - "\n", - "\n", - "\n", - "c->y\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\ns->y\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -475,87 +257,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "\n", - "\n", - "\n", - "sigma->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "\n", - "\n", - "\n", - "mu->x\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 9, @@ -586,96 +290,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "\n", - "\n", - "\n", - "sigma->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "\n", - "\n", - "\n", - "mu->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "distribution_description_node\n", - "x ~ Normal\n", - "y ~ LogNormal\n", - "z ~ Normal\n", - "sigma : GreaterThan(lower_bound=0.0)\n", - "mu : Real()\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ LogNormal\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -732,67 +349,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_plate1\n", - "\n", - "plate1\n", - "\n", - "\n", - "cluster_plate2\n", - "\n", - "plate2\n", - "\n", - "\n", - "cluster_plate2__CLONE\n", - "\n", - "plate2\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "x->y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "y->z\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_plate1\n\nplate1\n\n\ncluster_plate2\n\nplate2\n\n\ncluster_plate2__CLONE\n\nplate2\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\ny->z\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 12, @@ -838,63 +397,9 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_N\n", - "\n", - "N\n", - "\n", - "\n", - "\n", - "z\n", - "\n", - "z\n", - "\n", - "\n", - "\n", - "x\n", - "\n", - "x\n", - "\n", - "\n", - "\n", - "z->x\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "y\n", - "\n", - "\n", - "\n", - "y->x\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nz\n\nz\n\n\n\nx\n\nx\n\n\n\nz->x\n\n\n\n\n\ny\n\n\n\n\n\n\n\ny\n\n\n\ny->x\n\n\n\n\n\n", "text/plain": [ - "" + "" ] }, "execution_count": 14, @@ -913,12 +418,63 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "attachments": {}, + "cell_type": "markdown", "id": "837047a8", "metadata": {}, + "source": [ + "# Rendering deterministic variables\n", + "\n", + "Pyro allows deterministic variables to be defined using `pyro.deterministic`. These variables can be rendered by setting `render_deterministic=True` in `pyro.render_model` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d90dc8d7", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "def model_deterministic(data):\n", + " sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n", + " mu = pyro.param(\"mu\", torch.tensor([0.]))\n", + " x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n", + " log_y = pyro.sample(\"y\", dist.Normal(x, 1))\n", + " y = pyro.deterministic(\"y_deterministic\", log_y.exp())\n", + " with pyro.plate(\"N\", len(data)):\n", + " eps_z_loc = pyro.sample(\"eps_z_loc\", dist.Normal(0, 1))\n", + " z_loc = pyro.deterministic(\"z_loc\", eps_z_loc + x, event_dim=0)\n", + " pyro.sample(\"z\", dist.Normal(z_loc, y), obs=data)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6fcc43d8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz_loc\n\nz_loc\n\n\n\nx->z_loc\n\n\n\n\n\ny_deterministic\n\ny_deterministic\n\n\n\ny->y_deterministic\n\n\n\n\n\nz\n\nz\n\n\n\ny_deterministic->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\neps_z_loc\n\neps_z_loc\n\n\n\neps_z_loc->z_loc\n\n\n\n\n\nz_loc->z\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ Normal\ny_deterministic ~ Deterministic\neps_z_loc ~ Normal\nz_loc ~ Deterministic\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = torch.ones(10)\n", + "pyro.render_model(\n", + " model_deterministic,\n", + " model_args=(data,),\n", + " render_params=True,\n", + " render_distributions=True,\n", + " render_deterministic=True\n", + ")" + ] } ], "metadata": { @@ -937,7 +493,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.11.5" } }, "nbformat": 4,