diff --git a/pyro/contrib/minipyro.py b/pyro/contrib/minipyro.py index f0f4452c92..03e52bd8de 100644 --- a/pyro/contrib/minipyro.py +++ b/pyro/contrib/minipyro.py @@ -45,7 +45,7 @@ def __init__(self, fn=None): self.fn = fn # Effect handlers push themselves onto the PYRO_STACK. - # Handlers earlier in the PYRO_STACK are applied first. + # Handlers later in the PYRO_STACK are applied first. def __enter__(self): PYRO_STACK.append(self) @@ -162,17 +162,19 @@ def __iter__(self): # apply_stack is called by pyro.sample and pyro.param. # It is responsible for applying each Messenger to each effectful operation. def apply_stack(msg): + # PYRO_STACK is reversed so that effect handlers higher + # in the stack are first applied. for pointer, handler in enumerate(reversed(PYRO_STACK)): handler.process_message(msg) # When a Messenger sets the "stop" field of a message, - # it prevents any Messengers above it on the stack from being applied. + # it prevents any Messengers below it on the stack from being applied. if msg.get("stop"): break if msg["value"] is None: msg["value"] = msg["fn"](*msg["args"]) # A Messenger that sets msg["stop"] == True also prevents application - # of postprocess_message by Messengers above it on the stack + # of postprocess_message by Messengers below it on the stack # via the pointer variable from the process_message loop for handler in PYRO_STACK[-pointer - 1 :]: handler.postprocess_message(msg)