From 3f15bc8c6f2e7899b63b9483df7610ed2b13c972 Mon Sep 17 00:00:00 2001 From: Sigmar Stefansson Date: Wed, 9 Oct 2024 10:32:44 +0000 Subject: [PATCH] original pseudo waiting kernel code --- jupyter_server/base/handlers.py | 38 +++++ jupyter_server/gateway/managers.py | 97 +++++++---- jupyter_server/serverapp.py | 17 ++ .../services/sessions/sessionmanager.py | 151 +++++++++++++----- pyproject.toml | 2 +- 5 files changed, 235 insertions(+), 70 deletions(-) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 770fff1866..c662c5513d 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -11,7 +11,10 @@ import mimetypes import os import re +import sys import types +import typing as ty +import logging import warnings from http.client import responses from logging import Logger @@ -61,6 +64,7 @@ # ----------------------------------------------------------------------------- _sys_info_cache = None +_my_globals = "myglobals" def json_sys_info(): @@ -79,9 +83,36 @@ def log() -> Logger: return app_log +def get_token_value(request: ty.Any, prev: str) -> str: + header = "Authorization" + if header not in request.headers: + logging.error(f'Header "{header}" is missing') + return prev + logging.debug(f'Getting value from header "{header}"') + auth_header_value: str = request.headers[header] + if len(auth_header_value) == 0: + logging.error(f'Header "{header}" is empty') + return prev + + try: + logging.info(f"Auth header value: {auth_header_value}") + # We expect the header value to be of the form "Bearer: XXX" + return auth_header_value.split(" ", maxsplit=1)[1] + except Exception as e: + logging.error(f"Could not read token from auth header: {str(e)}") + + return prev + + class AuthenticatedHandler(web.RequestHandler): """A RequestHandler with an authenticated user.""" + def prepare(self): + if _my_globals not in sys.modules: + sys.modules[_my_globals] = types.ModuleType(_my_globals) + prevtoken = sys.modules[_my_globals].token if hasattr(sys.modules[_my_globals], "token") else "" + sys.modules[_my_globals].token = get_token_value(self.request, prevtoken) + @property def base_url(self) -> str: return cast(str, self.settings.get("base_url", "/")) @@ -1175,6 +1206,13 @@ def get(self) -> None: self.set_header("Content-Type", prometheus_client.CONTENT_TYPE_LATEST) self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY)) +def get_current_token(): + """ + Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed. + """ + if _my_globals in sys.modules and hasattr(sys.modules[_my_globals], "token"): + return sys.modules[_my_globals].token + return "" class PublicStaticFileHandler(web.StaticFileHandler): """Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required.""" diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 0ac47f8f57..bb04076478 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from logging import Logger +_local_kernels: dict[str, ServerKernelManager] = {} class GatewayMappingKernelManager(AsyncMappingKernelManager): """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" @@ -78,20 +79,38 @@ async def start_kernel(self, *, kernel_id=None, path=None, **kwargs): The API path (unicode, '/' delimited) for the cwd. Will be transformed to an OS path relative to root_dir. """ - self.log.info(f"Request start kernel: kernel_id={kernel_id}, path='{path}'") - - if kernel_id is None and path is not None: - kwargs["cwd"] = self.cwd_for_path(path) - - km = self.kernel_manager_factory(parent=self, log=self.log) - await km.start_kernel(kernel_id=kernel_id, **kwargs) - kernel_id = km.kernel_id - self._kernels[kernel_id] = km - # Initialize culling if not already - if not self._initialized_culler: - self.initialize_culler() - - return kernel_id + kernel_name = kwargs.get("kernel_name") + if kernel_name == 'python3' or kernel_name.startswith('sc-'): + kwargs["kernel_name"] = "python3" + kwargs["local"] = True + env = kwargs["env"] + if kernel_name.startswith('sc-'): + app = kernel_name + startup_file = f"/tmp/{app}.py" + startup_content = f"""from pyspark.sql import SparkSession +spark = SparkSession.builder.appName('{kernel_name}').remote('sc://{app}-driver-svc.spark-apps.svc.cluster.local').getOrCreate() +""" + env["PYTHONSTARTUP"] = startup_file + with open(startup_file, "w") as f: + f.write(startup_content) + kernel_id = await super().start_kernel(kernel_id=kernel_id, path=path, **kwargs) + _local_kernels[kernel_id] = self._kernels[kernel_id] + return kernel_id + else: + self.log.info(f"Request start kernel: kernel_id={kernel_id}, path='{path}'") + + if kernel_id is None and path is not None: + kwargs["cwd"] = self.cwd_for_path(path) + + km = self.kernel_manager_factory(parent=self, log=self.log) + await km.start_kernel(kernel_id=kernel_id, **kwargs) + kernel_id = km.kernel_id + self._kernels[kernel_id] = km + # Initialize culling if not already + if not self._initialized_culler: + self.initialize_culler() + + return kernel_id async def kernel_model(self, kernel_id): """Return a dictionary of kernel information described in the @@ -102,11 +121,17 @@ async def kernel_model(self, kernel_id): kernel_id : uuid The uuid of the kernel. """ - model = None - km = self.get_kernel(str(kernel_id)) - if km: # type:ignore[truthy-bool] - model = km.kernel # type:ignore[attr-defined] - return model + if kernel_id in _local_kernels: + str_kernel_id = str(kernel_id) + model = super().kernel_model(kernel_id) + _local_kernels[str_kernel_id] = model + return model + else: + model = None + km = self.get_kernel(str(kernel_id)) + if km: # type:ignore[truthy-bool] + model = km.kernel # type:ignore[attr-defined] + return model async def list_kernels(self, **kwargs): """Get a list of running kernels from the Gateway server. @@ -437,19 +462,22 @@ async def refresh_model(self, model=None): model is fetched from the Gateway server. """ if model is None: - self.log.debug("Request kernel at: %s" % self.kernel_url) - try: - response = await gateway_request(self.kernel_url, method="GET") - - except web.HTTPError as error: - if error.status_code == 404: - self.log.warning("Kernel not found at: %s" % self.kernel_url) - model = None - else: - raise + if self.kernel_id in _local_kernels: + model = _local_kernels[self.kernel_id] else: - model = json_decode(response.body) - self.log.debug("Kernel retrieved: %s" % model) + self.log.debug("Request kernel at: %s" % self.kernel_url) + try: + response = await gateway_request(self.kernel_url, method="GET") + + except web.HTTPError as error: + if error.status_code == 404: + self.log.warning("Kernel not found at: %s" % self.kernel_url) + model = None + else: + raise + else: + model = json_decode(response.body) + self.log.debug("Kernel retrieved: %s" % model) if model: # Update activity markers self.last_activity = datetime.datetime.strptime( @@ -484,6 +512,13 @@ async def start_kernel(self, **kwargs): """ kernel_id = kwargs.get("kernel_id") + if "local" in kwargs: + kwargs.pop("local") + self.kernel_id = kernel_id + self.kernel_url = url_path_join(self.kernels_url, url_escape(str(self.kernel_id))) + self.kernel = await self.refresh_model() + await super().start_kernel(**kwargs) + if kernel_id is None: kernel_name = kwargs.get("kernel_name", "python3") self.log.debug("Request new kernel at: %s" % self.kernels_url) diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 3dedd5634f..3c84b9cfe6 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -251,6 +251,7 @@ def __init__( authorizer=None, identity_provider=None, kernel_websocket_connection_class=None, + local_kernel_websocket_connection_class=None, websocket_ping_interval=None, websocket_ping_timeout=None, ): @@ -290,6 +291,7 @@ def __init__( authorizer=authorizer, identity_provider=identity_provider, kernel_websocket_connection_class=kernel_websocket_connection_class, + local_kernel_websocket_connection_class=local_kernel_websocket_connection_class, websocket_ping_interval=websocket_ping_interval, websocket_ping_timeout=websocket_ping_timeout, ) @@ -357,6 +359,7 @@ def init_settings( authorizer=None, identity_provider=None, kernel_websocket_connection_class=None, + local_kernel_websocket_connection_class=None, websocket_ping_interval=None, websocket_ping_timeout=None, ): @@ -442,6 +445,7 @@ def init_settings( "identity_provider": identity_provider, "event_logger": event_logger, "kernel_websocket_connection_class": kernel_websocket_connection_class, + "local_kernel_websocket_connection_class": local_kernel_websocket_connection_class, "websocket_ping_interval": websocket_ping_interval, "websocket_ping_timeout": websocket_ping_timeout, # handlers @@ -1630,6 +1634,12 @@ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: help=_i18n("The kernel websocket connection class to use."), ) + local_kernel_websocket_connection_class = Type( + klass=BaseKernelWebsocketConnection, + config=True, + help=_i18n("The local kernel websocket connection class to use."), + ) + @default("kernel_websocket_connection_class") def _default_kernel_websocket_connection_class( self, @@ -1638,6 +1648,12 @@ def _default_kernel_websocket_connection_class( return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection + @default("local_kernel_websocket_connection_class") + def _default_local_kernel_websocket_connection_class( + self, + ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]: + return ZMQChannelsWebsocketConnection + websocket_ping_interval = Integer( config=True, help=""" @@ -2252,6 +2268,7 @@ def init_webapp(self) -> None: authorizer=self.authorizer, identity_provider=self.identity_provider, kernel_websocket_connection_class=self.kernel_websocket_connection_class, + local_kernel_websocket_connection_class=self.local_kernel_websocket_connection_class, websocket_ping_interval=self.websocket_ping_interval, websocket_ping_timeout=self.websocket_ping_timeout, ) diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 8b392b4e1b..69a1763c3a 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -5,8 +5,11 @@ import os import pathlib import uuid +from asyncio import Task from typing import Any, Dict, List, NewType, Optional, Union, cast +from requests import session + KernelName = NewType("KernelName", str) ModelName = NewType("ModelName", str) @@ -16,6 +19,7 @@ # fallback on pysqlite2 if Python was build without sqlite from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef] +import asyncio from dataclasses import dataclass, fields from jupyter_core.utils import ensure_async @@ -210,6 +214,8 @@ def __init__(self, *args, **kwargs): _connection = None _columns = {"session_id", "path", "name", "type", "kernel_id"} + fut_kernel_id_dict: Optional[Dict[str, Task[str]]] = None + @property def cursor(self): """Start a cursor and create a database called 'session'""" @@ -267,6 +273,7 @@ async def create_session( type: Optional[str] = None, kernel_name: Optional[KernelName] = None, kernel_id: Optional[str] = None, + session_id: Optional[str] = None, ) -> Dict[str, Any]: """Creates a session and returns its model @@ -276,7 +283,13 @@ async def create_session( Usually the model name, like the filename associated with current kernel. """ - session_id = self.new_session_id() + + if session_id is not None and self.fut_kernel_id_dict is None: + self.fut_kernel_id_dict = {} + + if session_id is None or session_id == "": + session_id = self.new_session_id() + record = KernelSessionRecord(session_id=session_id) self._pending_sessions.update(record) if kernel_id is not None and kernel_id in self.kernel_manager: @@ -312,6 +325,21 @@ def get_kernel_env( assert isinstance(path, str) return {**os.environ, "JPY_SESSION_NAME": path} + async def start_kernel_async( + self, + session_id: str, + path: Optional[str], + kernel_name: Optional[KernelName], + ): + kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path)) + kernel_env = self.get_kernel_env(path) + self.log.info(f"starting kernel ${path} ${kernel_name} ${kernel_path} ${kernel_env}") + self.fut_kernel_id_dict[session_id] = asyncio.create_task(self.kernel_manager.start_kernel( + path=kernel_path, + kernel_name=kernel_name, + env=kernel_env, + )) + async def start_kernel_for_session( self, session_id: str, @@ -337,15 +365,57 @@ async def start_kernel_for_session( the name of the kernel specification to use. The default kernel name will be used if not provided. """ # allow contents manager to specify kernels cwd - kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path)) + if self.fut_kernel_id_dict is not None: + if session_id in self.fut_kernel_id_dict: + fut_kernel_id = self.fut_kernel_id_dict[session_id] + if fut_kernel_id.done(): + kernel_id = await fut_kernel_id + self.fut_kernel_id_dict.pop(session_id) + return kernel_id + else: + await self.start_kernel_async(session_id, path, kernel_name) + kernel_id = "waiting" + else: + kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path)) - kernel_env = self.get_kernel_env(path, name) - kernel_id = await self.kernel_manager.start_kernel( - path=kernel_path, - kernel_name=kernel_name, - env=kernel_env, - ) + kernel_env = self.get_kernel_env(path, name) + kernel_id = await self.kernel_manager.start_kernel( + path=kernel_path, + kernel_name=kernel_name, + env=kernel_env, + ) return cast(str, kernel_id) + + def waiting_kernel( + self + ) -> Dict[str, Any]: + return { + "id": "waiting", + "name": "Waiting for kernel to start", + "last_activity": "2024-10-02T16:56:56.423328Z", + "execution_state": "waiting", + "connections": 0, + } + + def waiting_session( + self, + session_id: str, + path: str, + type: str, + name: str, + ) -> Dict[str, Any]: + result = { + "id": session_id, + "kernel": self.waiting_kernel(), + "path": path, + "type": type, + "name": name, + "notebook": {"path": path, "name": name}, + "last_activity": "2024-10-02T16:56:56.423328Z", + "execution_state": "waiting", + "connections": 0 + } + return result async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None): """Saves the items for the session with the given session_id @@ -397,37 +467,42 @@ async def get_session(self, **kwargs): returns a dictionary that includes all the information from the session described by the kwarg. """ - if not kwargs: - msg = "must specify a column to query" - raise TypeError(msg) - - conditions = [] - for column in kwargs: - if column not in self._columns: - msg = f"No such column: {column}" + session_id = kwargs["session_id"] + if self.fut_kernel_id_dict is not None and session_id in self.fut_kernel_id_dict: + model = self.waiting_session(session_id, "unknown", "notebook", "Waiting for kernel to start") + else: + if not kwargs: + msg = "must specify a column to query" raise TypeError(msg) - conditions.append("%s=?" % column) - - query = "SELECT * FROM session WHERE %s" % (" AND ".join(conditions)) # noqa: S608 - - self.cursor.execute(query, list(kwargs.values())) - try: - row = self.cursor.fetchone() - except KeyError: - # The kernel is missing, so the session just got deleted. - row = None - - if row is None: - q = [] - for key, value in kwargs.items(): - q.append(f"{key}={value!r}") - - raise web.HTTPError(404, "Session not found: %s" % (", ".join(q))) - - try: - model = await self.row_to_model(row) - except KeyError as e: - raise web.HTTPError(404, "Session not found: %s" % str(e)) from e + + conditions = [] + for column in kwargs: + if column not in self._columns: + msg = f"No such column: {column}" + raise TypeError(msg) + conditions.append("%s=?" % column) + + query = "SELECT * FROM session WHERE %s" % (" AND ".join(conditions)) # noqa: S608 + + self.cursor.execute(query, list(kwargs.values())) + try: + row = self.cursor.fetchone() + except KeyError: + # The kernel is missing, so the session just got deleted. + row = None + + if row is None: + q = [] + for key, value in kwargs.items(): + q.append(f"{key}={value!r}") + + raise web.HTTPError(404, "Session not found: %s" % (", ".join(q))) + + try: + model = await self.row_to_model(row) + except KeyError as e: + raise web.HTTPError(404, "Session not found: %s" % str(e)) from e + return model async def update_session(self, session_id, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index 9bfcc74eef..4e897a5a5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", ] -requires-python = ">=3.8" +requires-python = ">=3.10" dependencies = [ "anyio>=3.1.0", "argon2-cffi>=21.1",