-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate pyro.poutine.runtime
#3288
Changes from all commits
df1e471
a6538af
f8faddf
4284574
13690ce
7b958f1
b475541
51244e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,48 @@ | ||
# Copyright (c) 2017-2019 Uber Technologies, Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import functools | ||
from typing import Dict | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from typing_extensions import TypedDict | ||
|
||
from pyro.params.param_store import ( # noqa: F401 | ||
_MODULE_NAMESPACE_DIVIDER, | ||
ParamStoreDict, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyro.poutine.indep_messenger import CondIndepStackFrame | ||
from pyro.poutine.messenger import Messenger | ||
|
||
# the global pyro stack | ||
_PYRO_STACK = [] | ||
_PYRO_STACK: List[Messenger] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great! |
||
|
||
# the global ParamStore | ||
_PYRO_PARAM_STORE = ParamStoreDict() | ||
|
||
|
||
class Message(TypedDict, total=False): | ||
type: Optional[str] | ||
name: str | ||
fn: Callable | ||
is_observed: bool | ||
args: Tuple | ||
kwargs: Dict | ||
value: Optional[torch.Tensor] | ||
scale: float | ||
mask: Union[bool, torch.Tensor, None] | ||
cond_indep_stack: Tuple[CondIndepStackFrame, ...] | ||
done: bool | ||
stop: bool | ||
continuation: Optional[Callable[[Message], None]] | ||
infer: Optional[Dict[str, Union[str, bool]]] | ||
obs: Optional[torch.Tensor] | ||
|
||
|
||
class _DimAllocator: | ||
""" | ||
Dimension allocator for internal use by :class:`plate`. | ||
|
@@ -24,26 +51,25 @@ class _DimAllocator: | |
Note that dimensions are indexed from the right, e.g. -1, -2. | ||
""" | ||
|
||
def __init__(self): | ||
self._stack = [] # in reverse orientation of log_prob.shape | ||
def __init__(self) -> None: | ||
# in reverse orientation of log_prob.shape | ||
self._stack: List[Optional[str]] = [] | ||
|
||
def allocate(self, name, dim): | ||
def allocate(self, name: str, dim: Optional[int]) -> int: | ||
""" | ||
Allocate a dimension to an :class:`plate` with given name. | ||
Dim should be either None for automatic allocation or a negative | ||
integer for manual allocation. | ||
""" | ||
if name in self._stack: | ||
raise ValueError('duplicate plate "{}"'.format(name)) | ||
raise ValueError(f"duplicate plate '{name}'") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to change, but I've been using the notation |
||
if dim is None: | ||
# Automatically designate the rightmost available dim for allocation. | ||
dim = -1 | ||
while -dim <= len(self._stack) and self._stack[-1 - dim] is not None: | ||
dim -= 1 | ||
elif dim >= 0: | ||
raise ValueError( | ||
"Expected dim < 0 to index from the right, actual {}".format(dim) | ||
) | ||
raise ValueError(f"Expected dim < 0 to index from the right, actual {dim}") | ||
|
||
# Allocate the requested dimension. | ||
while dim < -len(self._stack): | ||
|
@@ -64,7 +90,7 @@ def allocate(self, name, dim): | |
self._stack[-1 - dim] = name | ||
return dim | ||
|
||
def free(self, name, dim): | ||
def free(self, name: str, dim: int) -> None: | ||
""" | ||
Free a dimension. | ||
""" | ||
|
@@ -88,7 +114,7 @@ class _EnumAllocator: | |
Note that ids are simply nonnegative integers here. | ||
""" | ||
|
||
def set_first_available_dim(self, first_available_dim): | ||
def set_first_available_dim(self, first_available_dim: int) -> None: | ||
""" | ||
Set the first available dim, which should be to the left of all | ||
:class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should | ||
|
@@ -98,9 +124,9 @@ def set_first_available_dim(self, first_available_dim): | |
assert first_available_dim < 0, first_available_dim | ||
self.next_available_dim = first_available_dim | ||
self.next_available_id = 0 | ||
self.dim_to_id = {} # only the global ids | ||
self.dim_to_id: Dict[int, int] = {} # only the global ids | ||
|
||
def allocate(self, scope_dims=None): | ||
def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int, int]: | ||
""" | ||
Allocate a new recyclable dim and a unique id. | ||
|
||
|
@@ -146,26 +172,28 @@ class NonlocalExit(Exception): | |
Used by poutine.EscapeMessenger to return site information. | ||
""" | ||
|
||
def __init__(self, site, *args, **kwargs): | ||
def __init__(self, site: Message, *args, **kwargs) -> None: | ||
""" | ||
:param site: message at a pyro site constructor. | ||
Just stores the input site. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self.site = site | ||
|
||
def reset_stack(self): | ||
def reset_stack(self) -> None: | ||
""" | ||
Reset the state of the frames remaining in the stack. | ||
Necessary for multiple re-executions in poutine.queue. | ||
""" | ||
from pyro.poutine.block_messenger import BlockMessenger | ||
|
||
for frame in reversed(_PYRO_STACK): | ||
frame._reset() | ||
if type(frame).__name__ == "BlockMessenger" and frame.hide_fn(self.site): | ||
if isinstance(frame, BlockMessenger) and frame.hide_fn(self.site): | ||
break | ||
|
||
|
||
def default_process_message(msg): | ||
def default_process_message(msg: Message) -> None: | ||
""" | ||
Default method for processing messages in inference. | ||
|
||
|
@@ -174,15 +202,15 @@ def default_process_message(msg): | |
""" | ||
if msg["done"] or msg["is_observed"] or msg["value"] is not None: | ||
msg["done"] = True | ||
return msg | ||
return | ||
|
||
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) | ||
|
||
# after fn has been called, update msg to prevent it from being called again. | ||
msg["done"] = True | ||
|
||
|
||
def apply_stack(initial_msg): | ||
def apply_stack(initial_msg: Message) -> None: | ||
""" | ||
Execute the effect stack at a single site according to the following scheme: | ||
|
||
|
@@ -223,8 +251,6 @@ def apply_stack(initial_msg): | |
if cont is not None: | ||
cont(msg) | ||
|
||
return None | ||
|
||
|
||
def am_i_wrapped(): | ||
""" | ||
|
@@ -234,7 +260,7 @@ def am_i_wrapped(): | |
return len(_PYRO_STACK) > 0 | ||
|
||
|
||
def effectful(fn=None, type=None): | ||
def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable: | ||
""" | ||
:param fn: function or callable that performs an effectful computation | ||
:param str type: the type label of the operation, e.g. `"sample"` | ||
|
@@ -247,7 +273,7 @@ def effectful(fn=None, type=None): | |
if getattr(fn, "_is_effectful", None): | ||
return fn | ||
|
||
assert type is not None, "must provide a type label for operation {}".format(fn) | ||
assert type is not None, f"must provide a type label for operation {fn}" | ||
assert type != "message", "cannot use 'message' as keyword" | ||
|
||
@functools.wraps(fn) | ||
|
@@ -281,11 +307,11 @@ def _fn(*args, **kwargs): | |
apply_stack(msg) | ||
return msg["value"] | ||
|
||
_fn._is_effectful = True | ||
_fn._is_effectful = True # type: ignore[attr-defined] | ||
return _fn | ||
|
||
|
||
def _inspect() -> Dict[str, object]: | ||
def _inspect() -> Message: | ||
""" | ||
EXPERIMENTAL Inspect the Pyro stack. | ||
|
||
|
@@ -295,7 +321,7 @@ def _inspect() -> Dict[str, object]: | |
:returns: A message with all effects applied. | ||
:rtype: dict | ||
""" | ||
msg = { | ||
msg: Message = { | ||
"type": "inspect", | ||
"name": "_pyro_inspect", | ||
"fn": lambda: True, | ||
|
@@ -315,7 +341,7 @@ def _inspect() -> Dict[str, object]: | |
return msg | ||
|
||
|
||
def get_mask(): | ||
def get_mask() -> Union[bool, torch.Tensor, None]: | ||
""" | ||
Records the effects of enclosing ``poutine.mask`` handlers. | ||
|
||
|
@@ -335,7 +361,7 @@ def model(): | |
return _inspect()["mask"] | ||
|
||
|
||
def get_plates() -> tuple: | ||
def get_plates() -> Tuple[CondIndepStackFrame, ...]: | ||
""" | ||
Records the effects of enclosing ``pyro.plate`` contexts. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ column_limit = 120 | |
|
||
[mypy] | ||
python_version = 3.8 | ||
explicit_package_bases = True | ||
warn_return_any = True | ||
warn_unused_configs = True | ||
warn_incomplete_stub = True | ||
|
@@ -77,11 +78,9 @@ warn_unused_ignores = True | |
ignore_errors = True | ||
warn_unused_ignores = True | ||
|
||
[mypy-pyro.optm.*] | ||
warn_unused_ignores = True | ||
|
||
[mypy-pyro.poutine.*] | ||
[mypy-pyro.optim.*] | ||
ignore_errors = True | ||
warn_unused_ignores = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add |
||
|
||
[mypy-pyro.util.*] | ||
ignore_errors = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporarily removed the return type (changed to Any) so that mypy wouldn't complain. Will be fixed in a later PR.