Skip to content

Commit

Permalink
Render deterministic functions (#3266)
Browse files Browse the repository at this point in the history
* Render deterministic functions

* Fixes to include and render deterministic variables

* Added unit test and fixed errors

* Added rendering tutorial

* Required rv_label to render correctly

* Removed one rendering example

* Removed test.py
  • Loading branch information
r3v1 authored Sep 20, 2023
1 parent cc8e545 commit 99633ae
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 537 deletions.
99 changes: 76 additions & 23 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Collection, Dict, List, Optional, Union

import torch

Expand All @@ -23,29 +23,42 @@
graphviz = SimpleNamespace(Digraph=object) # for type hints


def is_sample_site(msg):
def is_sample_site(msg, *, include_deterministic=False):
if msg["type"] != "sample":
return False
if site_is_subsample(msg):
return False

# Ignore masked observations.
if msg["is_observed"] and msg["mask"] is False:
return False
if not include_deterministic:
# Ignore masked observations.
if msg["is_observed"] and msg["mask"] is False:
return False

# Exclude deterministic sites.
fn = msg["fn"]
while hasattr(fn, "base_dist"):
fn = fn.base_dist
if type(fn).__name__ == "Delta":
return False
# Exclude deterministic sites.
fn = msg["fn"]
while hasattr(fn, "base_dist"):
fn = fn.base_dist
if type(fn).__name__ == "Delta":
return False

return True


def site_is_deterministic(msg: dict) -> bool:
return msg["type"] == "sample" and msg["infer"].get("_deterministic", False)


class TrackProvenance(Messenger):
def __init__(self, *, include_deterministic=False):
self.include_deterministic = include_deterministic

def _pyro_post_sample(self, msg):
if is_sample_site(msg):
if self.include_deterministic and site_is_deterministic(msg):
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)

elif is_sample_site(msg, include_deterministic=self.include_deterministic):
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)
Expand All @@ -62,6 +75,7 @@ def get_dependencies(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
) -> Dict[str, object]:
r"""
Infers dependency structure about a conditioned model.
Expand Down Expand Up @@ -169,6 +183,7 @@ def model_3():
:param callable model: A model.
:param tuple model_args: Optional tuple of model args.
:param dict model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:returns: A dictionary of metadata (see above).
:rtype: dict
"""
Expand All @@ -179,7 +194,7 @@ def model_3():

# Collect sites with tracked provenance.
with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)]

Expand Down Expand Up @@ -238,6 +253,7 @@ def get_model_relations(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
):
"""
Infer relations of RVs and plates from given model and optionally data.
Expand Down Expand Up @@ -271,6 +287,7 @@ def model(data):
:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:rtype: dict
"""
if model_args is None:
Expand All @@ -281,7 +298,7 @@ def model(data):
assert isinstance(model_kwargs, dict)

with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

sample_sample = {}
Expand All @@ -301,9 +318,14 @@ def _get_type_from_frozenname(frozen_name):
if site["type"] != "sample" or site_is_subsample(site):
continue

provenance = get_provenance(
site["fn"].log_prob(site["value"])
if not site_is_deterministic(site)
else site["fn"].base_dist.log_prob(site["value"])
)
sample_sample[name] = [
upstream
for upstream in get_provenance(site["fn"].log_prob(site["value"]))
for upstream in provenance
if upstream != name and _get_type_from_frozenname(upstream) == "sample"
]

Expand All @@ -313,7 +335,11 @@ def _get_type_from_frozenname(frozen_name):
if upstream != name and _get_type_from_frozenname(upstream) == "param"
]

sample_dist[name] = _get_dist_name(site["fn"])
sample_dist[name] = (
_get_dist_name(site["fn"])
if not site_is_deterministic(site)
else "Deterministic"
)
for frame in site["cond_indep_stack"]:
plate_sample[frame.name].append(name)
if site["is_observed"]:
Expand All @@ -332,10 +358,15 @@ def _resolve_plate_samples(plate_samples):
return plate_samples

plate_sample = _resolve_plate_samples(plate_sample)
# convert set to list to keep order of variables
plate_sample = {
k: [name for name in trace.nodes if name in v] for k, v in plate_sample.items()
}

# Normalize order of variables.
def sort_by_time(names: Collection[str]) -> List[str]:
return [name for name in trace.nodes if name in names]

sample_sample = {k: sort_by_time(v) for k, v in sample_sample.items()}
sample_param = {k: sort_by_time(v) for k, v in sample_param.items()}
plate_sample = {k: sort_by_time(v) for k, v in plate_sample.items()}
observed = sort_by_time(observed)

