Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pending kernels #593

Merged
merged 5 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def send_ping(self):
self.ping_callback.stop()
return

if self.ws_connection.client_terminated:
self.close()
return

# check for timeout on pong. Make sure that we really have sent a recent ping in
# case the machine with both server and client has been suspended since the last ping.
now = ioloop.IOLoop.current().time()
Expand Down
27 changes: 21 additions & 6 deletions jupyter_server/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,16 @@ def client_fetch(*parts, headers=None, params=None, **kwargs):
@pytest.fixture
def jp_kernelspecs(jp_data_dir):
"""Configures some sample kernelspecs in the Jupyter data directory."""
spec_names = ["sample", "sample 2"]
spec_names = ["sample", "sample 2", "bad"]
for name in spec_names:
sample_kernel_dir = jp_data_dir.joinpath("kernels", name)
sample_kernel_dir.mkdir(parents=True)
# Create kernel json file
sample_kernel_file = sample_kernel_dir.joinpath("kernel.json")
sample_kernel_file.write_text(json.dumps(sample_kernel_json))
kernel_json = sample_kernel_json.copy()
if name == "bad":
kernel_json["argv"] = ["non_existent_path"]
sample_kernel_file.write_text(json.dumps(kernel_json))
# Create resources text
sample_kernel_resources = sample_kernel_dir.joinpath("resource.txt")
sample_kernel_resources.write_text(some_resource)
Expand Down Expand Up @@ -474,12 +477,24 @@ async def _():
terminal_cleanup = jp_serverapp.web_app.settings["terminal_manager"].terminate_all
kernel_cleanup = jp_serverapp.kernel_manager.shutdown_all
if asyncio.iscoroutinefunction(terminal_cleanup):
await terminal_cleanup()
try:
await terminal_cleanup()
except Exception as e:
print(e)
else:
terminal_cleanup()
try:
await terminal_cleanup()
except Exception as e:
print(e)
if asyncio.iscoroutinefunction(kernel_cleanup):
await kernel_cleanup()
try:
await kernel_cleanup()
except Exception as e:
print(e)
else:
kernel_cleanup()
try:
kernel_cleanup()
except Exception as e:
print(e)

return _
16 changes: 15 additions & 1 deletion jupyter_server/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Distributed under the terms of the Modified BSD License.
import json
from textwrap import dedent
from traceback import format_tb

from ipython_genutils.py3compat import cast_unicode
from jupyter_client import protocol_version as client_protocol_version
Expand Down Expand Up @@ -78,7 +79,10 @@ async def post(self, kernel_id, action):
try:
await km.restart_kernel(kernel_id)
except Exception as e:
self.log.error("Exception restarting kernel", exc_info=True)
message = "Exception restarting kernel"
self.log.error(message, exc_info=True)
traceback = format_tb(e.__traceback__)
self.write(json.dumps(dict(message=message, traceback=traceback)))
self.set_status(500)
else:
model = await ensure_async(km.kernel_model(kernel_id))
Expand Down Expand Up @@ -325,6 +329,15 @@ async def pre_get(self):
# We don't want to wait forever, because browsers don't take it well when
# servers never respond to websocket connection requests.
kernel = self.kernel_manager.get_kernel(self.kernel_id)

if hasattr(kernel, "ready"):
try:
await kernel.ready
except Exception as e:
kernel.execution_state = "dead"
kernel.reason = str(e)
raise web.HTTPError(500, str(e)) from e

self.session.key = kernel.session.key
future = self.request_kernel_info()

Expand Down Expand Up @@ -445,6 +458,7 @@ def on_message(self, msg):
def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)

parent = msg["parent_header"]

def write_stderr(error_message):
Expand Down
49 changes: 40 additions & 9 deletions jupyter_server/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
from collections import defaultdict
from datetime import datetime
Expand Down Expand Up @@ -209,16 +210,14 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
kwargs["kernel_id"] = kernel_id
kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs))
self._kernel_connections[kernel_id] = 0
self._kernel_ports[kernel_id] = self._kernels[kernel_id].ports
self.start_watching_activity(kernel_id)
asyncio.ensure_future(self._finish_kernel_start(kernel_id))
# add busy/activity markers:
kernel = self.get_kernel(kernel_id)
kernel.execution_state = "starting"
kernel.reason = ""
kernel.last_activity = utcnow()
self.log.info("Kernel started: %s" % kernel_id)
self.log.debug("Kernel args: %r" % kwargs)
# register callback for failed auto-restart
self.add_restart_callback(
kernel_id,
lambda: self._handle_kernel_died(kernel_id),
"dead",
)

# Increase the metric of number of kernels running
# for the relevant kernel type by 1
Expand All @@ -233,6 +232,24 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):

return kernel_id

async def _finish_kernel_start(self, kernel_id):
km = self.get_kernel(kernel_id)
if hasattr(km, "ready"):
try:
await km.ready
except Exception:
self.log.exception(km.ready.exception())
return

self._kernel_ports[kernel_id] = km.ports
self.start_watching_activity(kernel_id)
# register callback for failed auto-restart
self.add_restart_callback(
kernel_id,
lambda: self._handle_kernel_died(kernel_id),
"dead",
)

def ports_changed(self, kernel_id):
"""Used by ZMQChannelsHandler to determine how to coordinate nudge and replays.

Expand Down Expand Up @@ -448,6 +465,8 @@ def kernel_model(self, kernel_id):
"execution_state": kernel.execution_state,
"connections": self._kernel_connections.get(kernel_id, 0),
}
if getattr(kernel, "reason", None):
model["reason"] = kernel.reason
return model

def list_kernels(self):
Expand Down Expand Up @@ -479,6 +498,7 @@ def start_watching_activity(self, kernel_id):
kernel = self._kernels[kernel_id]
# add busy/activity markers:
kernel.execution_state = "starting"
kernel.reason = ""
kernel.last_activity = utcnow()
kernel._activity_stream = kernel.connect_iopub()
session = Session(
Expand Down Expand Up @@ -507,7 +527,7 @@ def record_activity(msg_list):
def stop_watching_activity(self, kernel_id):
"""Stop watching IOPub messages on a kernel for activity."""
kernel = self._kernels[kernel_id]
if kernel._activity_stream:
if getattr(kernel, "_activity_stream", None):
kernel._activity_stream.close()
kernel._activity_stream = None

Expand Down Expand Up @@ -561,6 +581,17 @@ async def cull_kernels(self):

async def cull_kernel_if_idle(self, kernel_id):
kernel = self._kernels[kernel_id]

if getattr(kernel, "execution_state") == "dead":
self.log.warning(
"Culling '%s' dead kernel '%s' (%s).",
kernel.execution_state,
kernel.kernel_name,
kernel_id,
)
await ensure_async(self.shutdown_kernel(kernel_id))
return

if hasattr(
kernel, "last_activity"
): # last_activity is monkey-patched, so ensure that has occurred
Expand Down
8 changes: 7 additions & 1 deletion jupyter_server/services/sessions/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json

try:
Expand Down Expand Up @@ -78,6 +79,8 @@ async def post(self):
self.set_status(501)
self.finish(json.dumps(dict(message=msg, short_message=status_msg)))
return
except Exception as e:
raise web.HTTPError(500, str(e)) from e

location = url_path_join(self.base_url, "api", "sessions", model["id"])
self.set_header("Location", location)
Expand Down Expand Up @@ -144,7 +147,10 @@ async def patch(self, session_id):
if model["kernel"]["id"] != before["kernel"]["id"]:
# kernel_id changed because we got a new kernel
# shutdown the old one
await ensure_async(km.shutdown_kernel(before["kernel"]["id"]))
fut = asyncio.ensure_future(ensure_async(km.shutdown_kernel(before["kernel"]["id"])))
# If we are not using pending kernels, wait for the kernel to shut down
if not getattr(km, "use_pending_kernels", None):
await fut
self.finish(json.dumps(model, default=json_default))

@web.authenticated
Expand Down
50 changes: 48 additions & 2 deletions jupyter_server/tests/services/kernels/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
import pytest
import tornado
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
from tornado.httpclient import HTTPClientError

from ...utils import expected_http_error
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_server.utils import url_path_join


@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager"])
class TestMappingKernelManager(AsyncMappingKernelManager):
"""A no-op subclass to use in a fixture"""


@pytest.fixture(
params=["MappingKernelManager", "AsyncMappingKernelManager", "TestMappingKernelManager"]
)
def jp_argv(request):
if request.param == "TestMappingKernelManager":
extra = []
if hasattr(AsyncMappingKernelManager, "use_pending_kernels"):
extra = ["--AsyncMappingKernelManager.use_pending_kernels=True"]
return [
"--ServerApp.kernel_manager_class=jupyter_server.tests.services.kernels.test_api."
+ request.param
] + extra
return [
"--ServerApp.kernel_manager_class=jupyter_server.services.kernels.kernelmanager."
+ request.param
Expand Down Expand Up @@ -38,7 +54,7 @@ async def test_default_kernels(jp_fetch, jp_base_url, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesses):
async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesses, jp_serverapp):
# Start the first kernel
r = await jp_fetch(
"api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME})
Expand Down Expand Up @@ -83,6 +99,10 @@ async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_cleanup_subprocesse
assert r.code == 204

# Restart a kernel
kernel = jp_serverapp.kernel_manager.get_kernel(kernel2["id"])
if hasattr(kernel, "ready"):
await kernel.ready

r = await jp_fetch(
"api", "kernels", kernel2["id"], "restart", method="POST", allow_nonstandard_methods=True
)
Expand Down Expand Up @@ -143,6 +163,32 @@ async def test_kernel_handler(jp_fetch, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_kernel_handler_startup_error(
jp_fetch, jp_cleanup_subprocesses, jp_serverapp, jp_kernelspecs
):
if getattr(jp_serverapp.kernel_manager, "use_pending_kernels", False):
return

# Create a kernel
with pytest.raises(HTTPClientError):
await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": "bad"}))


async def test_kernel_handler_startup_error_pending(
jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses, jp_serverapp, jp_kernelspecs
):
if not getattr(jp_serverapp.kernel_manager, "use_pending_kernels", False):
return

jp_serverapp.kernel_manager.use_pending_kernels = True
# Create a kernel
r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": "bad"}))
kid = json.loads(r.body.decode())["id"]

with pytest.raises(HTTPClientError):
await jp_ws_fetch("api", "kernels", kid, "channels")


async def test_connection(
jp_fetch, jp_ws_fetch, jp_http_port, jp_auth_header, jp_cleanup_subprocesses
):
Expand Down
26 changes: 25 additions & 1 deletion jupyter_server/tests/services/kernels/test_cull.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def jp_server_config():
)


async def test_culling(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
async def test_cull_idle(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
kernel = json.loads(r.body.decode())
kid = kernel["id"]
Expand All @@ -53,6 +53,30 @@ async def test_culling(jp_fetch, jp_ws_fetch, jp_cleanup_subprocesses):
await jp_cleanup_subprocesses()


async def test_cull_dead(
jp_fetch, jp_ws_fetch, jp_serverapp, jp_cleanup_subprocesses, jp_kernelspecs
):
if not hasattr(jp_serverapp.kernel_manager, "use_pending_kernels"):
return

jp_serverapp.kernel_manager.use_pending_kernels = True
jp_serverapp.kernel_manager.default_kernel_name = "bad"
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
kernel = json.loads(r.body.decode())
kid = kernel["id"]

# Open a websocket connection.
with pytest.raises(HTTPClientError):
await jp_ws_fetch("api", "kernels", kid, "channels")

r = await jp_fetch("api", "kernels", kid, method="GET")
model = json.loads(r.body.decode())
assert model["connections"] == 0
culled = await get_cull_status(kid, jp_fetch) # connected, should not be culled
assert culled
await jp_cleanup_subprocesses()


async def get_cull_status(kid, jp_fetch):
frequency = 0.5
culled = False
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/tests/services/kernelspecs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


async def test_list_kernelspecs_bad(jp_fetch, jp_kernelspecs, jp_data_dir):
bad_kernel_dir = jp_data_dir.joinpath(jp_data_dir, "kernels", "bad")
bad_kernel_dir = jp_data_dir.joinpath(jp_data_dir, "kernels", "bad2")
bad_kernel_dir.mkdir(parents=True)
bad_kernel_json = bad_kernel_dir.joinpath("kernel.json")
bad_kernel_json.write_text("garbage")
Expand Down
Loading