Skip to content

Commit

Permalink
original pseudo waiting kernel code (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigmarkarl authored Oct 9, 2024
1 parent 717ab33 commit 5628e18
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 70 deletions.
38 changes: 38 additions & 0 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import mimetypes
import os
import re
import sys
import types
import typing as ty
import logging
import warnings
from http.client import responses
from logging import Logger
Expand Down Expand Up @@ -61,6 +64,7 @@
# -----------------------------------------------------------------------------

_sys_info_cache = None
_my_globals = "myglobals"


def json_sys_info():
Expand All @@ -79,9 +83,36 @@ def log() -> Logger:
return app_log


def get_token_value(request: ty.Any, prev: str) -> str:
header = "Authorization"
if header not in request.headers:
logging.error(f'Header "{header}" is missing')
return prev
logging.debug(f'Getting value from header "{header}"')
auth_header_value: str = request.headers[header]
if len(auth_header_value) == 0:
logging.error(f'Header "{header}" is empty')
return prev

try:
logging.info(f"Auth header value: {auth_header_value}")
# We expect the header value to be of the form "Bearer: XXX"
return auth_header_value.split(" ", maxsplit=1)[1]
except Exception as e:
logging.error(f"Could not read token from auth header: {str(e)}")

return prev


class AuthenticatedHandler(web.RequestHandler):
"""A RequestHandler with an authenticated user."""

def prepare(self):
if _my_globals not in sys.modules:
sys.modules[_my_globals] = types.ModuleType(_my_globals)
prevtoken = sys.modules[_my_globals].token if hasattr(sys.modules[_my_globals], "token") else ""
sys.modules[_my_globals].token = get_token_value(self.request, prevtoken)

@property
def base_url(self) -> str:
return cast(str, self.settings.get("base_url", "/"))
Expand Down Expand Up @@ -1175,6 +1206,13 @@ def get(self) -> None:
self.set_header("Content-Type", prometheus_client.CONTENT_TYPE_LATEST)
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))

def get_current_token():
"""
Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed.
"""
if _my_globals in sys.modules and hasattr(sys.modules[_my_globals], "token"):
return sys.modules[_my_globals].token
return ""

class PublicStaticFileHandler(web.StaticFileHandler):
"""Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required."""
Expand Down
97 changes: 66 additions & 31 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
if TYPE_CHECKING:
from logging import Logger

_local_kernels: dict[str, ServerKernelManager] = {}

class GatewayMappingKernelManager(AsyncMappingKernelManager):
"""Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway."""
Expand Down Expand Up @@ -78,20 +79,38 @@ async def start_kernel(self, *, kernel_id=None, path=None, **kwargs):
The API path (unicode, '/' delimited) for the cwd.
Will be transformed to an OS path relative to root_dir.
"""
self.log.info(f"Request start kernel: kernel_id={kernel_id}, path='{path}'")

if kernel_id is None and path is not None:
kwargs["cwd"] = self.cwd_for_path(path)

km = self.kernel_manager_factory(parent=self, log=self.log)
await km.start_kernel(kernel_id=kernel_id, **kwargs)
kernel_id = km.kernel_id
self._kernels[kernel_id] = km
# Initialize culling if not already
if not self._initialized_culler:
self.initialize_culler()

return kernel_id
kernel_name = kwargs.get("kernel_name")
if kernel_name == 'python3' or kernel_name.startswith('sc-'):
kwargs["kernel_name"] = "python3"
kwargs["local"] = True
env = kwargs["env"]
if kernel_name.startswith('sc-'):
app = kernel_name
startup_file = f"/tmp/{app}.py"
startup_content = f"""from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('{kernel_name}').remote('sc://{app}-driver-svc.spark-apps.svc.cluster.local').getOrCreate()
"""
env["PYTHONSTARTUP"] = startup_file
with open(startup_file, "w") as f:
f.write(startup_content)
kernel_id = await super().start_kernel(kernel_id=kernel_id, path=path, **kwargs)
_local_kernels[kernel_id] = self._kernels[kernel_id]
return kernel_id
else:
self.log.info(f"Request start kernel: kernel_id={kernel_id}, path='{path}'")

if kernel_id is None and path is not None:
kwargs["cwd"] = self.cwd_for_path(path)

km = self.kernel_manager_factory(parent=self, log=self.log)
await km.start_kernel(kernel_id=kernel_id, **kwargs)
kernel_id = km.kernel_id
self._kernels[kernel_id] = km
# Initialize culling if not already
if not self._initialized_culler:
self.initialize_culler()

return kernel_id

async def kernel_model(self, kernel_id):
"""Return a dictionary of kernel information described in the
Expand All @@ -102,11 +121,17 @@ async def kernel_model(self, kernel_id):
kernel_id : uuid
The uuid of the kernel.
"""
model = None
km = self.get_kernel(str(kernel_id))
if km: # type:ignore[truthy-bool]
model = km.kernel # type:ignore[attr-defined]
return model
if kernel_id in _local_kernels:
str_kernel_id = str(kernel_id)
model = super().kernel_model(kernel_id)
_local_kernels[str_kernel_id] = model
return model
else:
model = None
km = self.get_kernel(str(kernel_id))
if km: # type:ignore[truthy-bool]
model = km.kernel # type:ignore[attr-defined]
return model

async def list_kernels(self, **kwargs):
"""Get a list of running kernels from the Gateway server.
Expand Down Expand Up @@ -437,19 +462,22 @@ async def refresh_model(self, model=None):
model is fetched from the Gateway server.
"""
if model is None:
self.log.debug("Request kernel at: %s" % self.kernel_url)
try:
response = await gateway_request(self.kernel_url, method="GET")

except web.HTTPError as error:
if error.status_code == 404:
self.log.warning("Kernel not found at: %s" % self.kernel_url)
model = None
else:
raise
if self.kernel_id in _local_kernels:
model = _local_kernels[self.kernel_id]
else:
model = json_decode(response.body)
self.log.debug("Kernel retrieved: %s" % model)
self.log.debug("Request kernel at: %s" % self.kernel_url)
try:
response = await gateway_request(self.kernel_url, method="GET")

except web.HTTPError as error:
if error.status_code == 404:
self.log.warning("Kernel not found at: %s" % self.kernel_url)
model = None
else:
raise
else:
model = json_decode(response.body)
self.log.debug("Kernel retrieved: %s" % model)

if model: # Update activity markers
self.last_activity = datetime.datetime.strptime(
Expand Down Expand Up @@ -484,6 +512,13 @@ async def start_kernel(self, **kwargs):
"""
kernel_id = kwargs.get("kernel_id")

if "local" in kwargs:
kwargs.pop("local")
self.kernel_id = kernel_id
self.kernel_url = url_path_join(self.kernels_url, url_escape(str(self.kernel_id)))
self.kernel = await self.refresh_model()
await super().start_kernel(**kwargs)

if kernel_id is None:
kernel_name = kwargs.get("kernel_name", "python3")
self.log.debug("Request new kernel at: %s" % self.kernels_url)
Expand Down
17 changes: 17 additions & 0 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
authorizer=None,
identity_provider=None,
kernel_websocket_connection_class=None,
local_kernel_websocket_connection_class=None,
websocket_ping_interval=None,
websocket_ping_timeout=None,
):
Expand Down Expand Up @@ -290,6 +291,7 @@ def __init__(
authorizer=authorizer,
identity_provider=identity_provider,
kernel_websocket_connection_class=kernel_websocket_connection_class,
local_kernel_websocket_connection_class=local_kernel_websocket_connection_class,
websocket_ping_interval=websocket_ping_interval,
websocket_ping_timeout=websocket_ping_timeout,
)
Expand Down Expand Up @@ -357,6 +359,7 @@ def init_settings(
authorizer=None,
identity_provider=None,
kernel_websocket_connection_class=None,
local_kernel_websocket_connection_class=None,
websocket_ping_interval=None,
websocket_ping_timeout=None,
):
Expand Down Expand Up @@ -442,6 +445,7 @@ def init_settings(
"identity_provider": identity_provider,
"event_logger": event_logger,
"kernel_websocket_connection_class": kernel_websocket_connection_class,
"local_kernel_websocket_connection_class": local_kernel_websocket_connection_class,
"websocket_ping_interval": websocket_ping_interval,
"websocket_ping_timeout": websocket_ping_timeout,
# handlers
Expand Down Expand Up @@ -1630,6 +1634,12 @@ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]:
help=_i18n("The kernel websocket connection class to use."),
)

local_kernel_websocket_connection_class = Type(
klass=BaseKernelWebsocketConnection,
config=True,
help=_i18n("The local kernel websocket connection class to use."),
)

@default("kernel_websocket_connection_class")
def _default_kernel_websocket_connection_class(
self,
Expand All @@ -1638,6 +1648,12 @@ def _default_kernel_websocket_connection_class(
return "jupyter_server.gateway.connections.GatewayWebSocketConnection"
return ZMQChannelsWebsocketConnection

@default("local_kernel_websocket_connection_class")
def _default_local_kernel_websocket_connection_class(
self,
) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]:
return ZMQChannelsWebsocketConnection

websocket_ping_interval = Integer(
config=True,
help="""
Expand Down Expand Up @@ -2252,6 +2268,7 @@ def init_webapp(self) -> None:
authorizer=self.authorizer,
identity_provider=self.identity_provider,
kernel_websocket_connection_class=self.kernel_websocket_connection_class,
local_kernel_websocket_connection_class=self.local_kernel_websocket_connection_class,
websocket_ping_interval=self.websocket_ping_interval,
websocket_ping_timeout=self.websocket_ping_timeout,
)
Expand Down
Loading

0 comments on commit 5628e18

Please sign in to comment.