Skip to content

Commit

Permalink
Use pending kernels (#593)
Browse files Browse the repository at this point in the history
* wip use pending kernels

yup

wip

wip

handle getattr

use ensure_future

add more backwards compat

clean up tests

more cleanup

Fix ping handling for shut down kernels

fix handling of kernel startup

lint

* fix handling of kernel activity

* clean up pre-commit

* make tests less brittle

* make tests less brittle

Co-authored-by: Steven Silvester <ssilvester@apple.com>
  • Loading branch information
blink1073 and Steven Silvester authored Nov 23, 2021
1 parent c5c515b commit 758dba6
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 37 deletions.
4 changes: 4 additions & 0 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def send_ping(self):
self.ping_callback.stop()
return

if self.ws_connection.client_terminated:
self.close()
return

# check for timeout on pong. Make sure that we really have sent a recent ping in
# case the machine with both server and client has been suspended since the last ping.
now = ioloop.IOLoop.current().time()
Expand Down
27 changes: 21 additions & 6 deletions jupyter_server/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,16 @@ def client_fetch(*parts, headers=None, params=None, **kwargs):
@pytest.fixture
def jp_kernelspecs(jp_data_dir):
"""Configures some sample kernelspecs in the Jupyter data directory."""
spec_names = ["sample", "sample 2"]
spec_names = ["sample", "sample 2", "bad"]
for name in spec_names:
sample_kernel_dir = jp_data_dir.joinpath("kernels", name)
sample_kernel_dir.mkdir(parents=True)
# Create kernel json file
sample_kernel_file = sample_kernel_dir.joinpath("kernel.json")
sample_kernel_file.write_text(json.dumps(sample_kernel_json))
kernel_json = sample_kernel_json.copy()
if name == "bad":
kernel_json["argv"] = ["non_existent_path"]
sample_kernel_file.write_text(json.dumps(kernel_json))
# Create resources text
sample_kernel_resources = sample_kernel_dir.joinpath("resource.txt")
sample_kernel_resources.write_text(some_resource)
Expand Down Expand Up @@ -474,12 +477,24 @@ async def _():
terminal_cleanup = jp_serverapp.web_app.settings["terminal_manager"].terminate_all
kernel_cleanup = jp_serverapp.kernel_manager.shutdown_all
if asyncio.iscoroutinefunction(terminal_cleanup):
await terminal_cleanup()
try:
await terminal_cleanup()
except Exception as e:
print(e)
else:
terminal_cleanup()
try:
await terminal_cleanup()
except Exception as e:
print(e)
if asyncio.iscoroutinefunction(kernel_cleanup):
await kernel_cleanup()
try:
await kernel_cleanup()
except Exception as e:
print(e)
else:
kernel_cleanup()
try:
kernel_cleanup()
except Exception as e:
print(e)

return _
16 changes: 15 additions & 1 deletion jupyter_server/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Distributed under the terms of the Modified BSD License.
import json
from textwrap import dedent
from traceback import format_tb

from ipython_genutils.py3compat import cast_unicode
from jupyter_client import protocol_version as client_protocol_version
Expand Down Expand Up @@ -78,7 +79,10 @@ async def post(self, kernel_id, action):
try:
await km.restart_kernel(kernel_id)
except Exception as e:
self.log.error("Exception restarting kernel", exc_info=True)
message = "Exception restarting kernel"
self.log.error(message, exc_info=True)
traceback = format_tb(e.__traceback__)
self.write(json.dumps(dict(message=message, traceback=traceback)))
self.set_status(500)
else:
model = await ensure_async(km.kernel_model(kernel_id))
Expand Down Expand Up @@ -325,6 +329,15 @@ async def pre_get(self):
# We don't want to wait forever, because browsers don't take it well when
# servers never respond to websocket connection requests.
kernel = self.kernel_manager.get_kernel(self.kernel_id)

if hasattr(kernel, "ready"):
try:
await kernel.ready
except Exception as e:
kernel.execution_state = "dead"
kernel.reason = str(e)
raise web.HTTPError(500, str(e)) from e

self.session.key = kernel.session.key
future = self.request_kernel_info()

Expand Down Expand Up @@ -445,6 +458,7 @@ def on_message(self, msg):
def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)

parent = msg["parent_header"]

