Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge the gateway handlers into the standard handlers. #1261

Merged
merged 21 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6bb3da4
Merge the gateway handlers into the standard handlers.
ojarjur Apr 20, 2023
ef586e5
Cleanups for the GatewayWebSocketConnection class and related code:
ojarjur Apr 20, 2023
4fd437a
Remove test reference to deprecated gateway handlers package
ojarjur Apr 20, 2023
1a170c5
Restore the gateway handlers module but with a deprecation warning
ojarjur Apr 20, 2023
704029d
Fix a typing error in the KernelSpecResourceHandler class
ojarjur Apr 20, 2023
cabadd7
Fix a bug in GatewayWebsocketConnection setup.
ojarjur Apr 20, 2023
113b529
Update the gateway docs for the new connections module
ojarjur Apr 20, 2023
9681c56
Add a test for gateway websockets logging a stacktrace when the clien…
ojarjur Apr 22, 2023
df75330
Copy logic for handling closed client connections from gateway.WebSoc…
ojarjur Apr 22, 2023
a388115
Fix gateway websocket connection typing errors
ojarjur Apr 22, 2023
4f74cf2
Merge branch 'main' into ojarjur/kernel-servers
ojarjur Apr 24, 2023
4ec6255
Merge branch 'jupyter-server:main' into ojarjur/kernel-servers
ojarjur Apr 25, 2023
6b81720
Merge branch 'main' of https://github.com/jupyter-server/jupyter_serv…
ojarjur Apr 25, 2023
0c77f56
Merge branch 'ojarjur/kernel-servers' of github.com:ojarjur/jupyter_s…
ojarjur Apr 25, 2023
543a5f9
Use traitlets instead of __init__ arguments to configure the GatewayW…
ojarjur Apr 25, 2023
f4dfad1
Update test_gateway.py to incorporate traitlet changes to GatewayWebS…
ojarjur Apr 25, 2023
4fb4323
Merge branch 'main' into ojarjur/kernel-servers
ojarjur May 8, 2023
0383f9a
Merge branch 'jupyter-server:main' into ojarjur/kernel-servers
ojarjur May 9, 2023
de03891
Merge branch 'main' into ojarjur/kernel-servers
blink1073 May 10, 2023
c1c9bc0
Merge branch 'main' into ojarjur/kernel-servers
blink1073 May 10, 2023
995f579
Update to reflect `kernel_ws_protocol` being moved into BaseKernelWeb…
ojarjur May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/api/jupyter_server.gateway.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Submodules
----------


.. automodule:: jupyter_server.gateway.connections
:members:
:undoc-members:
:show-inheritance:


.. automodule:: jupyter_server.gateway.gateway_client
:members:
:undoc-members:
Expand Down
175 changes: 175 additions & 0 deletions jupyter_server/gateway/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Gateway connection classes."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import logging
import random
from typing import Any, cast

import tornado.websocket as tornado_websocket
from tornado.concurrent import Future
from tornado.escape import json_decode, url_escape, utf8
from tornado.httpclient import HTTPRequest
from tornado.ioloop import IOLoop
from traitlets import Bool, Instance, Int

from ..services.kernels.connection.base import BaseKernelWebsocketConnection
from ..utils import url_path_join
from .managers import GatewayClient


class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
"""Web socket connection that proxies to a kernel/enterprise gateway."""

ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)

ws_future = Instance(default_value=Future(), klass=Future)

disconnected = Bool(False)

retry = Int(0)

async def connect(self):
"""Connect to the socket."""
# websocket is initialized before connection
self.ws = None
ws_url = url_path_join(
GatewayClient.instance().ws_url,
GatewayClient.instance().kernels_endpoint,
url_escape(self.kernel_id),
"channels",
)
self.log.info(f"Connecting to {ws_url}")
kwargs: dict = {}
kwargs = GatewayClient.instance().load_connection_args(**kwargs)

request = HTTPRequest(ws_url, **kwargs)
self.ws_future = cast(Future, tornado_websocket.websocket_connect(request))
self.ws_future.add_done_callback(self._connection_done)

loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages())

def _connection_done(self, fut):
"""Handle a finished connection."""
if (
not self.disconnected and fut.exception() is None
): # prevent concurrent.futures._base.CancelledError
self.ws = fut.result()
self.retry = 0
self.log.debug(f"Connection is ready: ws: {self.ws}")
else:
self.log.warning(
"Websocket connection has been closed via client disconnect or due to error. "
"Kernel with ID '{}' may not be terminated on GatewayClient: {}".format(
self.kernel_id, GatewayClient.instance().url
)
)

def disconnect(self):
"""Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
# Close connection
self.ws.close()
elif not self.ws_future.done():
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
self.ws_future.cancel()
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")

async def _read_messages(self):
"""Read messages from gateway server."""
while self.ws is not None:
message = None
if not self.disconnected:
try:
message = await self.ws.read_message()
except Exception as e:
self.log.error(
f"Exception reading message from websocket: {e}"
) # , exc_info=True)
if message is None:
if not self.disconnected:
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
break
self.handle_outgoing_message(
message
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
break

# NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
jitter = random.randint(10, 100) * 0.01 # noqa
retry_interval = (
min(
GatewayClient.instance().gateway_retry_interval * (2**self.retry),
GatewayClient.instance().gateway_retry_interval_max,
)
+ jitter
)
self.retry += 1
self.log.info(
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
retry_interval,
self.retry,
GatewayClient.instance().gateway_retry_max,
self.kernel_id,
)
await asyncio.sleep(retry_interval)
loop = IOLoop.current()
loop.spawn_callback(self.connect)

def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
"""Send message to the notebook client."""
try:
self.websocket_handler.write_message(incoming_msg)
except tornado_websocket.WebSocketClosedError:
if self.log.isEnabledFor(logging.DEBUG):
msg_summary = GatewayWebSocketConnection._get_message_summary(
json_decode(utf8(incoming_msg))
)
self.log.debug(
"Notebook client closed websocket connection - message dropped: {}".format(
msg_summary
)
)

def handle_incoming_message(self, message: str) -> None:
"""Send message to gateway server."""
if self.ws is None:
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
else:
self._write_message(message)

def _write_message(self, message):
"""Send message to gateway server."""
try:
if not self.disconnected and self.ws is not None:
self.ws.write_message(message)
except Exception as e:
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)

@staticmethod
def _get_message_summary(message):
"""Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")

if message_type == "status":
summary.append(", state: {}".format(message["content"]["execution_state"]))
elif message_type == "error":
summary.append(
", {}:{}:{}".format(
message["content"]["ename"],
message["content"]["evalue"],
message["content"]["traceback"],
)
)
else:
summary.append(", ...") # don't display potentially sensitive data

return "".join(summary)
8 changes: 8 additions & 0 deletions jupyter_server/gateway/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mimetypes
import os
import random
import warnings
from typing import Optional, cast

from jupyter_client.session import Session
Expand All @@ -21,6 +22,13 @@
from ..utils import url_path_join
from .managers import GatewayClient

warnings.warn(
"The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
DeprecationWarning,
stacklevel=2,
)


# Keepalive ping interval (default: 30 seconds)
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))

