Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate poutines #3306

Merged
merged 7 commits into from
Dec 18, 2023
Merged

Type annotate poutines #3306

merged 7 commits into from
Dec 18, 2023

Conversation

ordabayevy
Copy link
Member

  • escape_messenger
  • indep_messenger
  • infer_config_messenger
  • lift_messenger
  • mask_messenger
  • uncondition_messenger

pyro/poutine/enum_messenger.py Outdated Show resolved Hide resolved
name: str
dim: Optional[int]
size: int
counter: int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I tried to use class CondIndepStackFrame(NamedTuple), however, there is code in other places that tries to assign frame.full_size attribute which throws an error because NamedTuple is immutable(?).

Copy link
Member

@fritzo fritzo Dec 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, immutability is a great property to reason about when reading & maintaining a codebase. I'd be a shame to lose that property even as we get type hints. Is there some way we could keep this class clean and immutable, even if that hack gets uglier? Even something like this:

@dataclass(frozen=True, slots=True)  # frozen + slots is close to namedtuple
class CondIndepStackFrame:
    name: str
    dim: Optional[int]
    size: int
    counter: int
    _full_size: Optional[int] = None  # hack: this is actually mutable

    @property
    def full_size(self) -> int:
        return self.size if self._full_size is None else self._full_size

    @full_size.setter
    def full_size(self, value: int) -> None:
        object.__setattr__(self, "_full_size", value)

    # ... __eq__ etc. that ignore .full_size

Copy link
Member Author

@ordabayevy ordabayevy Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried this approach, however, python complained that you cannot use setattr with frame.full_size = size if frozen=True. Also slots is new in Python 3.10. I'll think more about other ways of keeping immutability, don't have any ideas at the moment.

Copy link
Member

@fritzo fritzo Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we kept CondIndepStackFrame as NamedTuple but moved FullSize to a global WeakKeyDictionary? There are only 4 code locations that use the hacky .full_size attribute. We could either change those to access the WeakKeyDictionary directly, or make a property to mediate access:

class CondIndepStackFrame(NamedTuple):
    name: str
    dim: Optional[int]
    size: int
    counter: int

    @property
    def full_size(self) -> Optional[int]:
        return COND_INDEP_FULL_SIZE.get(self)

    @full_size.getter
    def full_size(self, value: int) -> None:
        COND_INDEP_FULL_SIZE[self] = value

COND_INDEP_FULL_SIZE: WeakKeyDictionary[CondIndepStackFrame, int] = WeakKeyDictionary()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And apologies for that full_size hack 😊

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you meant @full_size.setter. Using NamedTuple and WeakKeyDictionary I get this error:

  File "/mnt/disks/dev/repos/pyro/pyro/poutine/subsample_messenger.py", line 154, in _process_message
    frame.full_size = self.size  # Used for param initialization.
  File "/mnt/disks/dev/repos/pyro/pyro/poutine/indep_messenger.py", line 28, in full_size
    COND_INDEP_FULL_SIZE[self] = value
  File "/home/yordabay/anaconda3/envs/pyro/lib/python3.8/weakref.py", line 396, in __setitem__
    self.data[ref(key, self._remove)] = value
TypeError: cannot create weak reference to 'CondIndepStackFrame' object

It seems like you cannot create weakrefs to tuples (https://stackoverflow.com/questions/58312618/is-there-a-way-to-support-weakrefs-with-collections-namedtuple)

I also tried it with a dataclass and using frozen=True but then it doesn't allow to use setter:

  File "/mnt/disks/dev/repos/pyro/pyro/poutine/subsample_messenger.py", line 154, in _process_message
    frame.full_size = self.size  # Used for param initialization.
  File "<string>", line 4, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'full_size'

Copy link
Member Author

@ordabayevy ordabayevy Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like Python really doesn't want to allow setting attributes on immutable objects. It is also strange that an old namedtuple had no problems with that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind if I push a fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! That would be great. I don't have any solution on my end

pyro/poutine/indep_messenger.py Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see such careful maintenance work, kudos!

I just have a couple comments about code clarity and immutability.

name: str
dim: Optional[int]
size: int
counter: int
Copy link
Member

@fritzo fritzo Dec 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, immutability is a great property to reason about when reading & maintaining a codebase. I'd be a shame to lose that property even as we get type hints. Is there some way we could keep this class clean and immutable, even if that hack gets uglier? Even something like this:

@dataclass(frozen=True, slots=True)  # frozen + slots is close to namedtuple
class CondIndepStackFrame:
    name: str
    dim: Optional[int]
    size: int
    counter: int
    _full_size: Optional[int] = None  # hack: this is actually mutable

    @property
    def full_size(self) -> int:
        return self.size if self._full_size is None else self._full_size

    @full_size.setter
    def full_size(self, value: int) -> None:
        object.__setattr__(self, "_full_size", value)

    # ... __eq__ etc. that ignore .full_size

pyro/poutine/trace_struct.py Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ordabayevy your changes LGTM. Could you PTAL at my .full_size solution? Feel free to merge when tests pass.

@ordabayevy
Copy link
Member Author

@fritzo LGTM. Thanks for the careful review!

@ordabayevy ordabayevy merged commit 834ff63 into dev Dec 18, 2023
9 checks passed
@ordabayevy ordabayevy deleted the type-poutines branch December 18, 2023 18:13
@ordabayevy ordabayevy mentioned this pull request Feb 5, 2024
23 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants