Skip to content

Commit

Permalink
Use timestamp instead of datetime to achieve faster cookie expiration… (
Browse files Browse the repository at this point in the history
#7837)

(cherry picked from commit 8ae650b)

Co-authored-by: Rongrong <i@rong.moe>
  • Loading branch information
Dreamsorcerer and Rongronggg9 authored Nov 14, 2023
1 parent 53476df commit e07a1bd
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGES/7824.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use timestamp instead of ``datetime`` to achieve faster cookie expiration in ``CookieJar``.
67 changes: 33 additions & 34 deletions aiohttp/cookiejar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import calendar
import contextlib
import datetime
import os # noqa
import pathlib
import pickle
import re
import time
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from math import ceil
from typing import ( # noqa
DefaultDict,
Dict,
Expand All @@ -24,7 +27,7 @@
from yarl import URL

from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address, next_whole_second
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL

__all__ = ("CookieJar", "DummyCookieJar")
Expand Down Expand Up @@ -52,9 +55,22 @@ class CookieJar(AbstractCookieJar):

DATE_YEAR_RE = re.compile(r"(\d{2,4})")

MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)

MAX_32BIT_TIME = datetime.datetime.fromtimestamp(2**31 - 1, datetime.timezone.utc)
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except OSError:
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1

def __init__(
self,
Expand Down Expand Up @@ -83,14 +99,8 @@ def __init__(
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._next_expiration = next_whole_second()
self._expirations: Dict[Tuple[str, str, str], datetime.datetime] = {}
# #4515: datetime.max may not be representable on 32-bit platforms
self._max_time = self.MAX_TIME
try:
self._max_time.timestamp()
except OverflowError:
self._max_time = self.MAX_32BIT_TIME
self._next_expiration: float = ceil(time.time())
self._expirations: Dict[Tuple[str, str, str], float] = {}

def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
Expand All @@ -104,14 +114,14 @@ def load(self, file_path: PathLike) -> None:

def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._next_expiration = next_whole_second()
self._next_expiration = ceil(time.time())
self._cookies.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return

to_del = []
now = datetime.datetime.now(datetime.timezone.utc)
now = time.time()
for (domain, path), cookie in self._cookies.items():
for name, morsel in cookie.items():
key = (domain, path, name)
Expand All @@ -127,13 +137,11 @@ def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
del self._expirations[(domain, path, name)]
self._cookies[(domain, path)].pop(name, None)

next_expiration = min(self._expirations.values(), default=self._max_time)
try:
self._next_expiration = next_expiration.replace(
microsecond=0
) + datetime.timedelta(seconds=1)
except OverflowError:
self._next_expiration = self._max_time
self._next_expiration = (
min(*self._expirations.values(), self.SUB_MAX_TIME) + 1
if self._expirations
else self.MAX_TIME
)

def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
Expand All @@ -149,9 +157,7 @@ def __len__(self) -> int:
def _do_expiration(self) -> None:
self.clear(lambda x: False)

def _expire_cookie(
self, when: datetime.datetime, domain: str, path: str, name: str
) -> None:
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, path, name)] = when

Expand Down Expand Up @@ -209,12 +215,7 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No
if max_age:
try:
delta_seconds = int(max_age)
try:
max_age_expiration = datetime.datetime.now(
datetime.timezone.utc
) + datetime.timedelta(seconds=delta_seconds)
except OverflowError:
max_age_expiration = self._max_time
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
Expand Down Expand Up @@ -323,7 +324,7 @@ def _is_path_match(req_path: str, cookie_path: str) -> bool:
return non_matching.startswith("/")

@classmethod
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
Expand Down Expand Up @@ -384,9 +385,7 @@ def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None

return datetime.datetime(
year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc
)
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))


class DummyCookieJar(AbstractCookieJar):
Expand Down
7 changes: 0 additions & 7 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,6 @@ def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> b
return is_ipv4_address(host) or is_ipv6_address(host)


def next_whole_second() -> datetime.datetime:
"""Return current time rounded up to the next whole second."""
return datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
) + datetime.timedelta(seconds=0)


_cached_current_datetime: Optional[int] = None
_cached_formatted_datetime = ""

Expand Down
20 changes: 12 additions & 8 deletions tests/test_cookiejar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,27 @@ def test_date_parsing() -> None:
assert parse_func("") is None

# 70 -> 1970
assert parse_func("Tue, 1 Jan 70 00:00:00 GMT") == datetime.datetime(
1970, 1, 1, tzinfo=utc
assert (
parse_func("Tue, 1 Jan 70 00:00:00 GMT")
== datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp()
)

# 10 -> 2010
assert parse_func("Tue, 1 Jan 10 00:00:00 GMT") == datetime.datetime(
2010, 1, 1, tzinfo=utc
assert (
parse_func("Tue, 1 Jan 10 00:00:00 GMT")
== datetime.datetime(2010, 1, 1, tzinfo=utc).timestamp()
)

# No day of week string
assert parse_func("1 Jan 1970 00:00:00 GMT") == datetime.datetime(
1970, 1, 1, tzinfo=utc
assert (
parse_func("1 Jan 1970 00:00:00 GMT")
== datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp()
)

# No timezone string
assert parse_func("Tue, 1 Jan 1970 00:00:00") == datetime.datetime(
1970, 1, 1, tzinfo=utc
assert (
parse_func("Tue, 1 Jan 1970 00:00:00")
== datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp()
)

# No year
Expand Down

0 comments on commit e07a1bd

Please sign in to comment.