Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.poutine.runtime #3288

Merged
merged 8 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from pyro.distributions import Categorical
from pyro.distributions import Categorical # type: ignore[attr-defined]
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.ops.indexing import Vindex
from pyro.util import ignore_jit_warnings
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_posterior(
"""
raise NotImplementedError

def upstream_value(self, name: str) -> torch.Tensor:
def upstream_value(self, name: str):
Copy link
Member Author

@ordabayevy ordabayevy Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily removed the return type (changed to Any) so that mypy wouldn't complain. Will be fixed in a later PR.

"""
For use in :meth:`get_posterior` .

Expand Down
82 changes: 54 additions & 28 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,48 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import functools
from typing import Dict
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from typing_extensions import TypedDict

from pyro.params.param_store import ( # noqa: F401
_MODULE_NAMESPACE_DIVIDER,
ParamStoreDict,
)

if TYPE_CHECKING:
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger

# the global pyro stack
_PYRO_STACK = []
_PYRO_STACK: List[Messenger] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!


# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()


class Message(TypedDict, total=False):
type: Optional[str]
name: str
fn: Callable
is_observed: bool
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
scale: float
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
done: bool
stop: bool
continuation: Optional[Callable[[Message], None]]
infer: Optional[Dict[str, Union[str, bool]]]
obs: Optional[torch.Tensor]


class _DimAllocator:
"""
Dimension allocator for internal use by :class:`plate`.
Expand All @@ -24,26 +51,25 @@ class _DimAllocator:
Note that dimensions are indexed from the right, e.g. -1, -2.
"""

def __init__(self):
self._stack = [] # in reverse orientation of log_prob.shape
def __init__(self) -> None:
# in reverse orientation of log_prob.shape
self._stack: List[Optional[str]] = []

def allocate(self, name, dim):
def allocate(self, name: str, dim: Optional[int]) -> int:
"""
Allocate a dimension to an :class:`plate` with given name.
Dim should be either None for automatic allocation or a negative
integer for manual allocation.
"""
if name in self._stack:
raise ValueError('duplicate plate "{}"'.format(name))
raise ValueError(f"duplicate plate '{name}'")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change, but I've been using the notation f"duplicate plate {name!r}" which I find easier to read.

if dim is None:
# Automatically designate the rightmost available dim for allocation.
dim = -1
while -dim <= len(self._stack) and self._stack[-1 - dim] is not None:
dim -= 1
elif dim >= 0:
raise ValueError(
"Expected dim < 0 to index from the right, actual {}".format(dim)
)
raise ValueError(f"Expected dim < 0 to index from the right, actual {dim}")

# Allocate the requested dimension.
while dim < -len(self._stack):
Expand All @@ -64,7 +90,7 @@ def allocate(self, name, dim):
self._stack[-1 - dim] = name
return dim

def free(self, name, dim):
def free(self, name: str, dim: int) -> None:
"""
Free a dimension.
"""
Expand All @@ -88,7 +114,7 @@ class _EnumAllocator:
Note that ids are simply nonnegative integers here.
"""

def set_first_available_dim(self, first_available_dim):
def set_first_available_dim(self, first_available_dim: int) -> None:
"""
Set the first available dim, which should be to the left of all
:class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should
Expand All @@ -98,9 +124,9 @@ def set_first_available_dim(self, first_available_dim):
assert first_available_dim < 0, first_available_dim
self.next_available_dim = first_available_dim
self.next_available_id = 0
self.dim_to_id = {} # only the global ids
self.dim_to_id: Dict[int, int] = {} # only the global ids

def allocate(self, scope_dims=None):
def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int, int]:
"""
Allocate a new recyclable dim and a unique id.

Expand Down Expand Up @@ -146,26 +172,28 @@ class NonlocalExit(Exception):
Used by poutine.EscapeMessenger to return site information.
"""

def __init__(self, site, *args, **kwargs):
def __init__(self, site: Message, *args, **kwargs) -> None:
"""
:param site: message at a pyro site constructor.
Just stores the input site.
"""
super().__init__(*args, **kwargs)
self.site = site

def reset_stack(self):
def reset_stack(self) -> None:
"""
Reset the state of the frames remaining in the stack.
Necessary for multiple re-executions in poutine.queue.
"""
from pyro.poutine.block_messenger import BlockMessenger

for frame in reversed(_PYRO_STACK):
frame._reset()
if type(frame).__name__ == "BlockMessenger" and frame.hide_fn(self.site):
if isinstance(frame, BlockMessenger) and frame.hide_fn(self.site):
break


def default_process_message(msg):
def default_process_message(msg: Message) -> None:
"""
Default method for processing messages in inference.

Expand All @@ -174,15 +202,15 @@ def default_process_message(msg):
"""
if msg["done"] or msg["is_observed"] or msg["value"] is not None:
msg["done"] = True
return msg
return

msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

# after fn has been called, update msg to prevent it from being called again.
msg["done"] = True


def apply_stack(initial_msg):
def apply_stack(initial_msg: Message) -> None:
"""
Execute the effect stack at a single site according to the following scheme:

Expand Down Expand Up @@ -223,8 +251,6 @@ def apply_stack(initial_msg):
if cont is not None:
cont(msg)

return None


def am_i_wrapped():
"""
Expand All @@ -234,7 +260,7 @@ def am_i_wrapped():
return len(_PYRO_STACK) > 0


def effectful(fn=None, type=None):
def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable:
"""
:param fn: function or callable that performs an effectful computation
:param str type: the type label of the operation, e.g. `"sample"`
Expand All @@ -247,7 +273,7 @@ def effectful(fn=None, type=None):
if getattr(fn, "_is_effectful", None):
return fn

assert type is not None, "must provide a type label for operation {}".format(fn)
assert type is not None, f"must provide a type label for operation {fn}"
assert type != "message", "cannot use 'message' as keyword"

@functools.wraps(fn)
Expand Down Expand Up @@ -281,11 +307,11 @@ def _fn(*args, **kwargs):
apply_stack(msg)
return msg["value"]

_fn._is_effectful = True
_fn._is_effectful = True # type: ignore[attr-defined]
return _fn


def _inspect() -> Dict[str, object]:
def _inspect() -> Message:
"""
EXPERIMENTAL Inspect the Pyro stack.

Expand All @@ -295,7 +321,7 @@ def _inspect() -> Dict[str, object]:
:returns: A message with all effects applied.
:rtype: dict
"""
msg = {
msg: Message = {
"type": "inspect",
"name": "_pyro_inspect",
"fn": lambda: True,
Expand All @@ -315,7 +341,7 @@ def _inspect() -> Dict[str, object]:
return msg


def get_mask():
def get_mask() -> Union[bool, torch.Tensor, None]:
"""
Records the effects of enclosing ``poutine.mask`` handlers.

Expand All @@ -335,7 +361,7 @@ def model():
return _inspect()["mask"]


def get_plates() -> tuple:
def get_plates() -> Tuple[CondIndepStackFrame, ...]:
"""
Records the effects of enclosing ``pyro.plate`` contexts.

Expand Down
7 changes: 3 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ column_limit = 120

[mypy]
python_version = 3.8
explicit_package_bases = True
warn_return_any = True
warn_unused_configs = True
warn_incomplete_stub = True
Expand Down Expand Up @@ -77,11 +78,9 @@ warn_unused_ignores = True
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.optm.*]
warn_unused_ignores = True

[mypy-pyro.poutine.*]
[mypy-pyro.optim.*]
ignore_errors = True
warn_unused_ignores = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add warn_unused_ignores to the top level [mypy] config? I've found that some # type: ignores mask real errors on the same line, so I've started using warn_unused_ignores globally, but maybe this practice is too pedantic for a user base with diverse mypy versions.


[mypy-pyro.util.*]
ignore_errors = True
Expand Down
Loading