Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.primitives & poutine.block_messenger #3292

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TorchDistributionMixin(Distribution):
from :class:`TorchDistributionMixin`.
"""

def __call__(self, sample_shape=torch.Size()):
def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
Samples a random value.

Expand All @@ -51,7 +51,7 @@ def __call__(self, sample_shape=torch.Size()):
)

@property
def event_dim(self):
def event_dim(self) -> int:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file I type annotated only methods needed for pyro.primitives

"""
:return: Number of dimensions of individual events.
:rtype: int
Expand Down
49 changes: 36 additions & 13 deletions pyro/poutine/block_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, List, Optional

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
def _block_fn(
expose: List[str],
expose_types: List[str],
hide: List[str],
hide_types: List[str],
hide_all: bool,
msg: Message,
) -> bool:
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
msg_type = "observe"
Expand All @@ -27,7 +36,14 @@ def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
return False


def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types):
def _make_default_hide_fn(
hide_all: bool,
expose_all: bool,
hide: Optional[List[str]],
expose: Optional[List[str]],
hide_types: Optional[List[str]],
expose_types: Optional[List[str]],
) -> Callable[[Message], bool]:
# first, some sanity checks:
# hide_all and expose_all intersect?
assert (hide_all is False and expose_all is False) or (
Expand Down Expand Up @@ -65,6 +81,14 @@ def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose
return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: Message) -> bool:
return not fn(msg)

return negated_fn


class BlockMessenger(Messenger):
"""
This handler selectively hides Pyro primitive sites from the outside world.
Expand Down Expand Up @@ -116,27 +140,26 @@ class BlockMessenger(Messenger):

def __init__(
self,
hide_fn=None,
expose_fn=None,
hide_all=True,
expose_all=False,
hide=None,
expose=None,
hide_types=None,
expose_types=None,
hide_fn: Optional[Callable[[Message], Optional[bool]]] = None,
expose_fn: Optional[Callable[[Message], Optional[bool]]] = None,
hide_all: bool = True,
expose_all: bool = False,
hide: Optional[List[str]] = None,
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
):
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif expose_fn is not None:
self.hide_fn = lambda msg: not expose_fn(msg)
self.hide_fn = _negate_fn(expose_fn)
else:
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["stop"] = bool(self.hide_fn(msg))
return None
120 changes: 76 additions & 44 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)

import torch
from typing_extensions import TypedDict
from typing_extensions import ParamSpec, TypedDict

from pyro.params.param_store import ( # noqa: F401
_MODULE_NAMESPACE_DIVIDER,
ParamStoreDict,
)

P = ParamSpec("P")
T = TypeVar("T")

if TYPE_CHECKING:
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger
Expand All @@ -26,8 +40,8 @@


class Message(TypedDict, total=False):
type: Optional[str]
name: str
type: str
name: Optional[str]
fn: Callable
is_observed: bool
args: Tuple
Expand Down Expand Up @@ -252,15 +266,31 @@ def apply_stack(initial_msg: Message) -> None:
cont(msg)


def am_i_wrapped():
def am_i_wrapped() -> bool:
"""
Checks whether the current computation is wrapped in a poutine.
:returns: bool
"""
return len(_PYRO_STACK) > 0


def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable:
@overload
def effectful(
fn: None = ..., type: Optional[str] = ...
) -> Callable[[Callable[P, T]], Callable[..., Union[T, torch.Tensor, None]]]:
...


@overload
def effectful(
fn: Callable[P, T] = ..., type: Optional[str] = ...
) -> Callable[..., Union[T, torch.Tensor, None]]:
...


def effectful(
fn: Optional[Callable[P, T]] = None, type: Optional[str] = None
) -> Callable:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best I could do at being specific with callable types.

"""
:param fn: function or callable that performs an effectful computation
:param str type: the type label of the operation, e.g. `"sample"`
Expand All @@ -277,32 +307,34 @@ def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Call
assert type != "message", "cannot use 'message' as keyword"

@functools.wraps(fn)
def _fn(*args, **kwargs):
name = kwargs.pop("name", None)
infer = kwargs.pop("infer", {})

value = kwargs.pop("obs", None)
is_observed = value is not None
def _fn(
*args: P.args,
name: Optional[str] = None,
infer: Optional[Dict] = None,
obs: Optional[torch.Tensor] = None,
**kwargs: P.kwargs,
) -> Union[T, torch.Tensor, None]:
is_observed = obs is not None

if not am_i_wrapped():
return fn(*args, **kwargs)
else:
msg = {
"type": type,
"name": name,
"fn": fn,
"is_observed": is_observed,
"args": args,
"kwargs": kwargs,
"value": value,
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
"infer": infer,
}
msg = Message(
type=type,
name=name,
fn=fn,
is_observed=is_observed,
args=args,
kwargs=kwargs,
value=obs,
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
infer=infer if infer is not None else {},
)
# apply the stack and return its return value
apply_stack(msg)
return msg["value"]
Expand All @@ -321,22 +353,22 @@ def _inspect() -> Message:
:returns: A message with all effects applied.
:rtype: dict
"""
msg: Message = {
"type": "inspect",
"name": "_pyro_inspect",
"fn": lambda: True,
"is_observed": False,
"args": (),
"kwargs": {},
"value": None,
"infer": {"_do_not_trace": True},
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
}
msg = Message(
type="inspect",
name="_pyro_inspect",
fn=lambda: True,
is_observed=False,
args=(),
kwargs={},
value=None,
infer={"_do_not_trace": True},
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
)
apply_stack(msg)
return msg

Expand Down
Loading
Loading