Skip to content

Commit

Permalink
add spottokenrenewer
Browse files Browse the repository at this point in the history
  • Loading branch information
sigmarkarl committed Nov 29, 2023
1 parent 3755794 commit 023e524
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 8 deletions.
10 changes: 10 additions & 0 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import contextvars
import functools
import inspect
import ipaddress
Expand Down Expand Up @@ -55,6 +56,7 @@
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_server.services.sessions.sessionmanager import SessionManager

_current_request_var: contextvars.ContextVar = contextvars.ContextVar("current_request")
# -----------------------------------------------------------------------------
# Top-level handlers
# -----------------------------------------------------------------------------
Expand All @@ -81,6 +83,9 @@ def log() -> Logger:
class AuthenticatedHandler(web.RequestHandler):
"""A RequestHandler with an authenticated user."""

def prepare(self):
_current_request_var.set(self.request)

@property
def base_url(self) -> str:
return cast(str, self.settings.get("base_url", "/"))
Expand Down Expand Up @@ -1134,6 +1139,11 @@ def get(self) -> None:
self.set_header("Content-Type", prometheus_client.CONTENT_TYPE_LATEST)
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))

def get_current_request():
"""
Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed.
"""
return _current_request_var.get(None)

# -----------------------------------------------------------------------------
# URL pattern fragments for reuse
Expand Down
43 changes: 43 additions & 0 deletions jupyter_server/gateway/spottokenrenewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing as ty

import logging
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase
import jupyter_server.base.handlers
import jupyter_server.serverapp


def get_header_value(request: ty.Any, header: str) -> str:
if header not in request.headers:
logging.error(f'Header "{header}" is missing')
return ""
logging.debug(f'Getting value from header "{header}"')
value = request.headers[header]
if len(value) == 0:
logging.error(f'Header "{header}" is empty')
return ""
return value


class SpotTokenRenewer(GatewayTokenRenewerBase):

def get_token(
self,
auth_header_key: str,
auth_scheme: ty.Union[str, None],
auth_token: str,
**kwargs: ty.Any,
) -> str:
request = jupyter_server.base.handlers.get_current_request()
if request is None:
logging.error("Could not get current request")
return auth_token

auth_header_value = get_header_value(request, auth_header_key)
if auth_header_value:
try:
# We expect the header value to be of the form "Bearer: XXX"
auth_token = auth_header_value.split(" ", maxsplit=1)[1]
except Exception as e:
logging.error(f"Could not read token from auth header: {str(e)}")

return auth_token
2 changes: 2 additions & 0 deletions jupyter_server/services/sessions/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def post(self):

name = model.get("name", None)
kernel = model.get("kernel", {})
session_id = model.get("id", None)
kernel_name = kernel.get("name", None)
kernel_id = kernel.get("id", None)

Expand All @@ -92,6 +93,7 @@ async def post(self):
kernel_id=kernel_id,
name=name,
type=mtype,
session_id=session_id,
)
except NoSuchKernel:
msg = (
Expand Down
43 changes: 35 additions & 8 deletions jupyter_server/services/sessions/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# fallback on pysqlite2 if Python was build without sqlite
from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef]

import asyncio
from dataclasses import dataclass, fields

from jupyter_core.utils import ensure_async
Expand Down Expand Up @@ -210,6 +211,8 @@ def __init__(self, *args, **kwargs):
_connection = None
_columns = {"session_id", "path", "name", "type", "kernel_id"}

fut_kernel_id_dict = None

@property
def cursor(self):
"""Start a cursor and create a database called 'session'"""
Expand Down Expand Up @@ -267,6 +270,7 @@ async def create_session(
type: Optional[str] = None,
kernel_name: Optional[KernelName] = None,
kernel_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Creates a session and returns its model
Expand All @@ -276,7 +280,13 @@ async def create_session(
Usually the model name, like the filename associated with current
kernel.
"""
session_id = self.new_session_id()

if session_id is not None and self.fut_kernel_id_dict is None:
self.fut_kernel_id_dict = {}

if session_id is None or session_id == "":
session_id = self.new_session_id()

record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
if kernel_id is not None and kernel_id in self.kernel_manager:
Expand Down Expand Up @@ -337,14 +347,31 @@ async def start_kernel_for_session(
the name of the kernel specification to use. The default kernel name will be used if not provided.
"""
# allow contents manager to specify kernels cwd
kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path))
if self.fut_kernel_id_dict is not None:
if session_id in self.fut_kernel_id_dict:
fut_kernel_id = self.fut_kernel_id_dict[session_id]
if fut_kernel_id.done():
kernel_id = await fut_kernel_id
self.fut_kernel_id_dict.pop(session_id)
return kernel_id
else:
kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path))
kernel_env = self.get_kernel_env(path)
self.fut_kernel_id_dict[session_id] = asyncio.create_task(self.kernel_manager.start_kernel(
path=kernel_path,
kernel_name=kernel_name,
env=kernel_env,
))
kernel_id = "waiting"
else:
kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path))

kernel_env = self.get_kernel_env(path, name)
kernel_id = await self.kernel_manager.start_kernel(
path=kernel_path,
kernel_name=kernel_name,
env=kernel_env,
)
kernel_env = self.get_kernel_env(path, name)
kernel_id = await self.kernel_manager.start_kernel(
path=kernel_path,
kernel_name=kernel_name,
env=kernel_env,
)
return cast(str, kernel_id)

async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None):
Expand Down

0 comments on commit 023e524

Please sign in to comment.