Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render deterministic functions #3266

Merged
merged 7 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 75 additions & 23 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union, Collection
r3v1 marked this conversation as resolved.
Show resolved Hide resolved

import torch

Expand All @@ -23,29 +23,42 @@
graphviz = SimpleNamespace(Digraph=object) # for type hints


def is_sample_site(msg):
def is_sample_site(msg, *, include_deterministic=False):
if msg["type"] != "sample":
return False
if site_is_subsample(msg):
return False

# Ignore masked observations.
if msg["is_observed"] and msg["mask"] is False:
return False
if not include_deterministic:
# Ignore masked observations.
if msg["is_observed"] and msg["mask"] is False:
return False

# Exclude deterministic sites.
fn = msg["fn"]
while hasattr(fn, "base_dist"):
fn = fn.base_dist
if type(fn).__name__ == "Delta":
return False
# Exclude deterministic sites.
fn = msg["fn"]
while hasattr(fn, "base_dist"):
fn = fn.base_dist
if type(fn).__name__ == "Delta":
return False

return True


def site_is_deterministic(msg: dict) -> bool:
return msg["type"] == "sample" and msg["infer"].get("_deterministic", False)
r3v1 marked this conversation as resolved.
Show resolved Hide resolved


class TrackProvenance(Messenger):
def __init__(self, *, include_deterministic=False):
self.include_deterministic = include_deterministic

def _pyro_post_sample(self, msg):
if is_sample_site(msg):
if self.include_deterministic and site_is_deterministic(msg):
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)

elif is_sample_site(msg, include_deterministic=self.include_deterministic):
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)
Expand All @@ -62,6 +75,7 @@ def get_dependencies(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
) -> Dict[str, object]:
r"""
Infers dependency structure about a conditioned model.
Expand Down Expand Up @@ -169,6 +183,7 @@ def model_3():
:param callable model: A model.
:param tuple model_args: Optional tuple of model args.
:param dict model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:returns: A dictionary of metadata (see above).
:rtype: dict
"""
Expand All @@ -179,7 +194,7 @@ def model_3():

# Collect sites with tracked provenance.
with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)]

Expand Down Expand Up @@ -238,6 +253,7 @@ def get_model_relations(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
):
"""
Infer relations of RVs and plates from given model and optionally data.
Expand Down Expand Up @@ -271,6 +287,7 @@ def model(data):
:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:rtype: dict
"""
if model_args is None:
Expand All @@ -281,7 +298,7 @@ def model(data):
assert isinstance(model_kwargs, dict)

with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

sample_sample = {}
Expand All @@ -301,9 +318,14 @@ def _get_type_from_frozenname(frozen_name):
if site["type"] != "sample" or site_is_subsample(site):
continue

provenance = get_provenance(
site["fn"].log_prob(site["value"])
if not site_is_deterministic(site)
else site["fn"].base_dist.log_prob(site["value"])
)
sample_sample[name] = [
upstream
for upstream in get_provenance(site["fn"].log_prob(site["value"]))
for upstream in provenance
if upstream != name and _get_type_from_frozenname(upstream) == "sample"
]

Expand All @@ -313,7 +335,11 @@ def _get_type_from_frozenname(frozen_name):
if upstream != name and _get_type_from_frozenname(upstream) == "param"
]

sample_dist[name] = _get_dist_name(site["fn"])
sample_dist[name] = (
_get_dist_name(site["fn"])
if not site_is_deterministic(site)
else "Deterministic"
)
for frame in site["cond_indep_stack"]:
plate_sample[frame.name].append(name)
if site["is_observed"]:
Expand All @@ -332,10 +358,15 @@ def _resolve_plate_samples(plate_samples):
return plate_samples

plate_sample = _resolve_plate_samples(plate_sample)
# convert set to list to keep order of variables
plate_sample = {
k: [name for name in trace.nodes if name in v] for k, v in plate_sample.items()
}

# Normalize order of variables.
def sort_by_time(names: Collection[str]) -> List[str]:
return [name for name in trace.nodes if name in names]

sample_sample = {k: sort_by_time(v) for k, v in sample_sample.items()}
sample_param = {k: sort_by_time(v) for k, v in sample_param.items()}
plate_sample = {k: sort_by_time(v) for k, v in plate_sample.items()}
observed = sort_by_time(observed)

return {
"sample_sample": sample_sample,
Expand Down Expand Up @@ -510,11 +541,21 @@ def render_graph(
# For sample_nodes - ellipse
if node_data[rv]["distribution"]:
shape = "ellipse"
rv_label = rv

# For param_nodes - No shape
else:
shape = "plain"
cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color)

# use different symbol for Deterministic site
node_style = (
"filled,dashed"
if node_data[rv]["distribution"] == "Deterministic"
else "filled"
)
cur_graph.node(
rv, label=rv_label, shape=shape, style=node_style, fillcolor=color
)

# add leaf nodes first
while len(plate_data) >= 1:
Expand Down Expand Up @@ -560,6 +601,7 @@ def render_model(
filename: Optional[str] = None,
render_distributions: bool = False,
render_params: bool = False,
render_deterministic: bool = False,
) -> "graphviz.Digraph":
"""
Renders a model using `graphviz <https://graphviz.org>`_ .
Expand All @@ -577,12 +619,20 @@ def render_model(
:param bool render_distributions: Whether to include RV distribution
annotations (and param constraints) in the plot.
:param bool render_params: Whether to show params inthe plot.
:param bool render_deterministic: Whether to include deterministic sites.
:returns: A model graph.
:rtype: graphviz.Digraph
"""
# Get model relations.
if not isinstance(model_args, list) and not isinstance(model_kwargs, list):
relations = [get_model_relations(model, model_args, model_kwargs)]
relations = [
get_model_relations(
model,
model_args,
model_kwargs,
include_deterministic=render_deterministic,
)
]
else: # semisupervised
if isinstance(model_args, list):
if not isinstance(model_kwargs, list):
Expand All @@ -591,7 +641,9 @@ def render_model(
model_args = [model_args] * len(model_kwargs)
assert len(model_args) == len(model_kwargs)
relations = [
get_model_relations(model, args, kwargs)
get_model_relations(
model, args, kwargs, include_deterministic=render_deterministic
)
for args, kwargs in zip(model_args, model_kwargs)
]

Expand Down
33 changes: 33 additions & 0 deletions pyro/test.py
r3v1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
from pyro import distributions as dist


def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(a, 1))
c = pyro.sample("c", dist.Normal(a, b.exp()))
d = pyro.sample("d", dist.Bernoulli(logits=c), obs=torch.tensor(0.0))

with pyro.plate("p", len(data)):
e = pyro.sample("e", dist.Normal(a, b.exp()))
f = pyro.deterministic("f", e + 1)
g = pyro.sample("g", dist.Delta(e + 1), obs=e + 1)
h = pyro.sample("h", dist.Delta(e + 1))
i = pyro.sample("i", dist.Normal(e, (f + g + h).exp()), obs=data)


obs = torch.ones(10)
g = pyro.render_model(
model,
model_args=(obs,),
render_distributions=True,
render_params=True,
render_deterministic=True,
)
g.format = "png"
g.render("model", view=False)
108 changes: 107 additions & 1 deletion tests/infer/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyro
import pyro.distributions as dist
from pyro.distributions.testing.fakes import NonreparameterizedNormal
from pyro.infer.inspect import _deep_merge, get_dependencies
from pyro.infer.inspect import _deep_merge, get_dependencies, get_model_relations


@pytest.mark.parametrize("grad_enabled", [True, False])
Expand Down Expand Up @@ -450,3 +450,109 @@ def model():
def test_deep_merge(things, expected):
actual = _deep_merge(things)
assert actual == expected


@pytest.mark.parametrize("include_deterministic", [True, False])
def test_get_model_relations(include_deterministic):
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(a, 1))
c = pyro.sample("c", dist.Normal(a, b.exp()))
d = pyro.sample("d", dist.Bernoulli(logits=c), obs=torch.tensor(0.0))

with pyro.plate("p", len(data)):
e = pyro.sample("e", dist.Normal(a, b.exp()))
f = pyro.deterministic("f", e + 1)
g = pyro.sample("g", dist.Delta(e + 1), obs=e + 1)
h = pyro.sample("h", dist.Delta(e + 1))
i = pyro.sample("i", dist.Normal(e, (f + g + h).exp()), obs=data)

return [a, b, c, d, e, f, g, h, i]

data = torch.randn(3)
actual = get_model_relations(
model,
(data,),
include_deterministic=include_deterministic,
)

if include_deterministic:
expected = {
"observed": ["d", "f", "g", "i"],
"param_constraint": {},
"plate_sample": {"p": ["e", "f", "g", "h", "i"]},
"sample_dist": {
"a": "Normal",
"b": "Normal",
"c": "Normal",
"d": "Bernoulli",
"e": "Normal",
"f": "Deterministic",
"g": "Delta",
"h": "Delta",
"i": "Normal",
},
"sample_param": {
"a": [],
"b": [],
"c": [],
"d": [],
"e": [],
"f": [],
"g": [],
"h": [],
"i": [],
},
"sample_sample": {
"a": [],
"b": ["a"],
"c": ["a", "b"],
"d": ["c"],
"e": ["a", "b"],
"f": ["e"],
"g": ["e"],
"h": ["e"],
"i": ["e", "f", "g", "h"],
},
}
else:
expected = {
"sample_sample": {
"a": [],
"b": ["a"],
"c": ["a", "b"],
"d": ["c"],
"e": ["a", "b"],
"f": ["e"],
"g": ["e"],
"h": ["e"],
"i": ["e"],
},
"sample_param": {
"a": [],
"b": [],
"c": [],
"d": [],
"e": [],
"f": [],
"g": [],
"h": [],
"i": [],
},
"sample_dist": {
"a": "Normal",
"b": "Normal",
"c": "Normal",
"d": "Bernoulli",
"e": "Normal",
"f": "Deterministic",
"g": "Delta",
"h": "Delta",
"i": "Normal",
},
"param_constraint": {},
"plate_sample": {"p": ["e", "f", "g", "h", "i"]},
"observed": ["d", "f", "g", "i"],
}

assert actual == expected
Loading
Loading