Expand Down
22 changes: 22 additions & 0 deletions jupyter_server/kernelspecs/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Kernelspecs API Handlers."""
import mimetypes

from jupyter_core.utils import ensure_async
from tornado import web

Expand Down Expand Up @@ -27,6 +29,26 @@ async def get(self, kernel_name, path, include_body=True):
ksm = self.kernel_spec_manager
if path.lower().endswith(".png"):
self.set_header("Cache-Control", f"max-age={60*60*24*30}")
ksm = self.kernel_spec_manager
if hasattr(ksm, "get_kernel_spec_resource"):
# If the kernel spec manager defines a method to get kernelspec resources,
# then use that instead of trying to read from disk.
kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
if kernel_spec_res is not None:
# We have to explicitly specify the `absolute_path` attribute so that
# the underlying StaticFileHandler methods can calculate an etag.
self.absolute_path = path
mimetype: str = mimetypes.guess_type(path)[0] or "text/plain"
self.set_header("Content-Type", mimetype)
self.finish(kernel_spec_res)
return
else:
self.log.warning(
"Kernelspec resource '{}' for '{}' not found. Kernel spec manager may"
" not support resource serving. Falling back to reading from disk".format(
path, kernel_name
)
)
try:
kspec = await ensure_async(ksm.get_kernel_spec(kernel_name))
self.root = kspec.resource_dir
Expand Down
34 changes: 21 additions & 13 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from jupyter_server.extension.config import ExtensionConfigManager
from jupyter_server.extension.manager import ExtensionManager
from jupyter_server.extension.serverextension import ServerExtensionApp
from jupyter_server.gateway.connections import GatewayWebSocketConnection
from jupyter_server.gateway.managers import (
GatewayClient,
GatewayKernelSpecManager,
Expand Down Expand Up @@ -433,17 +434,6 @@ def init_handlers(self, default_services, settings):
# And from identity provider
handlers.extend(settings["identity_provider"].get_handlers())

# If gateway mode is enabled, replace appropriate handlers to perform redirection
if GatewayClient.instance().gateway_enabled:
# for each handler required for gateway, locate its pattern
# in the current list and replace that entry...
gateway_handlers = load_handlers("jupyter_server.gateway.handlers")
for _, gwh in enumerate(gateway_handlers):
for j, h in enumerate(handlers):
if gwh[0] == h[0]:
handlers[j] = (gwh[0], gwh[1])
break

# register base handlers last
handlers.extend(load_handlers("jupyter_server.base.handlers"))

Expand Down Expand Up @@ -796,6 +786,7 @@ class ServerApp(JupyterApp):
GatewayMappingKernelManager,
GatewayKernelSpecManager,
GatewaySessionManager,
GatewayWebSocketConnection,
GatewayClient,
Authorizer,
EventLogger,
Expand Down Expand Up @@ -1505,12 +1496,17 @@ def _default_session_manager_class(self):
return SessionManager

kernel_websocket_connection_class = Type(
default_value=ZMQChannelsWebsocketConnection,
klass=BaseKernelWebsocketConnection,
config=True,
help=_i18n("The kernel websocket connection class to use."),
)

@default("kernel_websocket_connection_class")
def _default_kernel_websocket_connection_class(self):
if self.gateway_config.gateway_enabled:
return "jupyter_server.gateway.connections.GatewayWebSocketConnection"
return ZMQChannelsWebsocketConnection

config_manager_class = Type(
default_value=ConfigManager,
config=True,
Expand Down Expand Up @@ -2876,7 +2872,19 @@ async def _cleanup(self):
self.remove_browser_open_files()
await self.cleanup_extensions()
await self.cleanup_kernels()
await self.kernel_websocket_connection_class.close_all()
try:
await self.kernel_websocket_connection_class.close_all()
except AttributeError:
# This can happen in two different scenarios:
#
# 1. During tests, where the _cleanup method is invoked without
# the corresponding initialize method having been invoked.
# 2. If the provided `kernel_websocket_connection_class` does not
# implement the `close_all` class method.
#
# In either case, we don't need to do anything and just want to treat
# the raised error as a no-op.
pass
if getattr(self, "kernel_manager", None):
self.kernel_manager.__del__()
if getattr(self, "session_manager", None):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
import pytest
import tornado
from jupyter_core.utils import ensure_async
from tornado.concurrent import Future
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.httputil import HTTPServerRequest
from tornado.queues import Queue
from tornado.web import HTTPError
from traitlets import Int, Unicode
from traitlets.config import Config

from jupyter_server.gateway.connections import GatewayWebSocketConnection
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler

from .utils import expected_http_error

Expand Down Expand Up @@ -659,6 +664,61 @@ async def test_channel_queue_get_msg_when_response_router_had_finished():
await queue.get_msg()


class MockWebSocketClientConnection(tornado.websocket.WebSocketClientConnection):
def __init__(self, *args, **kwargs):
self._msgs: Queue = Queue(2)
self._msgs.put_nowait('{"msg_type": "status", "content": {"execution_state": "starting"}}')

def write_message(self, message, *args, **kwargs):
return self._msgs.put(message)

def read_message(self, *args, **kwargs):
return self._msgs.get()


def mock_websocket_connect():
def helper(request):
fut: Future = Future()
mock_client = MockWebSocketClientConnection()
fut.set_result(mock_client)
return fut

return helper


@patch("tornado.websocket.websocket_connect", mock_websocket_connect())
async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch, caplog):
# Create the kernel and get the kernel manager...
kernel_id = await create_kernel(jp_fetch, "kspec_foo")
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)

# Create the KernelWebsocketHandler...
request = HTTPServerRequest("foo", "GET")
request.connection = MagicMock()
handler = KernelWebsocketHandler(jp_serverapp.web_app, request)

# Force the websocket handler to raise a closed error if we try to write a message
# to the client.
handler.ws_connection = MagicMock()
handler.ws_connection.is_closing = lambda: True

# Create the GatewayWebSocketConnection and attach it to the handler...
conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler)
handler.connection = conn
await conn.connect()

# Processing websocket messages happens in separate coroutines and any
# errors in that process will show up in logs, but not bubble up to the
# caller.
#
# To check for these, we wait for the server to stop and then check the
# logs for errors.
await jp_serverapp._cleanup()
for _, level, message in caplog.record_tuples:
if level >= logging.ERROR:
pytest.fail(f"Logs contain an error: {message}")


#
# Test methods below...
#
Expand Down
Loading