Skip to content

Commit

Permalink
Remove control queue
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 committed Feb 13, 2024
1 parent 1b4eb5e commit fb8e192
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 56 deletions.
3 changes: 0 additions & 3 deletions ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def start(self):
def _abort_queues(self):
"""The in-process kernel doesn't abort requests."""

async def _flush_control_queue(self):
"""No need to flush control queues for in-process"""

def _input_request(self, prompt, ident, parent, password=False):
# Flush output before making the request.
self.raw_input_str = None
Expand Down
55 changes: 14 additions & 41 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import inspect
import itertools
import logging
Expand Down Expand Up @@ -270,6 +269,9 @@ def _parent_header(self):
"usage_request",
]

# Flag to ensure a single control request is processed at a time.
_block_control = False

def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)
Expand All @@ -289,48 +291,24 @@ def __init__(self, **kwargs):
for msg_type in self.control_msg_types:
self.control_handlers[msg_type] = getattr(self, msg_type)

self.control_queue: Queue[t.Any] = Queue()

# Storing the accepted parameters for do_execute, used in execute_request
self._do_exec_accepted_params = _accepts_parameters(
self.do_execute, ["cell_meta", "cell_id"]
)

def dispatch_control(self, msg):
self.control_queue.put_nowait(msg)

async def poll_control_queue(self):
while True:
msg = await self.control_queue.get()
# handle tracers from _flush_control_queue
if isinstance(msg, (concurrent.futures.Future, asyncio.Future)):
msg.set_result(None)
continue
await self.process_control(msg)

async def _flush_control_queue(self):
"""Flush the control queue, wait for processing of any pending messages"""
tracer_future: concurrent.futures.Future[object] | asyncio.Future[object]
if self.control_thread:
control_loop = self.control_thread.io_loop
# concurrent.futures.Futures are threadsafe
# and can be used to await across threads
tracer_future = concurrent.futures.Future()
awaitable_future = asyncio.wrap_future(tracer_future)
else:
control_loop = self.io_loop
tracer_future = awaitable_future = asyncio.Future()
async def dispatch_control(self, msg):
# Ensure only one control message is processed at a time
while self._block_control:
await asyncio.sleep(0)

def _flush():
# control_stream.flush puts messages on the queue
if self.control_stream:
self.control_stream.flush()
# put Future on the queue after all of those,
# so we can wait for all queued messages to be processed
self.control_queue.put(tracer_future)
self._block_control = True

control_loop.add_callback(_flush)
return awaitable_future
try:
await self.process_control(msg)
except:
raise
finally:
self._block_control = False

async def process_control(self, msg):
"""dispatch control requests"""
Expand Down Expand Up @@ -387,8 +365,6 @@ async def dispatch_shell(self, msg):
"""dispatch shell requests"""
if not self.session:
return
# flush control queue before handling shell requests
await self._flush_control_queue()

idents, msg = self.session.feed_identities(msg, copy=False)
try:
Expand Down Expand Up @@ -575,9 +551,6 @@ def start(self):
if self.control_stream:
self.control_stream.on_recv(self.dispatch_control, copy=False)

control_loop = self.control_thread.io_loop if self.control_thread else self.io_loop

asyncio.run_coroutine_threadsafe(self.poll_control_queue(), control_loop.asyncio_loop)
if self.shell_stream:
self.shell_stream.on_recv(
partial(
Expand Down
12 changes: 0 additions & 12 deletions tests/test_ipkernel_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,41 +164,29 @@ def test_dispatch_debugpy(ipkernel: IPythonKernel) -> None:

async def test_start(ipkernel: IPythonKernel) -> None:
shell_future: asyncio.Future = asyncio.Future()
control_future: asyncio.Future = asyncio.Future()

async def fake_dispatch_queue():
shell_future.set_result(None)

async def fake_poll_control_queue():
control_future.set_result(None)

ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore
ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore
ipkernel.start()
ipkernel.debugpy_stream = None
ipkernel.start()
await ipkernel.process_one(False)
await shell_future
await control_future


async def test_start_no_debugpy(ipkernel: IPythonKernel) -> None:
shell_future: asyncio.Future = asyncio.Future()
control_future: asyncio.Future = asyncio.Future()

async def fake_dispatch_queue():
shell_future.set_result(None)

async def fake_poll_control_queue():
control_future.set_result(None)

ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore
ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore
ipkernel.debugpy_stream = None
ipkernel.start()

await shell_future
await control_future


def test_create_comm():
Expand Down
38 changes: 38 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import sys
import time
from datetime import datetime, timedelta
from subprocess import Popen
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -597,6 +598,43 @@ def test_control_thread_priority():
assert control_dates[-1] <= shell_dates[0]


def test_sequential_control_messages():
with new_kernel() as kc:
msg_id = kc.execute("import time")
get_reply(kc, msg_id)

# Send multiple messages on the control channel.
# Using execute messages to vary duration.
sleeps = [0.6, 0.3, 0.1]

# Prepare messages
msgs = [
kc.session.msg("execute_request", {"code": f"time.sleep({sleep})"}) for sleep in sleeps
]
msg_ids = [msg["header"]["msg_id"] for msg in msgs]

# Submit messages
for msg in msgs:
kc.control_channel.send(msg)

# Get replies
replies = [get_reply(kc, msg_id, channel="control") for msg_id in msg_ids]

# Check messages are processed in order, one at a time, and of a sensible duration.
previous_end = None
for reply, sleep in zip(replies, sleeps):
start = datetime.fromisoformat(reply["metadata"]["started"])
end = reply["header"]["date"]
if isinstance(end, str):
end = datetime.fromisoformat(end)

if previous_end is not None:
assert start > previous_end
previous_end = end

assert end >= start + timedelta(seconds=sleep)


def _child():
print("in child", os.getpid())

Expand Down

0 comments on commit fb8e192

Please sign in to comment.