def write_stderr(error_message):
Expand Down
49 changes: 40 additions & 9 deletions jupyter_server/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
from collections import defaultdict
from datetime import datetime
Expand Down Expand Up @@ -209,16 +210,14 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
kwargs["kernel_id"] = kernel_id
kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs))
self._kernel_connections[kernel_id] = 0
self._kernel_ports[kernel_id] = self._kernels[kernel_id].ports
self.start_watching_activity(kernel_id)
asyncio.ensure_future(self._finish_kernel_start(kernel_id))
# add busy/activity markers:
kernel = self.get_kernel(kernel_id)
kernel.execution_state = "starting"
kernel.reason = ""
kernel.last_activity = utcnow()
self.log.info("Kernel started: %s" % kernel_id)
self.log.debug("Kernel args: %r" % kwargs)
# register callback for failed auto-restart
self.add_restart_callback(
kernel_id,
lambda: self._handle_kernel_died(kernel_id),
"dead",
)

# Increase the metric of number of kernels running
# for the relevant kernel type by 1
Expand All @@ -233,6 +232,24 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):

return kernel_id

async def _finish_kernel_start(self, kernel_id):
km = self.get_kernel(kernel_id)
if hasattr(km, "ready"):
try:
await km.ready
except Exception:
self.log.exception(km.ready.exception())
return

self._kernel_ports[kernel_id] = km.ports
self.start_watching_activity(kernel_id)
# register callback for failed auto-restart
self.add_restart_callback(
kernel_id,
lambda: self._handle_kernel_died(kernel_id),
"dead",
)

