Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate poutines #3306

Merged
merged 7 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
Adjusts plates for generating initial values of parameters.
"""
for f in get_plates():
full_size = getattr(f, "full_size", f.size)
full_size = f.full_size or f.size
dim = f.dim - event_dim
if f in self._outer_plates or f.name in self.amortized_plates:
if -value.dim() <= dim:
Expand Down
6 changes: 3 additions & 3 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _create_plates(self, *args, **kwargs):
self.plates = {p.name: p for p in plates}
for name, frame in sorted(self._prototype_frames.items()):
if name not in self.plates:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
self.plates[name] = pyro.plate(
name, full_size, dim=frame.dim, subsample_size=frame.size
)
Expand Down Expand Up @@ -363,7 +363,7 @@ def _setup_prototype(self, *args, **kwargs):

# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
if full_size != frame.size:
dim = frame.dim - event_dim
value = periodic_repeat(value, full_size, dim).contiguous()
Expand Down Expand Up @@ -475,7 +475,7 @@ def _setup_prototype(self, *args, **kwargs):

# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
if full_size != frame.size:
dim = frame.dim - event_dim
init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _tmc_mixture_sample(msg: Message) -> torch.Tensor:
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down Expand Up @@ -72,6 +73,7 @@ def _tmc_diagonal_sample(msg: Message) -> torch.Tensor:
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down
11 changes: 6 additions & 5 deletions pyro/poutine/escape_messenger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger
from .runtime import NonlocalExit
from typing import Callable

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, NonlocalExit


class EscapeMessenger(Messenger):
"""
Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
"""

def __init__(self, escape_fn):
def __init__(self, escape_fn: Callable[[Message], bool]) -> None:
"""
:param escape_fn: function that takes a msg as input and returns True
if the poutine should perform a nonlocal exit at that site.
Expand All @@ -20,7 +22,7 @@ def __init__(self, escape_fn):
super().__init__()
self.escape_fn = escape_fn

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
"""
:param msg: current message at a trace site
:returns: a sample from the stochastic function at the site.
Expand All @@ -38,4 +40,3 @@ def cont(m):
raise NonlocalExit(m)

msg["continuation"] = cont
return None
60 changes: 37 additions & 23 deletions pyro/poutine/indep_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,46 @@
# SPDX-License-Identifier: Apache-2.0

import numbers
from collections import namedtuple
from typing import Iterator, NamedTuple, Optional, Tuple

import torch
from typing_extensions import Self

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import _DIM_ALLOCATOR, Message
from pyro.util import ignore_jit_warnings

from .messenger import Messenger
from .runtime import _DIM_ALLOCATOR

class CondIndepStackFrame(NamedTuple):
name: str
dim: Optional[int]
size: int
counter: int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I tried to use class CondIndepStackFrame(NamedTuple), however, there is code in other places that tries to assign frame.full_size attribute which throws an error because NamedTuple is immutable(?).

Copy link
Member

@fritzo fritzo Dec 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, immutability is a great property to reason about when reading & maintaining a codebase. I'd be a shame to lose that property even as we get type hints. Is there some way we could keep this class clean and immutable, even if that hack gets uglier? Even something like this:

@dataclass(frozen=True, slots=True)  # frozen + slots is close to namedtuple
class CondIndepStackFrame:
    name: str
    dim: Optional[int]
    size: int
    counter: int
    _full_size: Optional[int] = None  # hack: this is actually mutable

    @property
    def full_size(self) -> int:
        return self.size if self._full_size is None else self._full_size

    @full_size.setter
    def full_size(self, value: int) -> None:
        object.__setattr__(self, "_full_size", value)

    # ... __eq__ etc. that ignore .full_size

Copy link
Member Author

@ordabayevy ordabayevy Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried this approach, however, python complained that you cannot use setattr with frame.full_size = size if frozen=True. Also slots is new in Python 3.10. I'll think more about other ways of keeping immutability, don't have any ideas at the moment.

Copy link
Member

@fritzo fritzo Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we kept CondIndepStackFrame as NamedTuple but moved FullSize to a global WeakKeyDictionary? There are only 4 code locations that use the hacky .full_size attribute. We could either change those to access the WeakKeyDictionary directly, or make a property to mediate access:

class CondIndepStackFrame(NamedTuple):
    name: str
    dim: Optional[int]
    size: int
    counter: int

    @property
    def full_size(self) -> Optional[int]:
        return COND_INDEP_FULL_SIZE.get(self)

    @full_size.getter
    def full_size(self, value: int) -> None:
        COND_INDEP_FULL_SIZE[self] = value

COND_INDEP_FULL_SIZE: WeakKeyDictionary[CondIndepStackFrame, int] = WeakKeyDictionary()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And apologies for that full_size hack 😊

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you meant @full_size.setter. Using NamedTuple and WeakKeyDictionary I get this error:

  File "/mnt/disks/dev/repos/pyro/pyro/poutine/subsample_messenger.py", line 154, in _process_message
    frame.full_size = self.size  # Used for param initialization.
  File "/mnt/disks/dev/repos/pyro/pyro/poutine/indep_messenger.py", line 28, in full_size
    COND_INDEP_FULL_SIZE[self] = value
  File "/home/yordabay/anaconda3/envs/pyro/lib/python3.8/weakref.py", line 396, in __setitem__
    self.data[ref(key, self._remove)] = value
TypeError: cannot create weak reference to 'CondIndepStackFrame' object

It seems like you cannot create weakrefs to tuples (https://stackoverflow.com/questions/58312618/is-there-a-way-to-support-weakrefs-with-collections-namedtuple)

I also tried it with a dataclass and using frozen=True but then it doesn't allow to use setter:

  File "/mnt/disks/dev/repos/pyro/pyro/poutine/subsample_messenger.py", line 154, in _process_message
    frame.full_size = self.size  # Used for param initialization.
  File "<string>", line 4, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'full_size'

Copy link
Member Author

@ordabayevy ordabayevy Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like Python really doesn't want to allow setting attributes on immutable objects. It is also strange that an old namedtuple had no problems with that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind if I push a fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! That would be great. I don't have any solution on my end

full_size: Optional[int] = None

class CondIndepStackFrame(
namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])
):
@property
def vectorized(self):
def vectorized(self) -> bool:
return self.dim is not None

def _key(self):
def _key(self) -> Tuple[str, Optional[int], int, int]:
with ignore_jit_warnings(["Converting a tensor to a Python number"]):
size = (
self.size.item() if isinstance(self.size, torch.Tensor) else self.size
self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined]
)
return self.name, self.dim, size, self.counter

def __eq__(self, other):
return type(self) == type(other) and self._key() == other._key()
def __eq__(self, other: object) -> bool:
if not isinstance(other, CondIndepStackFrame):
return False
return self._key() == other._key()

def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __hash__(self):
def __hash__(self) -> int:
return hash(self._key())

def __str__(self):
def __str__(self) -> str:
return self.name


Expand All @@ -59,7 +65,13 @@ class IndepMessenger(Messenger):
"""

