From d4fc2cd2cca9163d9ab4c3f65e827f68e06eed29 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 16 Dec 2023 00:00:08 +0000 Subject: [PATCH 1/7] wip --- pyro/poutine/enum_messenger.py | 4 +- pyro/poutine/escape_messenger.py | 11 ++--- pyro/poutine/indep_messenger.py | 60 ++++++++++++++++---------- pyro/poutine/infer_config_messenger.py | 15 ++++--- pyro/poutine/lift_messenger.py | 27 +++++++----- pyro/poutine/markov_messenger.py | 2 +- pyro/poutine/mask_messenger.py | 10 +++-- pyro/poutine/runtime.py | 6 ++- pyro/poutine/trace_struct.py | 6 +-- pyro/poutine/uncondition_messenger.py | 7 +-- 10 files changed, 89 insertions(+), 59 deletions(-) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 7fbcd253c3..2d30800f97 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -22,7 +22,7 @@ def _tmc_mixture_sample(msg: Message) -> torch.Tensor: # find batch dims that aren't plate dims batch_shape = [1] * len(dist.batch_shape) for f in msg["cond_indep_stack"]: - if f.vectorized: + if f.dim is not None: batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim] batch_shape_tuple = tuple(batch_shape) @@ -71,7 +71,7 @@ def _tmc_diagonal_sample(msg: Message) -> torch.Tensor: # find batch dims that aren't plate dims batch_shape = [1] * len(dist.batch_shape) for f in msg["cond_indep_stack"]: - if f.vectorized: + if f.dim is not None: batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim] batch_shape_tuple = tuple(batch_shape) diff --git a/pyro/poutine/escape_messenger.py b/pyro/poutine/escape_messenger.py index fd0bf6a5e6..ab7c499886 100644 --- a/pyro/poutine/escape_messenger.py +++ b/pyro/poutine/escape_messenger.py @@ -1,8 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .messenger import Messenger -from .runtime import NonlocalExit +from typing import Callable + +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message, NonlocalExit class EscapeMessenger(Messenger): @@ -10,7 +12,7 @@ class EscapeMessenger(Messenger): Messenger that does a nonlocal exit by raising a util.NonlocalExit exception """ - def __init__(self, escape_fn): + def __init__(self, escape_fn: Callable[[Message], bool]) -> None: """ :param escape_fn: function that takes a msg as input and returns True if the poutine should perform a nonlocal exit at that site. @@ -20,7 +22,7 @@ def __init__(self, escape_fn): super().__init__() self.escape_fn = escape_fn - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: """ :param msg: current message at a trace site :returns: a sample from the stochastic function at the site. @@ -38,4 +40,3 @@ def cont(m): raise NonlocalExit(m) msg["continuation"] = cont - return None diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index ead1c7d613..d325ae3366 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,40 +2,45 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from collections import namedtuple +from typing import Iterator, NamedTuple, Optional, Tuple import torch +from typing_extensions import Self +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import _DIM_ALLOCATOR, Message from pyro.util import ignore_jit_warnings -from .messenger import Messenger -from .runtime import _DIM_ALLOCATOR +class CondIndepStackFrame(NamedTuple): + name: str + dim: Optional[int] + size: int + counter: int -class CondIndepStackFrame( - namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"]) -): @property - def vectorized(self): + def vectorized(self) -> bool: return self.dim is not None - def _key(self): + def _key(self) -> Tuple[Optional[str], Optional[int], int, int]: with ignore_jit_warnings(["Converting a tensor to a Python number"]): size = ( - self.size.item() if isinstance(self.size, torch.Tensor) else self.size + self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined] ) return self.name, self.dim, size, self.counter - def __eq__(self, other): - return type(self) == type(other) and self._key() == other._key() + def __eq__(self, other: object) -> bool: + if not isinstance(other, CondIndepStackFrame): + return False + return self._key() == other._key() - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self._key()) - def __str__(self): + def __str__(self) -> str: return self.name @@ -59,7 +64,13 @@ class IndepMessenger(Messenger): """ - def __init__(self, name=None, size=None, dim=None, device=None): + def __init__( + self, + name: str, + size: int, + dim: Optional[int] = None, + device: Optional[str] = None, + ): if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") @@ -68,34 +79,36 @@ def __init__(self, name=None, size=None, dim=None, device=None): if dim is not None: self._vectorized = True - self._indices = None + self._indices: Optional[torch.Tensor] = None self.name = name self.dim = dim self.size = size self.device = device self.counter = 0 - def next_context(self): + def next_context(self) -> None: """ Increments the counter. """ self.counter += 1 - def __enter__(self): + def __enter__(self) -> Self: if self._vectorized is not False: self._vectorized = True if self._vectorized is True: + assert self.dim is not None self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() - def __exit__(self, *args): + def __exit__(self, *args) -> None: if self._vectorized is True: + assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) - def __iter__(self): + def __iter__(self) -> Iterator[int]: if self._vectorized is True or self.dim is not None: raise ValueError( "cannot use plate {} as both vectorized and non-vectorized" @@ -110,18 +123,19 @@ def __iter__(self): with self: yield i if isinstance(i, numbers.Number) else i.item() - def _reset(self): + def _reset(self) -> None: if self._vectorized: + assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None self.counter = 0 @property - def indices(self): + def indices(self) -> torch.Tensor: if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) return self._indices - def _process_message(self, msg): + def _process_message(self, msg: Message) -> None: frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] diff --git a/pyro/poutine/infer_config_messenger.py b/pyro/poutine/infer_config_messenger.py index 507ee4f40f..f70e8cbfb8 100644 --- a/pyro/poutine/infer_config_messenger.py +++ b/pyro/poutine/infer_config_messenger.py @@ -1,7 +1,10 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .messenger import Messenger +from typing import Callable + +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import InferDict, Message class InferConfigMessenger(Messenger): @@ -15,7 +18,7 @@ class InferConfigMessenger(Messenger): :returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger` """ - def __init__(self, config_fn): + def __init__(self, config_fn: Callable[[Message], InferDict]): """ :param config_fn: a callable taking a site and returning an infer dict @@ -25,7 +28,7 @@ def __init__(self, config_fn): super().__init__() self.config_fn = config_fn - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: """ :param msg: current message at a trace site. @@ -35,10 +38,10 @@ def _pyro_sample(self, msg): Otherwise, implements default sampling behavior with no additional effects. """ + assert msg["infer"] is not None msg["infer"].update(self.config_fn(msg)) - return None - def _pyro_param(self, msg): + def _pyro_param(self, msg: Message) -> None: """ :param msg: current message at a trace site. @@ -48,5 +51,5 @@ def _pyro_param(self, msg): Otherwise, implements default param behavior with no additional effects. """ + assert msg["infer"] is not None msg["infer"].update(self.config_fn(msg)) - return None diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index 3b72f66534..09703e6825 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -2,13 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from typing import Callable, Dict, Set, Union + +from typing_extensions import Self from pyro import params from pyro.distributions.distribution import Distribution +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled -from .messenger import Messenger - class LiftMessenger(Messenger): """ @@ -40,7 +43,10 @@ class LiftMessenger(Messenger): :returns: ``fn`` decorated with a :class:`~pyro.poutine.lift_messenger.LiftMessenger` """ - def __init__(self, prior): + def __init__( + self, + prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]], + ) -> None: """ :param prior: prior used to lift parameters. Prior can be of type dict, pyro.distributions, or a python stochastic fn @@ -49,16 +55,16 @@ def __init__(self, prior): """ super().__init__() self.prior = prior - self._samples_cache = {} + self._samples_cache: Dict[str, Message] = {} - def __enter__(self): + def __enter__(self) -> Self: self._samples_cache = {} if is_validation_enabled() and isinstance(self.prior, dict): - self._param_hits = set() - self._param_misses = set() + self._param_hits: Set[str] = set() + self._param_misses: Set[str] = set() return super().__enter__() - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs) -> None: self._samples_cache = {} if is_validation_enabled() and isinstance(self.prior, dict): extra = set(self.prior) - self._param_hits @@ -71,10 +77,10 @@ def __exit__(self, *args, **kwargs): ) return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: return None - def _pyro_param(self, msg): + def _pyro_param(self, msg: Message) -> None: """ Overrides the `pyro.param` call with samples sampled from the distribution specified in the prior. The prior can be a @@ -82,6 +88,7 @@ def _pyro_param(self, msg): on the param names. If the param name does not match the name the keys in the prior, that param name is unchanged. """ + assert msg["name"] is not None name = msg["name"] param_name = params.user_param_name(name) if isinstance(self.prior, dict): diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 1d68c9e06a..b1d2fb15b7 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -4,7 +4,7 @@ from collections import Counter from contextlib import ExitStack # python 3 -from .reentrant_messenger import ReentrantMessenger +from pyro.poutine.reentrant_messenger import ReentrantMessenger class MarkovMessenger(ReentrantMessenger): diff --git a/pyro/poutine/mask_messenger.py b/pyro/poutine/mask_messenger.py index 35d9375827..c3c375d8a2 100644 --- a/pyro/poutine/mask_messenger.py +++ b/pyro/poutine/mask_messenger.py @@ -1,9 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import torch -from .messenger import Messenger +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message class MaskMessenger(Messenger): @@ -17,7 +20,7 @@ class MaskMessenger(Messenger): :returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.MaskMessenger` """ - def __init__(self, mask): + def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None: if isinstance(mask, torch.Tensor): if mask.dtype != torch.bool: raise ValueError( @@ -31,6 +34,5 @@ def __init__(self, mask): super().__init__() self.mask = mask - def _process_message(self, msg): + def _process_message(self, msg: Message) -> None: msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask - return None diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index f3dc3146f6..036cec7812 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -50,8 +50,10 @@ class InferDict(TypedDict, total=False): is_auxiliary: bool is_observed: bool num_samples: int + obs: Optional[torch.Tensor] prior: TorchDistributionMixin tmc: Literal["diagonal", "mixture"] + was_observed: bool _deterministic: bool _dim_to_symbol: Dict[int, str] _do_not_trace: bool @@ -98,7 +100,7 @@ def __init__(self) -> None: # in reverse orientation of log_prob.shape self._stack: List[Optional[str]] = [] - def allocate(self, name: str, dim: Optional[int]) -> int: + def allocate(self, name: Optional[str], dim: Optional[int]) -> int: """ Allocate a dimension to an :class:`plate` with given name. Dim should be either None for automatic allocation or a negative @@ -133,7 +135,7 @@ def allocate(self, name: str, dim: Optional[int]) -> int: self._stack[-1 - dim] = name return dim - def free(self, name: str, dim: int) -> None: + def free(self, name: Optional[str], dim: int) -> None: """ Free a dimension. """ diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 029483afbd..41791c5efe 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -388,7 +388,7 @@ def iter_stochastic_nodes(self) -> Iterator[Tuple[str, Message]]: if node["type"] == "sample" and not node["is_observed"]: yield name, node - def symbolize_dims(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> None: + def symbolize_dims(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None: """ Assign unique symbols to all tensor dimensions. """ @@ -401,7 +401,7 @@ def symbolize_dims(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> No # allocate even symbols for plate dims dim_to_symbol: Dict[int, str] = {} for frame in site["cond_indep_stack"]: - if frame.vectorized: + if frame.dim is not None: if frame.name in plate_to_symbol: symbol = plate_to_symbol[frame.name] else: @@ -424,7 +424,7 @@ def symbolize_dims(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> No self.plate_to_symbol = plate_to_symbol self.symbol_to_dim = symbol_to_dim - def pack_tensors(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> None: + def pack_tensors(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None: """ Computes packed representations of tensors in the trace. This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`. diff --git a/pyro/poutine/uncondition_messenger.py b/pyro/poutine/uncondition_messenger.py index 8db8a56ae9..1978ba9a85 100644 --- a/pyro/poutine/uncondition_messenger.py +++ b/pyro/poutine/uncondition_messenger.py @@ -1,7 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .messenger import Messenger +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message class UnconditionMessenger(Messenger): @@ -13,7 +14,7 @@ class UnconditionMessenger(Messenger): def __init__(self): super().__init__() - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: """ :param msg: current message at a trace site. @@ -22,8 +23,8 @@ def _pyro_sample(self, msg): """ if msg["is_observed"]: msg["is_observed"] = False + assert msg["infer"] is not None msg["infer"]["was_observed"] = True msg["infer"]["obs"] = msg["value"] msg["value"] = None msg["done"] = False - return None From 2d684e952feac3224bd6754a0dab1e2ebb520ffa Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 16 Dec 2023 00:10:52 +0000 Subject: [PATCH 2/7] fixes --- pyro/poutine/indep_messenger.py | 2 +- pyro/poutine/runtime.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index d325ae3366..3d68136b9a 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -22,7 +22,7 @@ class CondIndepStackFrame(NamedTuple): def vectorized(self) -> bool: return self.dim is not None - def _key(self) -> Tuple[Optional[str], Optional[int], int, int]: + def _key(self) -> Tuple[str, Optional[int], int, int]: with ignore_jit_warnings(["Converting a tensor to a Python number"]): size = ( self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined] diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 036cec7812..581dd87d03 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -100,7 +100,7 @@ def __init__(self) -> None: # in reverse orientation of log_prob.shape self._stack: List[Optional[str]] = [] - def allocate(self, name: Optional[str], dim: Optional[int]) -> int: + def allocate(self, name: str, dim: Optional[int]) -> int: """ Allocate a dimension to an :class:`plate` with given name. Dim should be either None for automatic allocation or a negative @@ -135,7 +135,7 @@ def allocate(self, name: Optional[str], dim: Optional[int]) -> int: self._stack[-1 - dim] = name return dim - def free(self, name: Optional[str], dim: int) -> None: + def free(self, name: str, dim: int) -> None: """ Free a dimension. """ From f96df2c5ae4455e800973968b6d80cd311a86ac2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 16 Dec 2023 00:16:08 +0000 Subject: [PATCH 3/7] revert markov --- pyro/poutine/markov_messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index b1d2fb15b7..1d68c9e06a 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -4,7 +4,7 @@ from collections import Counter from contextlib import ExitStack # python 3 -from pyro.poutine.reentrant_messenger import ReentrantMessenger +from .reentrant_messenger import ReentrantMessenger class MarkovMessenger(ReentrantMessenger): From 24ea9769e0ebbda893844fc9e5987c260a9eb927 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 16 Dec 2023 00:35:04 +0000 Subject: [PATCH 4/7] pass doctest --- pyro/poutine/indep_messenger.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 3d68136b9a..c224b032fd 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from typing import Iterator, NamedTuple, Optional, Tuple +from dataclasses import dataclass +from typing import Iterator, Optional, Tuple import torch from typing_extensions import Self @@ -12,7 +13,8 @@ from pyro.util import ignore_jit_warnings -class CondIndepStackFrame(NamedTuple): +@dataclass +class CondIndepStackFrame: name: str dim: Optional[int] size: int @@ -97,7 +99,6 @@ def __enter__(self) -> Self: self._vectorized = True if self._vectorized is True: - assert self.dim is not None self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() From 49dbab1e333ec35d9b673bcca8bfc08c9de92880 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 17 Dec 2023 00:19:42 +0000 Subject: [PATCH 5/7] address comments --- pyro/poutine/enum_messenger.py | 6 ++++-- pyro/poutine/indep_messenger.py | 2 +- pyro/poutine/trace_struct.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 2d30800f97..fe87f76d7f 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -22,7 +22,8 @@ def _tmc_mixture_sample(msg: Message) -> torch.Tensor: # find batch dims that aren't plate dims batch_shape = [1] * len(dist.batch_shape) for f in msg["cond_indep_stack"]: - if f.dim is not None: + if f.vectorized: + assert f.dim is not None batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim] batch_shape_tuple = tuple(batch_shape) @@ -71,7 +72,8 @@ def _tmc_diagonal_sample(msg: Message) -> torch.Tensor: # find batch dims that aren't plate dims batch_shape = [1] * len(dist.batch_shape) for f in msg["cond_indep_stack"]: - if f.dim is not None: + if f.vectorized: + assert f.dim is not None batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim] batch_shape_tuple = tuple(batch_shape) diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index c224b032fd..54a3eeec94 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -13,7 +13,7 @@ from pyro.util import ignore_jit_warnings -@dataclass +@dataclass(eq=False) class CondIndepStackFrame: name: str dim: Optional[int] diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 41791c5efe..4c2b58bb23 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -401,7 +401,8 @@ def symbolize_dims(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> No # allocate even symbols for plate dims dim_to_symbol: Dict[int, str] = {} for frame in site["cond_indep_stack"]: - if frame.dim is not None: + if frame.vectorized: + assert frame.dim is not None if frame.name in plate_to_symbol: symbol = plate_to_symbol[frame.name] else: From 9151f4340044636fd983fd104019d61341fd2d18 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 18 Dec 2023 08:59:00 -0800 Subject: [PATCH 6/7] Make CondIndepStackFrame.full_size immutable --- pyro/infer/autoguide/effect.py | 2 +- pyro/infer/autoguide/guides.py | 6 +++--- pyro/poutine/indep_messenger.py | 7 +++---- pyro/poutine/subsample_messenger.py | 3 +-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index d7abe537b0..37ed99351d 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -68,7 +68,7 @@ def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: Adjusts plates for generating initial values of parameters. """ for f in get_plates(): - full_size = getattr(f, "full_size", f.size) + full_size = f.full_size or f.size dim = f.dim - event_dim if f in self._outer_plates or f.name in self.amortized_plates: if -value.dim() <= dim: diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 129b21b1d7..4fb5ab221a 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -138,7 +138,7 @@ def _create_plates(self, *args, **kwargs): self.plates = {p.name: p for p in plates} for name, frame in sorted(self._prototype_frames.items()): if name not in self.plates: - full_size = getattr(frame, "full_size", frame.size) + full_size = frame.full_size or frame.size self.plates[name] = pyro.plate( name, full_size, dim=frame.dim, subsample_size=frame.size ) @@ -363,7 +363,7 @@ def _setup_prototype(self, *args, **kwargs): # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: - full_size = getattr(frame, "full_size", frame.size) + full_size = frame.full_size or frame.size if full_size != frame.size: dim = frame.dim - event_dim value = periodic_repeat(value, full_size, dim).contiguous() @@ -475,7 +475,7 @@ def _setup_prototype(self, *args, **kwargs): # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: - full_size = getattr(frame, "full_size", frame.size) + full_size = frame.full_size or frame.size if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 54a3eeec94..bfb3f3d1ae 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from dataclasses import dataclass -from typing import Iterator, Optional, Tuple +from typing import Iterator, NamedTuple, Optional, Tuple import torch from typing_extensions import Self @@ -13,12 +12,12 @@ from pyro.util import ignore_jit_warnings -@dataclass(eq=False) -class CondIndepStackFrame: +class CondIndepStackFrame(NamedTuple): name: str dim: Optional[int] size: int counter: int + full_size: Optional[int] = None @property def vectorized(self) -> bool: diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 042ebc1d0f..5894944bf6 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -149,9 +149,8 @@ def _reset(self): def _process_message(self, msg): frame = CondIndepStackFrame( - self.name, self.dim, self.subsample_size, self.counter + self.name, self.dim, self.subsample_size, self.counter, self.size ) - frame.full_size = self.size # Used for param initialization. msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] if isinstance(self.size, torch.Tensor) or isinstance( self.subsample_size, torch.Tensor From 6d4a030b54998e2b4b94d152935c3da0f51b48fb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 18 Dec 2023 09:05:44 -0800 Subject: [PATCH 7/7] Preserve comment --- pyro/poutine/subsample_messenger.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 5894944bf6..cc05c332ad 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -149,7 +149,11 @@ def _reset(self): def _process_message(self, msg): frame = CondIndepStackFrame( - self.name, self.dim, self.subsample_size, self.counter, self.size + name=self.name, + dim=self.dim, + size=self.subsample_size, + counter=self.counter, + full_size=self.size, # used for param initialization ) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] if isinstance(self.size, torch.Tensor) or isinstance(