Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix activity tracking and nudge issues when kernel ports change on restarts #482

Merged
merged 1 commit into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions jupyter_server/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,13 @@ def open(self, kernel_id):
buffer_info = km.get_buffer(kernel_id, self.session_key)
if buffer_info and buffer_info['session_key'] == self.session_key:
self.log.info("Restoring connection for %s", self.session_key)
self.channels = buffer_info['channels']
if km.ports_changed(kernel_id):
# If the kernel's ports have changed (some restarts trigger this)
# then reset the channels so nudge() is using the correct iopub channel
self.create_stream()
else:
# The kernel's ports have not changed; use the channels captured in the buffer
self.channels = buffer_info['channels']

connected = self.nudge()

Expand All @@ -381,15 +387,14 @@ def replay(value):
stream = self.channels[channel]
self._on_zmq_reply(stream, msg_list)


connected.add_done_callback(replay)
else:
try:
self.create_stream()
connected = self.nudge()
except web.HTTPError as e:
self.log.error("Error opening stream: %s", e)
# WebSockets don't response to traditional error codes so we
# WebSockets don't respond to traditional error codes so we
# close the connection.
for channel, stream in self.channels.items():
if not stream.closed():
Expand Down
60 changes: 52 additions & 8 deletions jupyter_server/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _default_kernel_manager_class(self):

_kernel_connections = Dict()

_kernel_ports = Dict()

_culler_callback = None

_initialized_culler = False
Expand Down Expand Up @@ -183,6 +185,7 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
kwargs['cwd'] = self.cwd_for_path(path)
kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs))
self._kernel_connections[kernel_id] = 0
self._kernel_ports[kernel_id] = self._kernels[kernel_id].ports
self.start_watching_activity(kernel_id)
self.log.info("Kernel started: %s" % kernel_id)
self.log.debug("Kernel args: %r" % kwargs)
Expand All @@ -208,6 +211,40 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):

return kernel_id

def ports_changed(self, kernel_id):
"""Used by ZMQChannelsHandler to determine how to coordinate nudge and replays.

Ports are captured when starting a kernel (via MappingKernelManager). Ports
are considered changed (following restarts) if the referenced KernelManager
is using a set of ports different from those captured at startup. If changes
are detected, the captured set is updated and a value of True is returned.

NOTE: Use is exclusive to ZMQChannelsHandler because this object is a singleton
instance while ZMQChannelsHandler instances are per WebSocket connection that
can vary per kernel lifetime.
"""
changed_ports = self._get_changed_ports(kernel_id)
if changed_ports:
# If changed, update captured ports and return True, else return False.
self.log.debug(f"Port change detected for kernel: {kernel_id}")
self._kernel_ports[kernel_id] = changed_ports
return True
return False

def _get_changed_ports(self, kernel_id):
"""Internal method to test if a kernel's ports have changed and, if so, return their values.

This method does NOT update the captured ports for the kernel as that can only be done
by ZMQChannelsHandler, but instead returns the new list of ports if they are different
than those captured at startup. This enables the ability to conditionally restart
activity monitoring immediately following a kernel's restart (if ports have changed).
"""
# Get current ports and return comparison with ports captured at startup.
km = self.get_kernel(kernel_id)
if km.ports != self._kernel_ports[kernel_id]:
return km.ports
return None

def start_buffering(self, kernel_id, session_key, channels):
"""Start buffering messages for a kernel

Expand Down Expand Up @@ -300,10 +337,7 @@ def stop_buffering(self, kernel_id):
def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by kernel_id"""
self._check_kernel_id(kernel_id)
kernel = self._kernels[kernel_id]
if kernel._activity_stream:
kernel._activity_stream.close()
kernel._activity_stream = None
self.stop_watching_activity(kernel_id)
self.stop_buffering(kernel_id)
self._kernel_connections.pop(kernel_id, None)

Expand All @@ -319,6 +353,7 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False):
# method is synchronous. However, we'll keep the relative call orders the same from
# a maintenance perspective.
self._kernel_connections.pop(kernel_id, None)
self._kernel_ports.pop(kernel_id, None)

async def restart_kernel(self, kernel_id, now=False):
"""Restart a kernel by kernel_id"""
Expand Down Expand Up @@ -359,6 +394,10 @@ def on_restart_failed():
channel.on_recv(on_reply)
loop = IOLoop.current()
timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout)
# Re-establish activity watching if ports have changed...
if self._get_changed_ports(kernel_id) is not None:
self.stop_watching_activity(kernel_id)
self.start_watching_activity(kernel_id)
return future

def notify_connect(self, kernel_id):
Expand Down Expand Up @@ -440,6 +479,13 @@ def record_activity(msg_list):

kernel._activity_stream.on_recv(record_activity)

def stop_watching_activity(self, kernel_id):
"""Stop watching IOPub messages on a kernel for activity."""
kernel = self._kernels[kernel_id]
if kernel._activity_stream:
kernel._activity_stream.close()
kernel._activity_stream = None

def initialize_culler(self):
"""Start idle culler if 'cull_idle_timeout' is greater than zero.

Expand Down Expand Up @@ -511,10 +557,7 @@ def __init__(self, **kwargs):
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by kernel_id"""
self._check_kernel_id(kernel_id)
kernel = self._kernels[kernel_id]
if kernel._activity_stream:
kernel._activity_stream.close()
kernel._activity_stream = None
self.stop_watching_activity(kernel_id)
self.stop_buffering(kernel_id)

# Decrease the metric of number of kernels
Expand All @@ -526,4 +569,5 @@ async def shutdown_kernel(self, kernel_id, now=False, restart=False):
# Finish shutting down the kernel before clearing state to avoid a race condition.
ret = await self.pinned_superclass.shutdown_kernel(self, kernel_id, now=now, restart=restart)
self._kernel_connections.pop(kernel_id, None)
self._kernel_ports.pop(kernel_id, None)
return ret
108 changes: 106 additions & 2 deletions jupyter_server/tests/services/sessions/test_api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,47 @@
import sys
import time
import json
import shutil
import pytest

import tornado

from jupyter_client.ioloop import AsyncIOLoopKernelManager

from nbformat.v4 import new_notebook
from nbformat import writes
from traitlets import default

from ...utils import expected_http_error
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_server.utils import url_path_join


j = lambda r: json.loads(r.body.decode())


@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager"])
class NewPortsKernelManager(AsyncIOLoopKernelManager):

@default('cache_ports')
def _default_cache_ports(self) -> bool:
return False

async def restart_kernel(self, now: bool = False, newports: bool = True, **kw) -> None:
self.log.debug(f"DEBUG**** calling super().restart_kernel with newports={newports}")
return await super().restart_kernel(now=now, newports=newports, **kw)


class NewPortsMappingKernelManager(AsyncMappingKernelManager):

@default('kernel_manager_class')
def _default_kernel_manager_class(self):
self.log.debug("NewPortsMappingKernelManager in _default_kernel_manager_class!")
return "jupyter_server.tests.services.sessions.test_api.NewPortsKernelManager"


@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager", "NewPortsMappingKernelManager"])
def jp_argv(request):
if request.param == "NewPortsMappingKernelManager":
return ["--ServerApp.kernel_manager_class=jupyter_server.tests.services.sessions.test_api." + request.param]
return ["--ServerApp.kernel_manager_class=jupyter_server.services.kernels.kernelmanager." + request.param]


Expand Down Expand Up @@ -339,3 +363,83 @@ async def test_modify_kernel_id(session_client, jp_fetch):

# Need to find a better solution to this.
await session_client.cleanup()


async def test_restart_kernel(session_client, jp_base_url, jp_fetch, jp_ws_fetch):

# Create a session.
resp = await session_client.create('foo/nb1.ipynb')
assert resp.code == 201
new_session = j(resp)
assert 'id' in new_session
assert new_session['path'] == 'foo/nb1.ipynb'
assert new_session['type'] == 'notebook'
assert resp.headers['Location'] == url_path_join(jp_base_url, '/api/sessions/', new_session['id'])

kid = new_session['kernel']['id']

# Get kernel info
r = await jp_fetch(
'api', 'kernels', kid,
method='GET'
)
model = json.loads(r.body.decode())
assert model['connections'] == 0

# Open a websocket connection.
ws = await jp_ws_fetch(
'api', 'kernels', kid, 'channels'
)

# Test that it was opened.
r = await jp_fetch(
'api', 'kernels', kid,
method='GET'
)
model = json.loads(r.body.decode())
assert model['connections'] == 1

# Restart kernel
r = await jp_fetch(
'api', 'kernels', kid, 'restart',
method='POST',
allow_nonstandard_methods=True
)
restarted_kernel = json.loads(r.body.decode())
assert restarted_kernel['id'] == kid

# Close/open websocket
ws.close()
# give it some time to close on the other side:
for i in range(10):
r = await jp_fetch(
'api', 'kernels', kid,
method='GET'
)
model = json.loads(r.body.decode())
if model['connections'] > 0:
time.sleep(0.1)
else:
break

r = await jp_fetch(
'api', 'kernels', kid,
method='GET'
)
model = json.loads(r.body.decode())
assert model['connections'] == 0

# Open a websocket connection.
await jp_ws_fetch(
'api', 'kernels', kid, 'channels'
)

r = await jp_fetch(
'api', 'kernels', kid,
method='GET'
)
model = json.loads(r.body.decode())
assert model['connections'] == 1

# Need to find a better solution to this.
await session_client.cleanup()