Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rendering PyroModules can fail with local parameter mode enabled #3365

Closed
eb8680 opened this issue May 7, 2024 · 0 comments · Fixed by #3366
Closed

Rendering PyroModules can fail with local parameter mode enabled #3365

eb8680 opened this issue May 7, 2024 · 0 comments · Fixed by #3366
Assignees
Labels

Comments

@eb8680
Copy link
Member

eb8680 commented May 7, 2024

When using module_local_params=True, calling pyro.render_model on a PyroModule with constrained parameters can fail with a KeyError, as shown by the test case in this stack trace:

_________________________________________________________________________________________________ test_render_constrained_param[True] __________________________________________________________________________________________________

use_module_local_params = True

    @pytest.mark.parametrize("use_module_local_params", [True, False])
    def test_render_constrained_param(use_module_local_params):
    
        class Model(PyroModule):
    
            @PyroParam(constraint=constraints.positive)
            def x(self):
                return torch.tensor(1.234)
    
            @PyroParam(constraint=constraints.real)
            def y(self):
                return torch.tensor(0.456)
    
            def forward(self):
                return self.x + self.y
    
        with pyro.settings.context(module_local_params=use_module_local_params):
            model = Model()
>           pyro.render_model(model)

tests/nn/test_module.py:1068: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
pyro/infer/inspect.py:630: in render_model
    get_model_relations(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

model = Model(), model_args = (), model_kwargs = {}, include_deterministic = False

    def get_model_relations(
        model: Callable,
        model_args: Optional[tuple] = None,
        model_kwargs: Optional[dict] = None,
        include_deterministic: bool = False,
    ):
        """
        Infer relations of RVs and plates from given model and optionally data.
        See https://github.com/pyro-ppl/pyro/issues/949 for more details.
    
        This returns a dictionary with keys:
    
        -  "sample_sample" map each downstream sample site to a list of the upstream
           sample sites on which it depend;
        -  "sample_dist" maps each sample site to the name of the distribution at
           that site;
        -  "plate_sample" maps each plate name to a list of the sample sites within
           that plate; and
        -  "observe" is a list of observed sample sites.
    
        For example for the model::
    
            def model(data):
                m = pyro.sample('m', dist.Normal(0, 1))
                sd = pyro.sample('sd', dist.LogNormal(m, 1))
                with pyro.plate('N', len(data)):
                    pyro.sample('obs', dist.Normal(m, sd), obs=data)
    
        the relation is::
    
            {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
             'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
             'plate_sample': {'N': ['obs']},
             'observed': ['obs']}
    
        :param callable model: A model to inspect.
        :param model_args: Optional tuple of model args.
        :param model_kwargs: Optional dict of model kwargs.
        :param bool include_deterministic: Whether to include deterministic sites.
        :rtype: dict
        """
        if model_args is None:
            model_args = ()
        if model_kwargs is None:
            model_kwargs = {}
        assert isinstance(model_args, tuple)
        assert isinstance(model_kwargs, dict)
    
        with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
            with TrackProvenance(include_deterministic=include_deterministic):
                trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    
        sample_sample = {}
        sample_param = {}
        sample_dist = {}
        param_constraint = {}
        plate_sample = defaultdict(list)
        observed = []
    
        def _get_type_from_frozenname(frozen_name):
            return trace.nodes[frozen_name]["type"]
    
        for name, site in trace.nodes.items():
            if site["type"] == "param":
>               param_constraint[name] = str(site["kwargs"]["constraint"])
E               KeyError: 'constraint'

pyro/infer/inspect.py:316: KeyError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant