Skip to content

Commit

Permalink
Make outputs go to correct cell when generated in threads/asyncio (#1186
Browse files Browse the repository at this point in the history
)

Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
  • Loading branch information
krassowski and blink1073 authored Jan 16, 2024
1 parent e8185df commit 8495548
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 36 deletions.
105 changes: 69 additions & 36 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import asyncio
import atexit
import contextvars
import io
import os
import sys
import threading
import traceback
import warnings
from binascii import b2a_hex
from collections import deque
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import local
from typing import Any, Callable, Deque, Dict, Optional
Expand Down Expand Up @@ -412,7 +413,7 @@ def __init__(
name : str {'stderr', 'stdout'}
the name of the standard stream to replace
pipe : object
the pip object
the pipe object
echo : bool
whether to echo output
watchfd : bool (default, True)
Expand Down Expand Up @@ -446,13 +447,19 @@ def __init__(
self.pub_thread = pub_thread
self.name = name
self.topic = b"stream." + name.encode()
self.parent_header = {}
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
"parent_header"
)
self._parent_header.set({})
self._thread_to_parent = {}
self._thread_to_parent_header = {}
self._parent_header_global = {}
self._master_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._io_loop = pub_thread.io_loop
self._buffer_lock = threading.RLock()
self._buffer = StringIO()
self._buffers = defaultdict(StringIO)
self.echo = None
self._isatty = bool(isatty)
self._should_watch = False
Expand Down Expand Up @@ -495,6 +502,30 @@ def __init__(
msg = "echo argument must be a file-like object"
raise ValueError(msg)

@property
def parent_header(self):
try:
# asyncio-specific
return self._parent_header.get()
except LookupError:
try:
# thread-specific
identity = threading.current_thread().ident
# retrieve the outermost (oldest ancestor,
# discounting the kernel thread) thread identity
while identity in self._thread_to_parent:
identity = self._thread_to_parent[identity]
# use the header of the oldest ancestor
return self._thread_to_parent_header[identity]
except KeyError:
# global (fallback)
return self._parent_header_global

@parent_header.setter
def parent_header(self, value):
self._parent_header_global = value
return self._parent_header.set(value)

def isatty(self):
"""Return a bool indicating whether this is an 'interactive' stream.
Expand Down Expand Up @@ -598,28 +629,28 @@ def _flush(self):
if self.echo is not sys.__stderr__:
print(f"Flush failed: {e}", file=sys.__stderr__)

data = self._flush_buffer()
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=self.parent_header)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)
for parent, data in self._flush_buffers():
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=parent)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
"""Write to current stream after encoding if necessary
Expand All @@ -630,6 +661,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
number of items from input parameter written to stream.
"""
parent = self.parent_header

if not isinstance(string, str):
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
Expand All @@ -649,7 +681,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
is_child = not self._is_master_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffer.write(string)
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
Expand All @@ -675,19 +707,20 @@ def writable(self):
"""Test whether the stream is writable."""
return True

def _flush_buffer(self):
def _flush_buffers(self):
"""clear the current buffer and return the current buffer data."""
buf = self._rotate_buffer()
data = buf.getvalue()
buf.close()
return data
buffers = self._rotate_buffers()
for frozen_parent, buffer in buffers.items():
data = buffer.getvalue()
buffer.close()
yield dict(frozen_parent), data

def _rotate_buffer(self):
def _rotate_buffers(self):
"""Returns the current buffer and replaces it with an empty buffer."""
with self._buffer_lock:
old_buffer = self._buffer
self._buffer = StringIO()
return old_buffer
old_buffers = self._buffers
self._buffers = defaultdict(StringIO)
return old_buffers

@property
def _hooks(self):
Expand Down
94 changes: 94 additions & 0 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import builtins
import gc
import getpass
import os
import signal
Expand All @@ -14,6 +15,7 @@
import comm
from IPython.core import release
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
from jupyter_client.session import extract_header
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
from zmq.eventloop.zmqstream import ZMQStream

Expand All @@ -22,6 +24,7 @@
from .compiler import XCachingCompiler
from .debugger import Debugger, _is_debugpy_available
from .eventloops import _use_appnope
from .iostream import OutStream
from .kernelbase import Kernel as KernelBase
from .kernelbase import _accepts_parameters
from .zmqshell import ZMQInteractiveShell
Expand Down Expand Up @@ -151,6 +154,14 @@ def __init__(self, **kwargs):

appnope.nope()

self._new_threads_parent_header = {}
self._initialize_thread_hooks()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

help_links = List(
[
{
Expand Down Expand Up @@ -341,6 +352,12 @@ def set_sigint_result():
# restore the previous sigint handler
signal.signal(signal.SIGINT, save_sigint)

async def execute_request(self, stream, ident, parent):
"""Override for cell output - cell reconciliation."""
parent_header = extract_header(parent)
self._associate_new_top_level_threads_with(parent_header)
await super().execute_request(stream, ident, parent)

async def do_execute(
self,
code,
Expand Down Expand Up @@ -706,6 +723,83 @@ def do_clear(self):
self.shell.reset(False)
return dict(status="ok")

def _associate_new_top_level_threads_with(self, parent_header):
"""Store the parent header to associate it with new top-level threads"""
self._new_threads_parent_header = parent_header

def _initialize_thread_hooks(self):
"""Store thread hierarchy and thread-parent_header associations."""
stdout = self._stdout
stderr = self._stderr
kernel_thread_ident = threading.get_ident()
kernel = self
_threading_Thread_run = threading.Thread.run
_threading_Thread__init__ = threading.Thread.__init__

def run_closure(self: threading.Thread):
"""Wrap the `threading.Thread.start` to intercept thread identity.
This is needed because there is no "start" hook yet, but there
might be one in the future: https://bugs.python.org/issue14073
This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""

try:
parent = self._ipykernel_parent_thread_ident # type:ignore[attr-defined]
except AttributeError:
return
for stream in [stdout, stderr]:
if isinstance(stream, OutStream):
if parent == kernel_thread_ident:
stream._thread_to_parent_header[
self.ident
] = kernel._new_threads_parent_header
else:
stream._thread_to_parent[self.ident] = parent
_threading_Thread_run(self)

def init_closure(self: threading.Thread, *args, **kwargs):
_threading_Thread__init__(self, *args, **kwargs)
self._ipykernel_parent_thread_ident = threading.get_ident() # type:ignore[attr-defined]

threading.Thread.__init__ = init_closure # type:ignore[method-assign]
threading.Thread.run = run_closure # type:ignore[method-assign]

def _clean_thread_parent_frames(
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
):
"""Clean parent frames of threads which are no longer running.
This is meant to be invoked by garbage collector callback hook.
The implementation enumerates the threads because there is no "exit" hook yet,
but there might be one in the future: https://bugs.python.org/issue14073
This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""
# Only run before the garbage collector starts
if phase != "start":
return
active_threads = {thread.ident for thread in threading.enumerate()}
for stream in [self._stdout, self._stderr]:
if isinstance(stream, OutStream):
thread_to_parent_header = stream._thread_to_parent_header
for identity in list(thread_to_parent_header.keys()):
if identity not in active_threads:
try:
del thread_to_parent_header[identity]
except KeyError:
pass
thread_to_parent = stream._thread_to_parent
for identity in list(thread_to_parent.keys()):
if identity not in active_threads:
try:
del thread_to_parent[identity]
except KeyError:
pass


# This exists only for backwards compatibility - use IPythonKernel instead

Expand Down
8 changes: 8 additions & 0 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ipykernel.jsonutil import json_clean

from ._version import kernel_protocol_version
from .iostream import OutStream


def _accepts_parameters(meth, param_names):
Expand Down Expand Up @@ -272,6 +273,13 @@ def _parent_header(self):
def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)

# Kernel application may swap stdout and stderr to OutStream,
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
# can already by different from TextIO at initialization time.
self._stdout: OutStream | t.TextIO = sys.stdout
self._stderr: OutStream | t.TextIO = sys.stderr

# Build dict of handlers for message types
self.shell_handlers = {}
for msg_type in self.msg_types:
Expand Down
Loading

0 comments on commit 8495548

Please sign in to comment.