-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate Trace
, TraceMessenger
, & pyro.poutine.guide
#3299
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,16 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Callable, Dict, Tuple, Union | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
import pyro.distributions as dist | ||
from pyro.distributions.distribution import Distribution | ||
|
||
from .trace_messenger import TraceMessenger | ||
from .trace_struct import Trace | ||
from .util import prune_subsample_sites, site_is_subsample | ||
from pyro.distributions.torch_distribution import TorchDistributionMixin | ||
from pyro.poutine.runtime import Message | ||
from pyro.poutine.trace_messenger import TraceMessenger | ||
from pyro.poutine.trace_struct import Trace | ||
from pyro.poutine.util import prune_subsample_sites, site_is_subsample | ||
|
||
|
||
class GuideMessenger(TraceMessenger, ABC): | ||
|
@@ -21,22 +21,22 @@ class GuideMessenger(TraceMessenger, ABC): | |
Derived classes must implement the :meth:`get_posterior` method. | ||
""" | ||
|
||
def __init__(self, model: Callable): | ||
def __init__(self, model: Callable) -> None: | ||
super().__init__() | ||
# Do not register model as submodule | ||
self._model = (model,) | ||
|
||
@property | ||
def model(self): | ||
def model(self) -> Callable: | ||
return self._model[0] | ||
|
||
def __getstate__(self): | ||
def __getstate__(self) -> Dict[str, object]: | ||
# Avoid pickling the trace. | ||
state = super().__getstate__() | ||
state.pop("trace") | ||
state = self.__dict__.copy() | ||
del state["trace"] | ||
return state | ||
|
||
def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] | ||
def __call__(self, *args, **kwargs) -> Dict[str, Optional[torch.Tensor]]: # type: ignore[override] | ||
""" | ||
Draws posterior samples from the guide and replays the model against | ||
those samples. | ||
|
@@ -60,9 +60,12 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[ | |
} | ||
return samples | ||
|
||
def _pyro_sample(self, msg): | ||
def _pyro_sample(self, msg: Message) -> None: | ||
if msg["is_observed"] or site_is_subsample(msg): | ||
return | ||
assert isinstance(msg["name"], str) | ||
assert isinstance(msg["fn"], TorchDistributionMixin) | ||
assert msg["infer"] is not None | ||
prior = msg["fn"] | ||
msg["infer"]["prior"] = prior | ||
posterior = self.get_posterior(msg["name"], prior) | ||
|
@@ -72,17 +75,20 @@ def _pyro_sample(self, msg): | |
posterior = posterior.expand(prior.batch_shape) | ||
msg["fn"] = posterior | ||
|
||
def _pyro_post_sample(self, msg): | ||
def _pyro_post_sample(self, msg: Message) -> None: | ||
# Manually apply outer plates. | ||
assert msg["infer"] is not None | ||
prior = msg["infer"].get("prior") | ||
if prior is not None and prior.batch_shape != msg["fn"].batch_shape: | ||
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) | ||
if prior is not None: | ||
assert isinstance(msg["fn"], TorchDistributionMixin) | ||
if prior.batch_shape != msg["fn"].batch_shape: | ||
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) | ||
return super()._pyro_post_sample(msg) | ||
|
||
@abstractmethod | ||
def get_posterior( | ||
self, name: str, prior: Distribution | ||
) -> Union[Distribution, torch.Tensor]: | ||
self, name: str, prior: TorchDistributionMixin | ||
) -> Union[TorchDistributionMixin, torch.Tensor]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
""" | ||
Abstract method to compute a posterior distribution or sample a | ||
posterior value given a prior distribution conditioned on upstream | ||
|
@@ -112,7 +118,7 @@ def get_posterior( | |
""" | ||
raise NotImplementedError | ||
|
||
def upstream_value(self, name: str): | ||
def upstream_value(self, name: str) -> Optional[torch.Tensor]: | ||
""" | ||
For use in :meth:`get_posterior` . | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised this is
Optional[Tensor]
rather than justTensor
. It might make more sense to assert or filter in the samples loop below, so callers know they are getting tensors.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added assert statements in the for loop (a bit longer code but probably better than filtering with if statement in the dict comprehension)