You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using module_local_params=True, calling pyro.render_model on a PyroModule with constrained parameters can fail with a KeyError, as shown by the test case in this stack trace:
_________________________________________________________________________________________________ test_render_constrained_param[True] __________________________________________________________________________________________________
use_module_local_params = True
@pytest.mark.parametrize("use_module_local_params", [True, False])
def test_render_constrained_param(use_module_local_params):
class Model(PyroModule):
@PyroParam(constraint=constraints.positive)
def x(self):
return torch.tensor(1.234)
@PyroParam(constraint=constraints.real)
def y(self):
return torch.tensor(0.456)
def forward(self):
return self.x + self.y
with pyro.settings.context(module_local_params=use_module_local_params):
model = Model()
> pyro.render_model(model)
tests/nn/test_module.py:1068:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/inspect.py:630: in render_model
get_model_relations(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
model = Model(), model_args = (), model_kwargs = {}, include_deterministic = False
def get_model_relations(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
):
"""
Infer relations of RVs and plates from given model and optionally data.
See https://github.com/pyro-ppl/pyro/issues/949 for more details.
This returns a dictionary with keys:
- "sample_sample" map each downstream sample site to a list of the upstream
sample sites on which it depend;
- "sample_dist" maps each sample site to the name of the distribution at
that site;
- "plate_sample" maps each plate name to a list of the sample sites within
that plate; and
- "observe" is a list of observed sample sites.
For example for the model::
def model(data):
m = pyro.sample('m', dist.Normal(0, 1))
sd = pyro.sample('sd', dist.LogNormal(m, 1))
with pyro.plate('N', len(data)):
pyro.sample('obs', dist.Normal(m, sd), obs=data)
the relation is::
{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'plate_sample': {'N': ['obs']},
'observed': ['obs']}
:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:rtype: dict
"""
if model_args is None:
model_args = ()
if model_kwargs is None:
model_kwargs = {}
assert isinstance(model_args, tuple)
assert isinstance(model_kwargs, dict)
with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
sample_sample = {}
sample_param = {}
sample_dist = {}
param_constraint = {}
plate_sample = defaultdict(list)
observed = []
def _get_type_from_frozenname(frozen_name):
return trace.nodes[frozen_name]["type"]
for name, site in trace.nodes.items():
if site["type"] == "param":
> param_constraint[name] = str(site["kwargs"]["constraint"])
E KeyError: 'constraint'
pyro/infer/inspect.py:316: KeyError
The text was updated successfully, but these errors were encountered:
When using
module_local_params=True
, callingpyro.render_model
on aPyroModule
with constrained parameters can fail with aKeyError
, as shown by the test case in this stack trace:The text was updated successfully, but these errors were encountered: