diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 77eca92dc9..b83046d858 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -369,7 +369,13 @@ def open(self, kernel_id): buffer_info = km.get_buffer(kernel_id, self.session_key) if buffer_info and buffer_info['session_key'] == self.session_key: self.log.info("Restoring connection for %s", self.session_key) - self.channels = buffer_info['channels'] + if km.ports_changed(kernel_id): + # If the kernel's ports have changed (some restarts trigger this) + # then reset the channels so nudge() is using the correct iopub channel + self.create_stream() + else: + # The kernel's ports have not changed; use the channels captured in the buffer + self.channels = buffer_info['channels'] connected = self.nudge() @@ -381,7 +387,6 @@ def replay(value): stream = self.channels[channel] self._on_zmq_reply(stream, msg_list) - connected.add_done_callback(replay) else: try: @@ -389,7 +394,7 @@ def replay(value): connected = self.nudge() except web.HTTPError as e: self.log.error("Error opening stream: %s", e) - # WebSockets don't response to traditional error codes so we + # WebSockets don't respond to traditional error codes so we # close the connection. for channel, stream in self.channels.items(): if not stream.closed(): diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index c95fd832bc..b5f14dbde0 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -47,6 +47,8 @@ def _default_kernel_manager_class(self): _kernel_connections = Dict() + _kernel_ports = Dict() + _culler_callback = None _initialized_culler = False @@ -183,6 +185,7 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs): kwargs['cwd'] = self.cwd_for_path(path) kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs)) self._kernel_connections[kernel_id] = 0 + self._kernel_ports[kernel_id] = self._kernels[kernel_id].ports self.start_watching_activity(kernel_id) self.log.info("Kernel started: %s" % kernel_id) self.log.debug("Kernel args: %r" % kwargs) @@ -208,6 +211,40 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs): return kernel_id + def ports_changed(self, kernel_id): + """Used by ZMQChannelsHandler to determine how to coordinate nudge and replays. + + Ports are captured when starting a kernel (via MappingKernelManager). Ports + are considered changed (following restarts) if the referenced KernelManager + is using a set of ports different from those captured at startup. If changes + are detected, the captured set is updated and a value of True is returned. + + NOTE: Use is exclusive to ZMQChannelsHandler because this object is a singleton + instance while ZMQChannelsHandler instances are per WebSocket connection that + can vary per kernel lifetime. + """ + changed_ports = self._get_changed_ports(kernel_id) + if changed_ports: + # If changed, update captured ports and return True, else return False. + self.log.debug(f"Port change detected for kernel: {kernel_id}") + self._kernel_ports[kernel_id] = changed_ports + return True + return False + + def _get_changed_ports(self, kernel_id): + """Internal method to test if a kernel's ports have changed and, if so, return their values. + + This method does NOT update the captured ports for the kernel as that can only be done + by ZMQChannelsHandler, but instead returns the new list of ports if they are different + than those captured at startup. This enables the ability to conditionally restart + activity monitoring immediately following a kernel's restart (if ports have changed). + """ + # Get current ports and return comparison with ports captured at startup. + km = self.get_kernel(kernel_id) + if km.ports != self._kernel_ports[kernel_id]: + return km.ports + return None + def start_buffering(self, kernel_id, session_key, channels): """Start buffering messages for a kernel @@ -300,10 +337,7 @@ def stop_buffering(self, kernel_id): def shutdown_kernel(self, kernel_id, now=False, restart=False): """Shutdown a kernel by kernel_id""" self._check_kernel_id(kernel_id) - kernel = self._kernels[kernel_id] - if kernel._activity_stream: - kernel._activity_stream.close() - kernel._activity_stream = None + self.stop_watching_activity(kernel_id) self.stop_buffering(kernel_id) self._kernel_connections.pop(kernel_id, None) @@ -319,6 +353,7 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False): # method is synchronous. However, we'll keep the relative call orders the same from # a maintenance perspective. self._kernel_connections.pop(kernel_id, None) + self._kernel_ports.pop(kernel_id, None) async def restart_kernel(self, kernel_id, now=False): """Restart a kernel by kernel_id""" @@ -359,6 +394,10 @@ def on_restart_failed(): channel.on_recv(on_reply) loop = IOLoop.current() timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout) + # Re-establish activity watching if ports have changed... + if self._get_changed_ports(kernel_id) is not None: + self.stop_watching_activity(kernel_id) + self.start_watching_activity(kernel_id) return future def notify_connect(self, kernel_id): @@ -440,6 +479,13 @@ def record_activity(msg_list): kernel._activity_stream.on_recv(record_activity) + def stop_watching_activity(self, kernel_id): + """Stop watching IOPub messages on a kernel for activity.""" + kernel = self._kernels[kernel_id] + if kernel._activity_stream: + kernel._activity_stream.close() + kernel._activity_stream = None + def initialize_culler(self): """Start idle culler if 'cull_idle_timeout' is greater than zero. @@ -511,10 +557,7 @@ def __init__(self, **kwargs): async def shutdown_kernel(self, kernel_id, now=False, restart=False): """Shutdown a kernel by kernel_id""" self._check_kernel_id(kernel_id) - kernel = self._kernels[kernel_id] - if kernel._activity_stream: - kernel._activity_stream.close() - kernel._activity_stream = None + self.stop_watching_activity(kernel_id) self.stop_buffering(kernel_id) # Decrease the metric of number of kernels @@ -526,4 +569,5 @@ async def shutdown_kernel(self, kernel_id, now=False, restart=False): # Finish shutting down the kernel before clearing state to avoid a race condition. ret = await self.pinned_superclass.shutdown_kernel(self, kernel_id, now=now, restart=restart) self._kernel_connections.pop(kernel_id, None) + self._kernel_ports.pop(kernel_id, None) return ret diff --git a/jupyter_server/tests/services/sessions/test_api.py b/jupyter_server/tests/services/sessions/test_api.py index 858b97004c..0208880b5f 100644 --- a/jupyter_server/tests/services/sessions/test_api.py +++ b/jupyter_server/tests/services/sessions/test_api.py @@ -1,4 +1,3 @@ -import sys import time import json import shutil @@ -6,18 +5,43 @@ import tornado +from jupyter_client.ioloop import AsyncIOLoopKernelManager + from nbformat.v4 import new_notebook from nbformat import writes +from traitlets import default from ...utils import expected_http_error +from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager from jupyter_server.utils import url_path_join j = lambda r: json.loads(r.body.decode()) -@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager"]) +class NewPortsKernelManager(AsyncIOLoopKernelManager): + + @default('cache_ports') + def _default_cache_ports(self) -> bool: + return False + + async def restart_kernel(self, now: bool = False, newports: bool = True, **kw) -> None: + self.log.debug(f"DEBUG**** calling super().restart_kernel with newports={newports}") + return await super().restart_kernel(now=now, newports=newports, **kw) + + +class NewPortsMappingKernelManager(AsyncMappingKernelManager): + + @default('kernel_manager_class') + def _default_kernel_manager_class(self): + self.log.debug("NewPortsMappingKernelManager in _default_kernel_manager_class!") + return "jupyter_server.tests.services.sessions.test_api.NewPortsKernelManager" + + +@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager", "NewPortsMappingKernelManager"]) def jp_argv(request): + if request.param == "NewPortsMappingKernelManager": + return ["--ServerApp.kernel_manager_class=jupyter_server.tests.services.sessions.test_api." + request.param] return ["--ServerApp.kernel_manager_class=jupyter_server.services.kernels.kernelmanager." + request.param] @@ -339,3 +363,83 @@ async def test_modify_kernel_id(session_client, jp_fetch): # Need to find a better solution to this. await session_client.cleanup() + + +async def test_restart_kernel(session_client, jp_base_url, jp_fetch, jp_ws_fetch): + + # Create a session. + resp = await session_client.create('foo/nb1.ipynb') + assert resp.code == 201 + new_session = j(resp) + assert 'id' in new_session + assert new_session['path'] == 'foo/nb1.ipynb' + assert new_session['type'] == 'notebook' + assert resp.headers['Location'] == url_path_join(jp_base_url, '/api/sessions/', new_session['id']) + + kid = new_session['kernel']['id'] + + # Get kernel info + r = await jp_fetch( + 'api', 'kernels', kid, + method='GET' + ) + model = json.loads(r.body.decode()) + assert model['connections'] == 0 + + # Open a websocket connection. + ws = await jp_ws_fetch( + 'api', 'kernels', kid, 'channels' + ) + + # Test that it was opened. + r = await jp_fetch( + 'api', 'kernels', kid, + method='GET' + ) + model = json.loads(r.body.decode()) + assert model['connections'] == 1 + + # Restart kernel + r = await jp_fetch( + 'api', 'kernels', kid, 'restart', + method='POST', + allow_nonstandard_methods=True + ) + restarted_kernel = json.loads(r.body.decode()) + assert restarted_kernel['id'] == kid + + # Close/open websocket + ws.close() + # give it some time to close on the other side: + for i in range(10): + r = await jp_fetch( + 'api', 'kernels', kid, + method='GET' + ) + model = json.loads(r.body.decode()) + if model['connections'] > 0: + time.sleep(0.1) + else: + break + + r = await jp_fetch( + 'api', 'kernels', kid, + method='GET' + ) + model = json.loads(r.body.decode()) + assert model['connections'] == 0 + + # Open a websocket connection. + await jp_ws_fetch( + 'api', 'kernels', kid, 'channels' + ) + + r = await jp_fetch( + 'api', 'kernels', kid, + method='GET' + ) + model = json.loads(r.body.decode()) + assert model['connections'] == 1 + + # Need to find a better solution to this. + await session_client.cleanup()