diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index 873b96d2..7af64aed 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -88,9 +88,6 @@ def start(self): def _abort_queues(self): """The in-process kernel doesn't abort requests.""" - async def _flush_control_queue(self): - """No need to flush control queues for in-process""" - def _input_request(self, prompt, ident, parent, password=False): # Flush output before making the request. self.raw_input_str = None diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index a24e3238..01539fd2 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -5,7 +5,6 @@ from __future__ import annotations import asyncio -import concurrent.futures import inspect import itertools import logging @@ -289,49 +288,16 @@ def __init__(self, **kwargs): for msg_type in self.control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) - self.control_queue: Queue[t.Any] = Queue() - # Storing the accepted parameters for do_execute, used in execute_request self._do_exec_accepted_params = _accepts_parameters( self.do_execute, ["cell_meta", "cell_id"] ) - def dispatch_control(self, msg): - self.control_queue.put_nowait(msg) - - async def poll_control_queue(self): - while True: - msg = await self.control_queue.get() - # handle tracers from _flush_control_queue - if isinstance(msg, (concurrent.futures.Future, asyncio.Future)): - msg.set_result(None) - continue + async def dispatch_control(self, msg): + # Ensure only one control message is processed at a time + async with asyncio.Lock(): await self.process_control(msg) - async def _flush_control_queue(self): - """Flush the control queue, wait for processing of any pending messages""" - tracer_future: concurrent.futures.Future[object] | asyncio.Future[object] - if self.control_thread: - control_loop = self.control_thread.io_loop - # concurrent.futures.Futures are threadsafe - # and can be used to await across threads - tracer_future = concurrent.futures.Future() - awaitable_future = asyncio.wrap_future(tracer_future) - else: - control_loop = self.io_loop - tracer_future = awaitable_future = asyncio.Future() - - def _flush(): - # control_stream.flush puts messages on the queue - if self.control_stream: - self.control_stream.flush() - # put Future on the queue after all of those, - # so we can wait for all queued messages to be processed - self.control_queue.put(tracer_future) - - control_loop.add_callback(_flush) - return awaitable_future - async def process_control(self, msg): """dispatch control requests""" if not self.session: @@ -387,8 +353,6 @@ async def dispatch_shell(self, msg): """dispatch shell requests""" if not self.session: return - # flush control queue before handling shell requests - await self._flush_control_queue() idents, msg = self.session.feed_identities(msg, copy=False) try: @@ -531,6 +495,19 @@ async def process_one(self, wait=True): t, dispatch, args = self.msg_queue.get_nowait() except (asyncio.QueueEmpty, QueueEmpty): return + + if self.control_thread is None and self.control_stream is not None: + # If there isn't a separate control thread then this main thread handles both shell + # and control messages. Before processing a shell message we need to flush all control + # messages and allow them all to be processed. + await asyncio.sleep(0) + self.control_stream.flush() + + socket = self.control_stream.socket + while socket.poll(1): + await asyncio.sleep(0) + self.control_stream.flush() + await dispatch(*args) async def dispatch_queue(self): @@ -578,9 +555,6 @@ def start(self): if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) - control_loop = self.control_thread.io_loop if self.control_thread else self.io_loop - - asyncio.run_coroutine_threadsafe(self.poll_control_queue(), control_loop.asyncio_loop) if self.shell_stream: self.shell_stream.on_recv( partial( diff --git a/pyproject.toml b/pyproject.toml index c2ed3fc4..1bd260c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,14 +24,14 @@ dependencies = [ "ipython>=7.23.1", "comm>=0.1.1", "traitlets>=5.4.0", - "jupyter_client>=6.1.12", + "jupyter_client>=8.0.0", "jupyter_core>=4.12,!=5.0.*", # For tk event loop support only. "nest_asyncio", - "tornado>=6.1", + "tornado>=6.2", "matplotlib-inline>=0.1", 'appnope;platform_system=="Darwin"', - "pyzmq>=24", + "pyzmq>=25", "psutil", "packaging", ] diff --git a/tests/test_ipkernel_direct.py b/tests/test_ipkernel_direct.py index c9201348..037489f3 100644 --- a/tests/test_ipkernel_direct.py +++ b/tests/test_ipkernel_direct.py @@ -164,41 +164,29 @@ def test_dispatch_debugpy(ipkernel: IPythonKernel) -> None: async def test_start(ipkernel: IPythonKernel) -> None: shell_future: asyncio.Future = asyncio.Future() - control_future: asyncio.Future = asyncio.Future() async def fake_dispatch_queue(): shell_future.set_result(None) - async def fake_poll_control_queue(): - control_future.set_result(None) - ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore - ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore ipkernel.start() ipkernel.debugpy_stream = None ipkernel.start() await ipkernel.process_one(False) await shell_future - await control_future async def test_start_no_debugpy(ipkernel: IPythonKernel) -> None: shell_future: asyncio.Future = asyncio.Future() - control_future: asyncio.Future = asyncio.Future() async def fake_dispatch_queue(): shell_future.set_result(None) - async def fake_poll_control_queue(): - control_future.set_result(None) - ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore - ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore ipkernel.debugpy_stream = None ipkernel.start() await shell_future - await control_future def test_create_comm(): diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 31338896..a0bd8334 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -10,6 +10,7 @@ import subprocess import sys import time +from datetime import datetime, timedelta from subprocess import Popen from tempfile import TemporaryDirectory @@ -597,6 +598,47 @@ def test_control_thread_priority(): assert control_dates[-1] <= shell_dates[0] +def test_sequential_control_messages(): + with new_kernel() as kc: + msg_id = kc.execute("import time") + get_reply(kc, msg_id) + + # Send multiple messages on the control channel. + # Using execute messages to vary duration. + sleeps = [0.6, 0.3, 0.1] + + # Prepare messages + msgs = [ + kc.session.msg("execute_request", {"code": f"time.sleep({sleep})"}) for sleep in sleeps + ] + msg_ids = [msg["header"]["msg_id"] for msg in msgs] + + # Submit messages + for msg in msgs: + kc.control_channel.send(msg) + + # Get replies + replies = [get_reply(kc, msg_id, channel="control") for msg_id in msg_ids] + + # Check messages are processed in order, one at a time, and of a sensible duration. + previous_end = None + for reply, sleep in zip(replies, sleeps): + start_str = reply["metadata"]["started"] + if sys.version_info[:2] < (3, 11) and start_str.endswith("Z"): + # Python < 3.11 doesn't support "Z" suffix in datetime.fromisoformat, + # so use alternative timezone format. + # https://github.com/python/cpython/issues/80010 + start_str = start_str[:-1] + "+00:00" + start = datetime.fromisoformat(start_str) + end = reply["header"]["date"] # Already a datetime + + if previous_end is not None: + assert start > previous_end + previous_end = end + + assert end >= start + timedelta(seconds=sleep) + + def _child(): print("in child", os.getpid())