return {
"sample_sample": sample_sample,
Expand Down Expand Up @@ -510,11 +541,22 @@ def render_graph(
# For sample_nodes - ellipse
if node_data[rv]["distribution"]:
shape = "ellipse"
rv_label = rv

# For param_nodes - No shape
else:
shape = "plain"
cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color)
rv_label = rv.replace("$params", "")

# use different symbol for Deterministic site
node_style = (
"filled,dashed"
if node_data[rv]["distribution"] == "Deterministic"
else "filled"
)
cur_graph.node(
rv, label=rv_label, shape=shape, style=node_style, fillcolor=color
)

# add leaf nodes first
while len(plate_data) >= 1:
Expand Down Expand Up @@ -560,6 +602,7 @@ def render_model(
filename: Optional[str] = None,
render_distributions: bool = False,
render_params: bool = False,
render_deterministic: bool = False,
) -> "graphviz.Digraph":
"""
Renders a model using `graphviz <https://graphviz.org>`_ .
Expand All @@ -577,12 +620,20 @@ def render_model(
:param bool render_distributions: Whether to include RV distribution
annotations (and param constraints) in the plot.
:param bool render_params: Whether to show params inthe plot.
:param bool render_deterministic: Whether to include deterministic sites.
:returns: A model graph.
:rtype: graphviz.Digraph
"""
# Get model relations.
if not isinstance(model_args, list) and not isinstance(model_kwargs, list):
relations = [get_model_relations(model, model_args, model_kwargs)]
relations = [
get_model_relations(
model,
model_args,
model_kwargs,
include_deterministic=render_deterministic,
)
]
else: # semisupervised
if isinstance(model_args, list):
if not isinstance(model_kwargs, list):
Expand All @@ -591,7 +642,9 @@ def render_model(
model_args = [model_args] * len(model_kwargs)
assert len(model_args) == len(model_kwargs)
relations = [
get_model_relations(model, args, kwargs)
get_model_relations(
model, args, kwargs, include_deterministic=render_deterministic
)
for args, kwargs in zip(model_args, model_kwargs)
]

Expand Down
108 changes: 107 additions & 1 deletion tests/infer/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyro
import pyro.distributions as dist
from pyro.distributions.testing.fakes import NonreparameterizedNormal
from pyro.infer.inspect import _deep_merge, get_dependencies
from pyro.infer.inspect import _deep_merge, get_dependencies, get_model_relations


@pytest.mark.parametrize("grad_enabled", [True, False])
Expand Down Expand Up @@ -450,3 +450,109 @@ def model():
def test_deep_merge(things, expected):
actual = _deep_merge(things)
assert actual == expected


@pytest.mark.parametrize("include_deterministic", [True, False])
def test_get_model_relations(include_deterministic):
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(a, 1))
c = pyro.sample("c", dist.Normal(a, b.exp()))
d = pyro.sample("d", dist.Bernoulli(logits=c), obs=torch.tensor(0.0))

with pyro.plate("p", len(data)):
e = pyro.sample("e", dist.Normal(a, b.exp()))
f = pyro.deterministic("f", e + 1)
g = pyro.sample("g", dist.Delta(e + 1), obs=e + 1)
h = pyro.sample("h", dist.Delta(e + 1))
i = pyro.sample("i", dist.Normal(e, (f + g + h).exp()), obs=data)

return [a, b, c, d, e, f, g, h, i]

data = torch.randn(3)
actual = get_model_relations(
model,
(data,),
include_deterministic=include_deterministic,
)

if include_deterministic:
expected = {
"observed": ["d", "f", "g", "i"],
"param_constraint": {},
"plate_sample": {"p": ["e", "f", "g", "h", "i"]},
"sample_dist": {
"a": "Normal",
"b": "Normal",
"c": "Normal",
"d": "Bernoulli",
"e": "Normal",
"f": "Deterministic",
"g": "Delta",
"h": "Delta",
"i": "Normal",
},
"sample_param": {
"a": [],
"b": [],
"c": [],
"d": [],
"e": [],
"f": [],
"g": [],
"h": [],
"i": [],
},
"sample_sample": {
"a": [],
"b": ["a"],
"c": ["a", "b"],
"d": ["c"],
"e": ["a", "b"],
"f": ["e"],
"g": ["e"],
"h": ["e"],
"i": ["e", "f", "g", "h"],
},
}
else:
expected = {
"sample_sample": {
"a": [],
"b": ["a"],
"c": ["a", "b"],
"d": ["c"],
"e": ["a", "b"],
"f": ["e"],
"g": ["e"],
"h": ["e"],
"i": ["e"],
},
"sample_param": {
"a": [],
"b": [],
"c": [],
"d": [],
"e": [],
"f": [],
"g": [],
"h": [],
"i": [],
},
"sample_dist": {
"a": "Normal",
"b": "Normal",
"c": "Normal",
"d": "Bernoulli",
"e": "Normal",
"f": "Deterministic",
"g": "Delta",
"h": "Delta",
"i": "Normal",
},
"param_constraint": {},
"plate_sample": {"p": ["e", "f", "g", "h", "i"]},
"observed": ["d", "f", "g", "i"],
}

assert actual == expected
Loading

0 comments on commit 99633ae

Please sign in to comment.