Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate Trace, TraceMessenger, & pyro.poutine.guide #3299

Merged
merged 6 commits into from
Dec 3, 2023
Merged

Conversation

ordabayevy
Copy link
Member

No description provided.

self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
self, name: str, prior: TorchDistributionMixin
) -> Union[TorchDistributionMixin, torch.Tensor]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _pyro_sample above it is expected that posterior has batch_shape. So I changed it here to TorchDistributionMixin because Distribution doesn't have batch_shape. My rule of thumb is that if only sample and log_prob methods are expected is to use Distribution and if additionally batch_shape and event_shape attributes are expected then to use TorchDistributionMixin

@ordabayevy ordabayevy changed the title Type annotate TraceMessenger Type annotate Trace, TraceMessenger, & pyro.poutine.guide Dec 3, 2023
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I have just one comment on GuideMessenger.__call__()

return state

def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override]
def __call__(self, *args, **kwargs) -> Dict[str, Optional[torch.Tensor]]: # type: ignore[override]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this is Optional[Tensor] rather than just Tensor. It might make more sense to assert or filter in the samples loop below, so callers know they are getting tensors.

Copy link
Member Author

@ordabayevy ordabayevy Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assert statements in the for loop (a bit longer code but probably better than filtering with if statement in the dict comprehension)

@fritzo fritzo merged commit 7233cf9 into dev Dec 3, 2023
9 checks passed
@ordabayevy ordabayevy deleted the type-trace branch December 3, 2023 21:36
@ordabayevy ordabayevy mentioned this pull request Feb 5, 2024
23 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants