Skip to content

Commit

Permalink
Improve typing for cached tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Sep 26, 2024
1 parent 6d56543 commit 6c7fa74
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
10 changes: 6 additions & 4 deletions dynamo/extensions/cogs/events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import contextlib
from collections.abc import AsyncGenerator
from copy import copy
from typing import Any
from typing import Any, cast

import discord
from discord.ext import commands
Expand Down Expand Up @@ -62,14 +63,15 @@ async def on_timeout(self) -> None:

class InterestedDropdown(EventsDropdown[EventsView]):
async def callback(self, interaction: discord.Interaction) -> None:
event = next((e for e in self.events if e.id == int(self.values[0])), None)
await interaction.response.send_message(await get_interested(event) or "No users found", ephemeral=True)
event: discord.ScheduledEvent | None = next((e for e in self.events if e.id == int(self.values[0])), None)
response = "No users found" if event is None else await get_interested(event)
await interaction.response.send_message(response, ephemeral=True)


@async_cache(ttl=1800)
async def get_interested(event: discord.ScheduledEvent) -> str:
# https://peps.python.org/pep-0533/
async with contextlib.aclosing(event.users()) as gen:
async with contextlib.aclosing(cast(AsyncGenerator[discord.User], event.users())) as gen:
users: list[discord.User] = [u async for u in gen]
return f"`[{event.name}]({event.url}) {' '.join(u.mention for u in users) or "No users found"}`"

Expand Down
64 changes: 38 additions & 26 deletions dynamo/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from collections.abc import Callable, Coroutine, Hashable, MutableSequence, Sized
from dataclasses import dataclass
from functools import partial
from typing import Any, Concatenate, Generic, cast, overload
from typing import Any, Protocol, Self, cast, overload

from dynamo._typing import MISSING, P, S, T
from dynamo._typing import MISSING, P, T

log = logging.getLogger(__name__)

Expand All @@ -18,7 +18,7 @@
FAST_TYPES: set[type] = {int, str}


WrappedCoroutine = Callable[Concatenate[S, P], Coroutine[Any, Any, T]]
WrappedCoroutine = Callable[P, Coroutine[Any, Any, T]]


@dataclass(slots=True)
Expand Down Expand Up @@ -68,33 +68,37 @@ def _make_key(
return key[0] if _len(key) == 1 and _type(key[0]) in fast_types else HashedSeq(key)


class CachedTask(Generic[S, P, T]):
class CachedTask(Protocol[P, T]):
@property
def __wrapped__(self) -> WrappedCoroutine[S, P, T]: ...
def __wrapped__(self) -> WrappedCoroutine[P, T]: ...

@__wrapped__.setter
def __wrapped__(self, value: WrappedCoroutine[S, P, T]) -> None: ...
def __wrapped__(self, value: WrappedCoroutine[P, T]) -> None: ...

def __call__(self, /, __self: S = None, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: ...
@overload
def __call__(self: Self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]: ...
@overload
def __call__(self: CachedTask[P, T], *args: Any, **kwargs: Any) -> asyncio.Task[T]: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_parameters(self) -> dict[str, int | float | None]: ...


def update_wrapper(
wrapper: CachedTask[S, P, T],
wrapped: WrappedCoroutine[S, P, T],
wrapper: CachedTask[P, T],
wrapped: WrappedCoroutine[P, T],
assigned: tuple[str, ...] = WRAPPER_ASSIGNMENTS,
updated: tuple[str, ...] = WRAPPER_UPDATES,
) -> CachedTask[S, P, T]:
) -> CachedTask[P, T]:
"""
Update a wrapper function to look more like the wrapped function.
Parameters
----------
wrapper : CachedTask[S, P, T]
wrapper : CachedTask[P, T]
The wrapper function to be updated.
wrapped : WrappedCoroutine[S, P, T]
wrapped : WrappedCoroutine[P, T]
The original function being wrapped.
assigned : tuple of str, optional
Attribute names to assign from the wrapped function. Default is WRAPPER_ASSIGNMENTS.
Expand All @@ -103,7 +107,7 @@ def update_wrapper(
Returns
-------
CachedTask[S, P, T]
CachedTask[P, T]
The updated wrapper function.
Notes
Expand All @@ -129,17 +133,17 @@ def update_wrapper(
@overload
def async_cache(
*, maxsize: int | None = 128, ttl: float | None = None
) -> Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]: ...
) -> Callable[[WrappedCoroutine[P, T]], CachedTask[P, T]]: ...


@overload
def async_cache(coro: WrappedCoroutine[S, P, T], /) -> CachedTask[S, P, T]: ...
def async_cache(coro: WrappedCoroutine[P, T], /) -> CachedTask[P, T]: ...


def async_cache(
maxsize: int | WrappedCoroutine[S, P, T] | None = 128,
maxsize: int | WrappedCoroutine[P, T] | None = 128,
ttl: float | None = None,
) -> CachedTask[S, P, T] | Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]:
) -> CachedTask[P, T] | Callable[[WrappedCoroutine[P, T]], CachedTask[P, T]]:
"""
Decorator to cache the result of an asynchronous function.
Expand All @@ -149,7 +153,7 @@ def async_cache(
Parameters
----------
maxsize : int | WrappedCoroutine[S, P, T] | None, optional
maxsize : int | WrappedCoroutine[P, T] | None, optional
The maximum number of items to cache. If a coroutine function is provided directly,
it is assumed to be the function to be decorated, and `maxsize` defaults to 128.
If `None`, the cache can grow without bound. Default is 128.
Expand All @@ -159,7 +163,7 @@ def async_cache(
Returns
-------
CachedTask[S, P, T] | Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]
CachedTask[P, T] | Callable[[WrappedCoroutine[P, T]], CachedTask[P, T]]
If a coroutine function is provided directly, returns the cached task.
Otherwise, returns a decorator that can be applied to a coroutine function.
Expand Down Expand Up @@ -193,7 +197,7 @@ def async_cache(
error = "Expected first argument to be an integer, a callable, or None"
raise TypeError(error)

def decorator(coro: WrappedCoroutine[S, P, T]) -> CachedTask[S, P, T]:
def decorator(coro: WrappedCoroutine[P, T]) -> CachedTask[P, T]:
wrapper = _async_cache_wrapper(coro, maxsize, ttl)
wrapper.cache_parameters = lambda: {"maxsize": maxsize, "ttl": ttl}
return update_wrapper(wrapper, coro)
Expand All @@ -202,10 +206,10 @@ def decorator(coro: WrappedCoroutine[S, P, T]) -> CachedTask[S, P, T]:


def _async_cache_wrapper(
coro: WrappedCoroutine[S, P, T],
coro: WrappedCoroutine[P, T],
maxsize: int | None,
ttl: float | None,
) -> CachedTask[S, P, T]:
) -> CachedTask[P, T]:
sentinel = MISSING
make_key = _make_key

Expand All @@ -230,6 +234,7 @@ def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
if maxsize is not None:
with lock:
if key not in internal_cache and _cache_info.full:
log.debug("Eviction: LRU cache is full")
internal_cache.popitem(last=False)
internal_cache[key] = task
internal_cache.move_to_end(key)
Expand All @@ -239,10 +244,16 @@ def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
internal_cache[key] = task

if ttl is not None:

def evict(k: Hashable) -> None:
with lock:
log.debug("Eviction: TTL expired for %s", k)
internal_cache.pop(k, sentinel)

call_after_ttl = partial(
asyncio.get_running_loop().call_later,
ttl,
internal_cache.pop,
evict,
key,
)
task.add_done_callback(call_after_ttl)
Expand All @@ -252,10 +263,11 @@ def cache_info() -> CacheInfo:
return _cache_info

def cache_clear() -> None:
internal_cache.clear()
_cache_info.clear()
with lock:
internal_cache.clear()
_cache_info.clear()

_wrapper = cast(CachedTask[S, P, T], wrapper)
_wrapper = cast(CachedTask[P, T], wrapper)
_wrapper.cache_info = cache_info
_wrapper.cache_clear = cache_clear
return _wrapper

0 comments on commit 6c7fa74

Please sign in to comment.