diff --git a/docs/source/api/jupyter_server.gateway.rst b/docs/source/api/jupyter_server.gateway.rst index 57f2c6235e..c08595e672 100644 --- a/docs/source/api/jupyter_server.gateway.rst +++ b/docs/source/api/jupyter_server.gateway.rst @@ -5,6 +5,12 @@ Submodules ---------- +.. automodule:: jupyter_server.gateway.connections + :members: + :undoc-members: + :show-inheritance: + + .. automodule:: jupyter_server.gateway.gateway_client :members: :undoc-members: diff --git a/jupyter_server/gateway/connections.py b/jupyter_server/gateway/connections.py new file mode 100644 index 0000000000..9a33195cb4 --- /dev/null +++ b/jupyter_server/gateway/connections.py @@ -0,0 +1,175 @@ +"""Gateway connection classes.""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import asyncio +import logging +import random +from typing import Any, cast + +import tornado.websocket as tornado_websocket +from tornado.concurrent import Future +from tornado.escape import json_decode, url_escape, utf8 +from tornado.httpclient import HTTPRequest +from tornado.ioloop import IOLoop +from traitlets import Bool, Instance, Int + +from ..services.kernels.connection.base import BaseKernelWebsocketConnection +from ..utils import url_path_join +from .managers import GatewayClient + + +class GatewayWebSocketConnection(BaseKernelWebsocketConnection): + """Web socket connection that proxies to a kernel/enterprise gateway.""" + + ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True) + + ws_future = Instance(default_value=Future(), klass=Future) + + disconnected = Bool(False) + + retry = Int(0) + + async def connect(self): + """Connect to the socket.""" + # websocket is initialized before connection + self.ws = None + ws_url = url_path_join( + GatewayClient.instance().ws_url, + GatewayClient.instance().kernels_endpoint, + url_escape(self.kernel_id), + "channels", + ) + self.log.info(f"Connecting to {ws_url}") + kwargs: dict = {} + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + + request = HTTPRequest(ws_url, **kwargs) + self.ws_future = cast(Future, tornado_websocket.websocket_connect(request)) + self.ws_future.add_done_callback(self._connection_done) + + loop = IOLoop.current() + loop.add_future(self.ws_future, lambda future: self._read_messages()) + + def _connection_done(self, fut): + """Handle a finished connection.""" + if ( + not self.disconnected and fut.exception() is None + ): # prevent concurrent.futures._base.CancelledError + self.ws = fut.result() + self.retry = 0 + self.log.debug(f"Connection is ready: ws: {self.ws}") + else: + self.log.warning( + "Websocket connection has been closed via client disconnect or due to error. " + "Kernel with ID '{}' may not be terminated on GatewayClient: {}".format( + self.kernel_id, GatewayClient.instance().url + ) + ) + + def disconnect(self): + """Handle a disconnect.""" + self.disconnected = True + if self.ws is not None: + # Close connection + self.ws.close() + elif not self.ws_future.done(): + # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally + self.ws_future.cancel() + self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}") + + async def _read_messages(self): + """Read messages from gateway server.""" + while self.ws is not None: + message = None + if not self.disconnected: + try: + message = await self.ws.read_message() + except Exception as e: + self.log.error( + f"Exception reading message from websocket: {e}" + ) # , exc_info=True) + if message is None: + if not self.disconnected: + self.log.warning(f"Lost connection to Gateway: {self.kernel_id}") + break + self.handle_outgoing_message( + message + ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) + else: # ws cancelled - stop reading + break + + # NOTE(esevan): if websocket is not disconnected by client, try to reconnect. + if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max: + jitter = random.randint(10, 100) * 0.01 # noqa + retry_interval = ( + min( + GatewayClient.instance().gateway_retry_interval * (2**self.retry), + GatewayClient.instance().gateway_retry_interval_max, + ) + + jitter + ) + self.retry += 1 + self.log.info( + "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s", + retry_interval, + self.retry, + GatewayClient.instance().gateway_retry_max, + self.kernel_id, + ) + await asyncio.sleep(retry_interval) + loop = IOLoop.current() + loop.spawn_callback(self.connect) + + def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None: + """Send message to the notebook client.""" + try: + self.websocket_handler.write_message(incoming_msg) + except tornado_websocket.WebSocketClosedError: + if self.log.isEnabledFor(logging.DEBUG): + msg_summary = GatewayWebSocketConnection._get_message_summary( + json_decode(utf8(incoming_msg)) + ) + self.log.debug( + "Notebook client closed websocket connection - message dropped: {}".format( + msg_summary + ) + ) + + def handle_incoming_message(self, message: str) -> None: + """Send message to gateway server.""" + if self.ws is None: + loop = IOLoop.current() + loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message)) + else: + self._write_message(message) + + def _write_message(self, message): + """Send message to gateway server.""" + try: + if not self.disconnected and self.ws is not None: + self.ws.write_message(message) + except Exception as e: + self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True) + + @staticmethod + def _get_message_summary(message): + """Get a summary of a message.""" + summary = [] + message_type = message["msg_type"] + summary.append(f"type: {message_type}") + + if message_type == "status": + summary.append(", state: {}".format(message["content"]["execution_state"])) + elif message_type == "error": + summary.append( + ", {}:{}:{}".format( + message["content"]["ename"], + message["content"]["evalue"], + message["content"]["traceback"], + ) + ) + else: + summary.append(", ...") # don't display potentially sensitive data + + return "".join(summary) diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index df1c82d4c5..ddddcea15b 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -6,6 +6,7 @@ import mimetypes import os import random +import warnings from typing import Optional, cast from jupyter_client.session import Session @@ -21,6 +22,13 @@ from ..utils import url_path_join from .managers import GatewayClient +warnings.warn( + "The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0", + DeprecationWarning, + stacklevel=2, +) + + # Keepalive ping interval (default: 30 seconds) GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30")) diff --git a/jupyter_server/kernelspecs/handlers.py b/jupyter_server/kernelspecs/handlers.py index 611a6f3a9a..301973223a 100644 --- a/jupyter_server/kernelspecs/handlers.py +++ b/jupyter_server/kernelspecs/handlers.py @@ -1,4 +1,6 @@ """Kernelspecs API Handlers.""" +import mimetypes + from jupyter_core.utils import ensure_async from tornado import web @@ -27,6 +29,26 @@ async def get(self, kernel_name, path, include_body=True): ksm = self.kernel_spec_manager if path.lower().endswith(".png"): self.set_header("Cache-Control", f"max-age={60*60*24*30}") + ksm = self.kernel_spec_manager + if hasattr(ksm, "get_kernel_spec_resource"): + # If the kernel spec manager defines a method to get kernelspec resources, + # then use that instead of trying to read from disk. + kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path) + if kernel_spec_res is not None: + # We have to explicitly specify the `absolute_path` attribute so that + # the underlying StaticFileHandler methods can calculate an etag. + self.absolute_path = path + mimetype: str = mimetypes.guess_type(path)[0] or "text/plain" + self.set_header("Content-Type", mimetype) + self.finish(kernel_spec_res) + return + else: + self.log.warning( + "Kernelspec resource '{}' for '{}' not found. Kernel spec manager may" + " not support resource serving. Falling back to reading from disk".format( + path, kernel_name + ) + ) try: kspec = await ensure_async(ksm.get_kernel_spec(kernel_name)) self.root = kspec.resource_dir diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 7cba08b5ac..b236ba7f4b 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -87,6 +87,7 @@ from jupyter_server.extension.config import ExtensionConfigManager from jupyter_server.extension.manager import ExtensionManager from jupyter_server.extension.serverextension import ServerExtensionApp +from jupyter_server.gateway.connections import GatewayWebSocketConnection from jupyter_server.gateway.managers import ( GatewayClient, GatewayKernelSpecManager, @@ -433,17 +434,6 @@ def init_handlers(self, default_services, settings): # And from identity provider handlers.extend(settings["identity_provider"].get_handlers()) - # If gateway mode is enabled, replace appropriate handlers to perform redirection - if GatewayClient.instance().gateway_enabled: - # for each handler required for gateway, locate its pattern - # in the current list and replace that entry... - gateway_handlers = load_handlers("jupyter_server.gateway.handlers") - for _, gwh in enumerate(gateway_handlers): - for j, h in enumerate(handlers): - if gwh[0] == h[0]: - handlers[j] = (gwh[0], gwh[1]) - break - # register base handlers last handlers.extend(load_handlers("jupyter_server.base.handlers")) @@ -796,6 +786,7 @@ class ServerApp(JupyterApp): GatewayMappingKernelManager, GatewayKernelSpecManager, GatewaySessionManager, + GatewayWebSocketConnection, GatewayClient, Authorizer, EventLogger, @@ -1505,12 +1496,17 @@ def _default_session_manager_class(self): return SessionManager kernel_websocket_connection_class = Type( - default_value=ZMQChannelsWebsocketConnection, klass=BaseKernelWebsocketConnection, config=True, help=_i18n("The kernel websocket connection class to use."), ) + @default("kernel_websocket_connection_class") + def _default_kernel_websocket_connection_class(self): + if self.gateway_config.gateway_enabled: + return "jupyter_server.gateway.connections.GatewayWebSocketConnection" + return ZMQChannelsWebsocketConnection + config_manager_class = Type( default_value=ConfigManager, config=True, @@ -2876,7 +2872,19 @@ async def _cleanup(self): self.remove_browser_open_files() await self.cleanup_extensions() await self.cleanup_kernels() - await self.kernel_websocket_connection_class.close_all() + try: + await self.kernel_websocket_connection_class.close_all() + except AttributeError: + # This can happen in two different scenarios: + # + # 1. During tests, where the _cleanup method is invoked without + # the corresponding initialize method having been invoked. + # 2. If the provided `kernel_websocket_connection_class` does not + # implement the `close_all` class method. + # + # In either case, we don't need to do anything and just want to treat + # the raised error as a no-op. + pass if getattr(self, "kernel_manager", None): self.kernel_manager.__del__() if getattr(self, "session_manager", None): diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 86fcf508ca..b69564220d 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -15,13 +15,18 @@ import pytest import tornado from jupyter_core.utils import ensure_async +from tornado.concurrent import Future from tornado.httpclient import HTTPRequest, HTTPResponse +from tornado.httputil import HTTPServerRequest +from tornado.queues import Queue from tornado.web import HTTPError from traitlets import Int, Unicode from traitlets.config import Config +from jupyter_server.gateway.connections import GatewayWebSocketConnection from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager +from jupyter_server.services.kernels.websocket import KernelWebsocketHandler from .utils import expected_http_error @@ -659,6 +664,61 @@ async def test_channel_queue_get_msg_when_response_router_had_finished(): await queue.get_msg() +class MockWebSocketClientConnection(tornado.websocket.WebSocketClientConnection): + def __init__(self, *args, **kwargs): + self._msgs: Queue = Queue(2) + self._msgs.put_nowait('{"msg_type": "status", "content": {"execution_state": "starting"}}') + + def write_message(self, message, *args, **kwargs): + return self._msgs.put(message) + + def read_message(self, *args, **kwargs): + return self._msgs.get() + + +def mock_websocket_connect(): + def helper(request): + fut: Future = Future() + mock_client = MockWebSocketClientConnection() + fut.set_result(mock_client) + return fut + + return helper + + +@patch("tornado.websocket.websocket_connect", mock_websocket_connect()) +async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch, caplog): + # Create the kernel and get the kernel manager... + kernel_id = await create_kernel(jp_fetch, "kspec_foo") + km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id) + + # Create the KernelWebsocketHandler... + request = HTTPServerRequest("foo", "GET") + request.connection = MagicMock() + handler = KernelWebsocketHandler(jp_serverapp.web_app, request) + + # Force the websocket handler to raise a closed error if we try to write a message + # to the client. + handler.ws_connection = MagicMock() + handler.ws_connection.is_closing = lambda: True + + # Create the GatewayWebSocketConnection and attach it to the handler... + conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler) + handler.connection = conn + await conn.connect() + + # Processing websocket messages happens in separate coroutines and any + # errors in that process will show up in logs, but not bubble up to the + # caller. + # + # To check for these, we wait for the server to stop and then check the + # logs for errors. + await jp_serverapp._cleanup() + for _, level, message in caplog.record_tuples: + if level >= logging.ERROR: + pytest.fail(f"Logs contain an error: {message}") + + # # Test methods below... # diff --git a/tests/test_serverapp.py b/tests/test_serverapp.py index 5f6c3eb16e..6f1cbd68ae 100644 --- a/tests/test_serverapp.py +++ b/tests/test_serverapp.py @@ -479,7 +479,7 @@ def test_server_web_application(jp_serverapp): server.kernel_manager, server.config_manager, server.event_logger, - ["jupyter_server.gateway.handlers"], + [], server.log, server.base_url, server.default_url,