Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarified comments in minipyro.py to fix #3003 #3004

Merged
merged 2 commits into from
Jan 14, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pyro/contrib/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, fn=None):
self.fn = fn

# Effect handlers push themselves onto the PYRO_STACK.
# Handlers earlier in the PYRO_STACK are applied first.
# Handlers later in the PYRO_STACK are applied first.
fritzo marked this conversation as resolved.
Show resolved Hide resolved
def __enter__(self):
PYRO_STACK.append(self)

Expand Down Expand Up @@ -162,17 +162,19 @@ def __iter__(self):
# apply_stack is called by pyro.sample and pyro.param.
# It is responsible for applying each Messenger to each effectful operation.
def apply_stack(msg):
# PYRO_STACK is reversed so that effect handlers later
fritzo marked this conversation as resolved.
Show resolved Hide resolved
# in the stack are first applied.
for pointer, handler in enumerate(reversed(PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
# it prevents any Messengers below it on the stack from being applied.
if msg.get("stop"):
break
if msg["value"] is None:
msg["value"] = msg["fn"](*msg["args"])

# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# of postprocess_message by Messengers below it on the stack
# via the pointer variable from the process_message loop
for handler in PYRO_STACK[-pointer - 1 :]:
handler.postprocess_message(msg)
Expand Down