Skip to content

Commit

Permalink
Accept and manage cookies when requesting gateways (#969)
Browse files Browse the repository at this point in the history
* Add support for stickiness cookies on load balancers

* Add case

* Simplify arguments

* Add tests for arguments

* Fix according to comments
  • Loading branch information
wjsi authored Sep 13, 2022
1 parent 4c7bbfa commit 9a2708e
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 1 deletion.
87 changes: 87 additions & 0 deletions jupyter_server/gateway/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import logging
import os
import typing as ty
from datetime import datetime
from email.utils import parsedate_to_datetime
from http.cookies import Morsel, SimpleCookie
from socket import gaierror

from tornado import web
Expand Down Expand Up @@ -276,6 +279,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self._static_args = {} # initialized on first use

# store of cookies with store time
self._cookies = {} # type: ty.Dict[str, ty.Tuple[Morsel, datetime]]

env_whitelist_default_value = ""
env_whitelist_env = "JUPYTER_GATEWAY_ENV_WHITELIST"
env_whitelist = Unicode(
Expand Down Expand Up @@ -363,6 +369,23 @@ def launch_timeout_pad_default(self):
)
)

accept_cookies_value = False
accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES"
accept_cookies = Bool(
default_value=accept_cookies_value,
config=True,
help="""Accept and manage cookies sent by the service side. This is often useful
for load balancers to decide which backend node to use.
(JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""",
)

@default("accept_cookies")
def accept_cookies_default(self):
return bool(
os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower())
not in ["no", "false"]
)

@property
def gateway_enabled(self):
return bool(self.url is not None and len(self.url) > 0)
Expand Down Expand Up @@ -424,8 +447,65 @@ def load_connection_args(self, **kwargs):
else:
kwargs[arg] = static_value

if self.accept_cookies:
self._update_cookie_header(kwargs)

return kwargs

def update_cookies(self, cookie: SimpleCookie) -> None:
"""Update cookies from existing requests for load balancers"""
if not self.accept_cookies:
return

store_time = datetime.now()
for key, item in cookie.items():
# Convert "expires" arg into "max-age" to facilitate expiration management.
# As "max-age" has precedence, ignore "expires" when "max-age" exists.
if item.get("expires") and not item.get("max-age"):
expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
item["max-age"] = str(expire_timedelta.total_seconds())

self._cookies[key] = (item, store_time)

def _clear_expired_cookies(self) -> None:
check_time = datetime.now()
expired_keys = []

for key, (morsel, store_time) in self._cookies.items():
cookie_max_age = morsel.get("max-age")
if not cookie_max_age:
continue
expired_timedelta = check_time - store_time
if expired_timedelta.total_seconds() > float(cookie_max_age):
expired_keys.append(key)

for key in expired_keys:
self._cookies.pop(key)

def _update_cookie_header(self, connection_args: dict) -> None:
self._clear_expired_cookies()

gateway_cookie_values = "; ".join(
f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items()
)
if gateway_cookie_values:
headers = connection_args.get("headers", {})

# As headers are case-insensitive, we get existing name of cookie header,
# or use "Cookie" by default.
cookie_header_name = next(
(header_key for header_key in headers if header_key.lower() == "cookie"),
"Cookie",
)
existing_cookie = headers.get(cookie_header_name)

# merge gateway-managed cookies with cookies already in arguments
if existing_cookie:
gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values
headers[cookie_header_name] = gateway_cookie_values

connection_args["headers"] = headers


class RetryableHTTPClient:
"""
Expand Down Expand Up @@ -524,4 +604,11 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
f"appear to be valid. Ensure gateway url is valid and the Gateway instance is running.",
) from e

if GatewayClient.instance().accept_cookies:
# Update cookies on GatewayClient from server if configured.
cookie_values = response.headers.get("Set-Cookie")
if cookie_values:
cookie: SimpleCookie = SimpleCookie()
cookie.load(cookie_values)
GatewayClient.instance().update_cookies(cookie)
return response
53 changes: 52 additions & 1 deletion tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
import os
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from email.utils import format_datetime
from http.cookies import SimpleCookie
from io import BytesIO
from queue import Empty
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -187,6 +189,7 @@ def init_gateway(monkeypatch):
monkeypatch.setenv("JUPYTER_GATEWAY_REQUEST_TIMEOUT", "44.4")
monkeypatch.setenv("JUPYTER_GATEWAY_CONNECT_TIMEOUT", "44.4")
monkeypatch.setenv("JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD", "1.1")
monkeypatch.setenv("JUPYTER_GATEWAY_ACCEPT_COOKIES", "false")
yield
GatewayClient.clear_instance()

Expand All @@ -200,6 +203,7 @@ async def test_gateway_env_options(init_gateway, jp_serverapp):
)
assert jp_serverapp.gateway_config.connect_timeout == 44.4
assert jp_serverapp.gateway_config.launch_timeout_pad == 1.1
assert jp_serverapp.gateway_config.accept_cookies is False

GatewayClient.instance().init_static_args()
assert GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == 43
Expand Down Expand Up @@ -259,6 +263,53 @@ async def test_gateway_request_timeout_pad_option(
GatewayClient.clear_instance()


cookie_expire_time = format_datetime(datetime.now() + timedelta(seconds=180))


@pytest.mark.parametrize(
"accept_cookies,expire_arg,expire_param,existing_cookies,cookie_exists",
[
(False, None, None, "EXISTING=1", False),
(True, None, None, "EXISTING=1", True),
(True, "Expires", cookie_expire_time, None, True),
(True, "Max-Age", "-360", "EXISTING=1", False),
],
)
async def test_gateway_request_with_expiring_cookies(
jp_configurable_serverapp,
accept_cookies,
expire_arg,
expire_param,
existing_cookies,
cookie_exists,
):
argv = [f"--GatewayClient.accept_cookies={accept_cookies}"]

GatewayClient.clear_instance()
jp_configurable_serverapp(argv=argv)

cookie: SimpleCookie = SimpleCookie()
cookie.load("SERVERID=1234567; Path=/")
if expire_arg:
cookie["SERVERID"][expire_arg] = expire_param

GatewayClient.instance().update_cookies(cookie)

args = {}
if existing_cookies:
args["headers"] = {"Cookie": existing_cookies}
connection_args = GatewayClient.instance().load_connection_args(**args)

if not cookie_exists:
assert "SERVERID" not in (connection_args["headers"].get("Cookie") or "")
else:
assert "SERVERID" in connection_args["headers"].get("Cookie")
if existing_cookies:
assert "EXISTING" in connection_args["headers"].get("Cookie")

GatewayClient.clear_instance()


async def test_gateway_class_mappings(init_gateway, jp_serverapp):
# Ensure appropriate class mappings are in place.
assert jp_serverapp.kernel_manager_class.__name__ == "GatewayMappingKernelManager"
Expand Down

0 comments on commit 9a2708e

Please sign in to comment.