def __init__(self, name=None, size=None, dim=None, device=None):
def __init__(
self,
name: str,
size: int,
dim: Optional[int] = None,
device: Optional[str] = None,
):
ordabayevy marked this conversation as resolved.
Show resolved Hide resolved
if not torch._C._get_tracing_state() and size == 0:
raise ZeroDivisionError("size cannot be zero")

Expand All @@ -68,20 +80,20 @@ def __init__(self, name=None, size=None, dim=None, device=None):
if dim is not None:
self._vectorized = True

self._indices = None
self._indices: Optional[torch.Tensor] = None
self.name = name
self.dim = dim
self.size = size
self.device = device
self.counter = 0

def next_context(self):
def next_context(self) -> None:
"""
Increments the counter.
"""
self.counter += 1

def __enter__(self):
def __enter__(self) -> Self:
if self._vectorized is not False:
self._vectorized = True

Expand All @@ -90,12 +102,13 @@ def __enter__(self):

return super().__enter__()

def __exit__(self, *args):
def __exit__(self, *args) -> None:
if self._vectorized is True:
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
return super().__exit__(*args)

def __iter__(self):
def __iter__(self) -> Iterator[int]:
if self._vectorized is True or self.dim is not None:
raise ValueError(
"cannot use plate {} as both vectorized and non-vectorized"
Expand All @@ -110,18 +123,19 @@ def __iter__(self):
with self:
yield i if isinstance(i, numbers.Number) else i.item()

def _reset(self):
def _reset(self) -> None:
if self._vectorized:
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
self._vectorized = None
self.counter = 0

@property
def indices(self):
def indices(self) -> torch.Tensor:
if self._indices is None:
self._indices = torch.arange(self.size, dtype=torch.long).to(self.device)
return self._indices

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
15 changes: 9 additions & 6 deletions pyro/poutine/infer_config_messenger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger
from typing import Callable

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import InferDict, Message


class InferConfigMessenger(Messenger):
Expand All @@ -15,7 +18,7 @@ class InferConfigMessenger(Messenger):
:returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger`
"""

def __init__(self, config_fn):
def __init__(self, config_fn: Callable[[Message], InferDict]):
"""
:param config_fn: a callable taking a site and returning an infer dict
Expand All @@ -25,7 +28,7 @@ def __init__(self, config_fn):
super().__init__()
self.config_fn = config_fn

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
"""
:param msg: current message at a trace site.
Expand All @@ -35,10 +38,10 @@ def _pyro_sample(self, msg):
Otherwise, implements default sampling behavior
with no additional effects.
"""
assert msg["infer"] is not None
msg["infer"].update(self.config_fn(msg))
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
:param msg: current message at a trace site.
Expand All @@ -48,5 +51,5 @@ def _pyro_param(self, msg):
Otherwise, implements default param behavior
with no additional effects.
"""
assert msg["infer"] is not None
msg["infer"].update(self.config_fn(msg))
return None
27 changes: 17 additions & 10 deletions pyro/poutine/lift_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict, Set, Union

from typing_extensions import Self

from pyro import params
from pyro.distributions.distribution import Distribution
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled

from .messenger import Messenger


class LiftMessenger(Messenger):
"""
Expand Down Expand Up @@ -40,7 +43,10 @@ class LiftMessenger(Messenger):
:returns: ``fn`` decorated with a :class:`~pyro.poutine.lift_messenger.LiftMessenger`
"""

def __init__(self, prior):
def __init__(
self,
prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]],
) -> None:
"""
:param prior: prior used to lift parameters. Prior can be of type
dict, pyro.distributions, or a python stochastic fn
Expand All @@ -49,16 +55,16 @@ def __init__(self, prior):
"""
super().__init__()
self.prior = prior
self._samples_cache = {}
self._samples_cache: Dict[str, Message] = {}

def __enter__(self):
def __enter__(self) -> Self:
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
self._param_hits = set()
self._param_misses = set()
self._param_hits: Set[str] = set()
self._param_misses: Set[str] = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs) -> None:
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
extra = set(self.prior) - self._param_hits
Expand All @@ -71,17 +77,18 @@ def __exit__(self, *args, **kwargs):
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
Overrides the `pyro.param` call with samples sampled from the
distribution specified in the prior. The prior can be a
pyro.distributions object or a dict of distributions keyed
on the param names. If the param name does not match the
name the keys in the prior, that param name is unchanged.
"""
assert msg["name"] is not None
name = msg["name"]
param_name = params.user_param_name(name)
if isinstance(self.prior, dict):
Expand Down
10 changes: 6 additions & 4 deletions pyro/poutine/mask_messenger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from .messenger import Messenger
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


class MaskMessenger(Messenger):
Expand All @@ -17,7 +20,7 @@ class MaskMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.MaskMessenger`
"""

def __init__(self, mask):
def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None:
if isinstance(mask, torch.Tensor):
if mask.dtype != torch.bool:
raise ValueError(
Expand All @@ -31,6 +34,5 @@ def __init__(self, mask):
super().__init__()
self.mask = mask

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask
return None
2 changes: 2 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class InferDict(TypedDict, total=False):
is_auxiliary: bool
is_observed: bool
num_samples: int
obs: Optional[torch.Tensor]
prior: TorchDistributionMixin
tmc: Literal["diagonal", "mixture"]
was_observed: bool
_deterministic: bool
_dim_to_symbol: Dict[int, str]
_do_not_trace: bool
Expand Down
Loading
Loading