diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 1bc79a5521..31e524a651 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -68,7 +68,7 @@ def __init__( self, sites: Union[str, List[str]], type: Optional[str] = "sample", - keep_dist: bool = False, + keep_dist: Optional[bool] = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites