diff --git a/docs/source/api/jupyter_server.base.rst b/docs/source/api/jupyter_server.base.rst index 42aa6211ff..0a8e12c6d1 100644 --- a/docs/source/api/jupyter_server.base.rst +++ b/docs/source/api/jupyter_server.base.rst @@ -5,6 +5,12 @@ Submodules ---------- +.. automodule:: jupyter_server.base.call_context + :members: + :undoc-members: + :show-inheritance: + + .. automodule:: jupyter_server.base.handlers :members: :undoc-members: diff --git a/jupyter_server/__init__.py b/jupyter_server/__init__.py index c6ca245c91..c0f77cba0a 100644 --- a/jupyter_server/__init__.py +++ b/jupyter_server/__init__.py @@ -15,6 +15,7 @@ del os from ._version import __version__, version_info # noqa +from .base.call_context import CallContext # noqa def _cleanup(): diff --git a/jupyter_server/base/call_context.py b/jupyter_server/base/call_context.py new file mode 100644 index 0000000000..3d989121c2 --- /dev/null +++ b/jupyter_server/base/call_context.py @@ -0,0 +1,88 @@ +"""Provides access to variables pertaining to specific call contexts.""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from contextvars import Context, ContextVar, copy_context +from typing import Any, Dict, List + + +class CallContext: + """CallContext essentially acts as a namespace for managing context variables. + + Although not required, it is recommended that any "file-spanning" context variable + names (i.e., variables that will be set or retrieved from multiple files or services) be + added as constants to this class definition. + """ + + # Add well-known (file-spanning) names here. + #: Provides access to the current request handler once set. + JUPYTER_HANDLER: str = "JUPYTER_HANDLER" + + # A map of variable name to value is maintained as the single ContextVar. This also enables + # easier management over maintaining a set of ContextVar instances, since the Context is a + # map of ContextVar instances to their values, and the "name" is no longer a lookup key. + _NAME_VALUE_MAP = "_name_value_map" + _name_value_map: ContextVar[Dict[str, Any]] = ContextVar(_NAME_VALUE_MAP) + + @classmethod + def get(cls, name: str) -> Any: + """Returns the value corresponding the named variable relative to this context. + + If the named variable doesn't exist, None will be returned. + + Parameters + ---------- + name : str + The name of the variable to get from the call context + + Returns + ------- + value: Any + The value associated with the named variable for this call context + """ + name_value_map = CallContext._get_map() + + if name in name_value_map: + return name_value_map[name] + return None # TODO - should this raise `LookupError` (or a custom error derived from said) + + @classmethod + def set(cls, name: str, value: Any) -> None: + """Sets the named variable to the specified value in the current call context. + + Parameters + ---------- + name : str + The name of the variable to store into the call context + value : Any + The value of the variable to store into the call context + + Returns + ------- + None + """ + name_value_map = CallContext._get_map() + name_value_map[name] = value + + @classmethod + def context_variable_names(cls) -> List[str]: + """Returns a list of variable names set for this call context. + + Returns + ------- + names: List[str] + A list of variable names set for this call context. + """ + name_value_map = CallContext._get_map() + return list(name_value_map.keys()) + + @classmethod + def _get_map(cls) -> Dict[str, Any]: + """Get the map of names to their values from the _NAME_VALUE_MAP context var. + + If the map does not exist in the current context, an empty map is created and returned. + """ + ctx: Context = copy_context() + if CallContext._name_value_map not in ctx: + CallContext._name_value_map.set({}) + return CallContext._name_value_map.get() diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index f749447482..061eea672a 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -26,6 +26,7 @@ from traitlets.config import Application import jupyter_server +from jupyter_server import CallContext from jupyter_server._sysinfo import get_sys_info from jupyter_server._tz import utcnow from jupyter_server.auth import authorized @@ -582,6 +583,9 @@ def check_host(self): async def prepare(self): """Pepare a response.""" + # Set the current Jupyter Handler context variable. + CallContext.set(CallContext.JUPYTER_HANDLER, self) + if not self.check_host(): self.current_user = self._jupyter_current_user = None raise web.HTTPError(403) diff --git a/tests/base/test_call_context.py b/tests/base/test_call_context.py new file mode 100644 index 0000000000..1c12338d61 --- /dev/null +++ b/tests/base/test_call_context.py @@ -0,0 +1,109 @@ +import asyncio + +from jupyter_server import CallContext +from jupyter_server.auth.utils import get_anonymous_username +from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager + + +async def test_jupyter_handler_contextvar(jp_fetch, monkeypatch): + # Create some mock kernel Ids + kernel1 = "x-x-x-x-x" + kernel2 = "y-y-y-y-y" + + # We'll use this dictionary to track the current user within each request. + context_tracker = { + kernel1: {"started": "no user yet", "ended": "still no user", "user": None}, + kernel2: {"started": "no user yet", "ended": "still no user", "user": None}, + } + + # Monkeypatch the get_current_user method in Tornado's + # request handler to return a random user name for + # each request + async def get_current_user(self): + return get_anonymous_username() + + monkeypatch.setattr(JupyterHandler, "get_current_user", get_current_user) + + # Monkeypatch the kernel_model method to show that + # the current context variable is truly local and + # not contaminated by other asynchronous parallel requests. + # Note that even though the current implementation of `kernel_model()` + # is synchronous, we can convert this into an async method because the + # kernel handler wraps the call to `kernel_model()` in `ensure_async()`. + async def kernel_model(self, kernel_id): + # Get the Jupyter Handler from the current context. + current: JupyterHandler = CallContext.get(CallContext.JUPYTER_HANDLER) + # Get the current user + context_tracker[kernel_id]["user"] = current.current_user + context_tracker[kernel_id]["started"] = current.current_user + await asyncio.sleep(1.0) + # Track the current user a few seconds later. We'll + # verify that this user was unaffected by other parallel + # requests. + context_tracker[kernel_id]["ended"] = current.current_user + return {"id": kernel_id, "name": "blah"} + + monkeypatch.setattr(AsyncMappingKernelManager, "kernel_model", kernel_model) + + # Make two requests in parallel. + await asyncio.gather( + jp_fetch("api", "kernels", kernel1), + jp_fetch("api", "kernels", kernel2), + ) + + # Assert that the two requests had different users + assert context_tracker[kernel1]["user"] != context_tracker[kernel2]["user"] + # Assert that the first request started+ended with the same user + assert context_tracker[kernel1]["started"] == context_tracker[kernel1]["ended"] + # Assert that the second request started+ended with the same user + assert context_tracker[kernel2]["started"] == context_tracker[kernel2]["ended"] + + +async def test_context_variable_names(): + CallContext.set("foo", "bar") + CallContext.set("foo2", "bar2") + names = CallContext.context_variable_names() + assert len(names) == 2 + assert set(names) == {"foo", "foo2"} + + +async def test_same_context_operations(): + CallContext.set("foo", "bar") + CallContext.set("foo2", "bar2") + + foo = CallContext.get("foo") + assert foo == "bar" + + CallContext.set("foo", "bar2") + assert CallContext.get("foo") == CallContext.get("foo2") + + +async def test_multi_context_operations(): + async def context1(): + """The "slower" context. This ensures that, following the sleep, the + context variable set prior to the sleep is still the expected value. + If contexts are not managed properly, we should find that context2() has + corrupted context1(). + """ + CallContext.set("foo", "bar1") + await asyncio.sleep(1.0) + assert CallContext.get("foo") == "bar1" + context1_names = CallContext.context_variable_names() + assert len(context1_names) == 1 + + async def context2(): + """The "faster" context. This ensures that CallContext reflects the + appropriate values of THIS context. + """ + CallContext.set("foo", "bar2") + assert CallContext.get("foo") == "bar2" + CallContext.set("foo2", "bar2") + context2_names = CallContext.context_variable_names() + assert len(context2_names) == 2 + + await asyncio.gather(context1(), context2()) + + # Assert that THIS context doesn't have any variables defined. + names = CallContext.context_variable_names() + assert len(names) == 0