def ports_changed(self, kernel_id):
"""Used by ZMQChannelsHandler to determine how to coordinate nudge and replays.
Expand Down Expand Up @@ -448,6 +465,8 @@ def kernel_model(self, kernel_id):
"execution_state": kernel.execution_state,
"connections": self._kernel_connections.get(kernel_id, 0),
}
if getattr(kernel, "reason", None):
model["reason"] = kernel.reason
return model

def list_kernels(self):
Expand Down Expand Up @@ -479,6 +498,7 @@ def start_watching_activity(self, kernel_id):
kernel = self._kernels[kernel_id]
# add busy/activity markers:
kernel.execution_state = "starting"
kernel.reason = ""
kernel.last_activity = utcnow()
kernel._activity_stream = kernel.connect_iopub()
session = Session(
Expand Down Expand Up @@ -507,7 +527,7 @@ def record_activity(msg_list):
def stop_watching_activity(self, kernel_id):
"""Stop watching IOPub messages on a kernel for activity."""
kernel = self._kernels[kernel_id]
if kernel._activity_stream:
if getattr(kernel, "_activity_stream", None):
kernel._activity_stream.close()
kernel._activity_stream = None

Expand Down Expand Up @@ -561,6 +581,17 @@ async def cull_kernels(self):

async def cull_kernel_if_idle(self, kernel_id):
kernel = self._kernels[kernel_id]

if getattr(kernel, "execution_state") == "dead":
self.log.warning(
"Culling '%s' dead kernel '%s' (%s).",
kernel.execution_state,
kernel.kernel_name,
kernel_id,
)
await ensure_async(self.shutdown_kernel(kernel_id))
return

if hasattr(
kernel, "last_activity"
): # last_activity is monkey-patched, so ensure that has occurred
Expand Down
8 changes: 7 additions & 1 deletion jupyter_server/services/sessions/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json

try:
Expand Down Expand Up @@ -78,6 +79,8 @@ async def post(self):
self.set_status(501)
self.finish(json.dumps(dict(message=msg, short_message=status_msg)))
return
except Exception as e:
raise web.HTTPError(500, str(e)) from e

location = url_path_join(self.base_url, "api", "sessions", model["id"])
self.set_header("Location", location)
Expand Down Expand Up @@ -144,7 +147,10 @@ async def patch(self, session_id):
if model["kernel"]["id"] != before["kernel"]["id"]:
# kernel_id changed because we got a new kernel
# shutdown the old one
await ensure_async(km.shutdown_kernel(before["kernel"]["id"]))
fut = asyncio.ensure_future(ensure_async(km.shutdown_kernel(before["kernel"]["id"])))
# If we are not using pending kernels, wait for the kernel to shut down
if not getattr(km, "use_pending_kernels", None):
await fut
self.finish(json.dumps(model, default=json_default))

@web.authenticated
Expand Down
50 changes: 48 additions & 2 deletions jupyter_server/tests/services/kernels/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
import pytest
import tornado
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
from tornado.httpclient import HTTPClientError

from ...utils import expected_http_error
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_server.utils import url_path_join


@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager"])
class TestMappingKernelManager(AsyncMappingKernelManager):
"""A no-op subclass to use in a fixture"""


@pytest.fixture(
params=["MappingKernelManager", "AsyncMappingKernelManager", "TestMappingKernelManager"]
)
def jp_argv(request):
if request.param == "TestMappingKernelManager":
extra = []
if hasattr(AsyncMappingKernelManager, "use_pending_kernels"):
extra = ["--AsyncMappingKernelManager.use_pending_kernels=True"]
return [
"--ServerApp.kernel_manager_class=jupyter_server.tests.services.kernels.test_api."
+ request.param
] + extra
return [
"--ServerApp.kernel_manager_class=jupyter_server.services.kernels.kernelmanager."
+ request.param
Expand Down Expand Up @@ -38,7 +54,7 @@ async def test_default_kernels(jp_fetch, jp_base_url, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesses):
async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesses, jp_serverapp):
# Start the first kernel
r = await jp_fetch(
"api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME})
Expand Down Expand Up @@ -83,6 +99,10 @@ async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesse
assert r.code == 204

# Restart a kernel
kernel = jp_serverapp.kernel_manager.get_kernel(kernel2["id"])
if hasattr(kernel, "ready"):
await kernel.ready

r = await jp_fetch(
"api", "kernels", kernel2["id"], "restart", method="POST", allow_nonstandard_methods=True
)
Expand Down Expand Up @@ -143,6 +163,32 @@ async def test_kernel_handler(jp_fetch, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_kernel_handler_startup_error(
jp_fetch, jp_cleanup_subprocesses, jp_serverapp, jp_kernelspecs
):
if getattr(jp_serverapp.kernel_manager, "use_pending_kernels", False):
return

# Create a kernel
with pytest.raises(HTTPClientError):
await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": "bad"}))


async def test_kernel_handler_startup_error_pending(
jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses, jp_serverapp, jp_kernelspecs
):
if not getattr(jp_serverapp.kernel_manager, "use_pending_kernels", False):
return

jp_serverapp.kernel_manager.use_pending_kernels = True
# Create a kernel
r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": "bad"}))
kid = json.loads(r.body.decode())["id"]

with pytest.raises(HTTPClientError):
await jp_ws_fetch("api", "kernels", kid, "channels")


async def test_connection(
jp_fetch, jp_ws_fetch, jp_http_port, jp_auth_header, jp_cleanup_subprocesses
):
Expand Down
26 changes: 25 additions & 1 deletion jupyter_server/tests/services/kernels/test_cull.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def jp_server_config():
)


async def test_culling(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
async def test_cull_idle(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
kernel = json.loads(r.body.decode())
kid = kernel["id"]
Expand All @@ -53,6 +53,30 @@ async def test_culling(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_cull_dead(
jp_fetch, jp_ws_fetch, jp_serverapp, jp_cleanup_subprocesses, jp_kernelspecs
):
if not hasattr(jp_serverapp.kernel_manager, "use_pending_kernels"):
return

jp_serverapp.kernel_manager.use_pending_kernels = True
jp_serverapp.kernel_manager.default_kernel_name = "bad"
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
kernel = json.loads(r.body.decode())
kid = kernel["id"]

# Open a websocket connection.
with pytest.raises(HTTPClientError):
await jp_ws_fetch("api", "kernels", kid, "channels")

r = await jp_fetch("api", "kernels", kid, method="GET")
model = json.loads(r.body.decode())
assert model["connections"] == 0
culled = await get_cull_status(kid, jp_fetch) # connected, should not be culled
assert culled
await jp_cleanup_subprocesses()


async def get_cull_status(kid, jp_fetch):
frequency = 0.5
culled = False
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/tests/services/kernelspecs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


async def test_list_kernelspecs_bad(jp_fetch, jp_kernelspecs, jp_data_dir):
bad_kernel_dir = jp_data_dir.joinpath(jp_data_dir, "kernels", "bad")
bad_kernel_dir = jp_data_dir.joinpath(jp_data_dir, "kernels", "bad2")
bad_kernel_dir.mkdir(parents=True)
bad_kernel_json = bad_kernel_dir.joinpath("kernel.json")
bad_kernel_json.write_text("garbage")
Expand Down
Loading

0 comments on commit 758dba6

Please sign in to comment.