-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate Trace
, TraceMessenger
, & pyro.poutine.guide
#3299
Conversation
self, name: str, prior: Distribution | ||
) -> Union[Distribution, torch.Tensor]: | ||
self, name: str, prior: TorchDistributionMixin | ||
) -> Union[TorchDistributionMixin, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _pyro_sample
above it is expected that posterior
has batch_shape
. So I changed it here to TorchDistributionMixin
because Distribution
doesn't have batch_shape
. My rule of thumb is that if only sample
and log_prob
methods are expected is to use Distribution
and if additionally batch_shape
and event_shape
attributes are expected then to use TorchDistributionMixin
Trace
, TraceMessenger
, & pyro.poutine.guide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I have just one comment on GuideMessenger.__call__()
pyro/poutine/guide.py
Outdated
return state | ||
|
||
def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] | ||
def __call__(self, *args, **kwargs) -> Dict[str, Optional[torch.Tensor]]: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised this is Optional[Tensor]
rather than just Tensor
. It might make more sense to assert or filter in the samples loop below, so callers know they are getting tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added assert statements in the for loop (a bit longer code but probably better than filtering with if statement in the dict comprehension)
No description provided.