Skip to content

Commit

Permalink
Some typing and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Sep 25, 2024
1 parent b8ec870 commit a1d6495
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 125 deletions.
256 changes: 141 additions & 115 deletions dynamo/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@
import logging
import threading
from collections import OrderedDict
from collections.abc import Callable, Coroutine, Hashable, Sized
from collections.abc import Callable, Coroutine, Hashable, MutableSequence, Sized
from dataclasses import dataclass
from functools import partial
from typing import Any, Concatenate, Generic, Protocol, Self
from typing import Any, Concatenate, Generic, cast, overload

from dynamo._typing import MISSING, P, S, S_co, T, T_co
from dynamo._typing import MISSING, P, S, T

log = logging.getLogger(__name__)

WRAPPER_ASSIGNMENTS = ("__module__", "__name__", "__qualname__", "__doc__", "__annotations__", "__type_params__")
WRAPPER_UPDATES = ("__dict__",)
FAST_TYPES: set[type] = {int, str}


WrappedCoroutine = Callable[Concatenate[S, P], Coroutine[Any, Any, T]]


@dataclass(slots=True)
class CacheInfo:
"""Cache info for the async_lru_cache decorator."""
"""Cache info for the async_cache decorator."""

hits: int = 0
misses: int = 0
Expand All @@ -32,99 +36,154 @@ def clear(self) -> None:
self.full = False


class HashedSeq(list[Any]):
__slots__ = ("hash_value",)
class HashedSeq(MutableSequence[Any]):
__slots__: tuple[str, ...] = ("hash_value",)

def __init__(self, *args: Any, hash: Callable[[object], int] = hash) -> None: # noqa: A002
def __init__(self, /, *args: Any, _hash: Callable[[object], int] = hash) -> None:
self[:] = args
self.hash_value: int = hash(args)
self.hash_value: int = _hash(args)

def __hash__(self) -> int: # type: ignore
def __hash__(self) -> int:
return self.hash_value


def _make_key(
args: tuple[Any, ...],
kwargs: dict[Any, Any],
kwargs_mark: tuple[object] = (object(),),
fast_types: set[type] = {int, str}, # noqa: B006
type: type[type] = type, # noqa: A002
len: Callable[[Sized], int] = len, # noqa: A002
fast_types: set[type] = FAST_TYPES,
_type: type[type] = type,
_len: Callable[[Sized], int] = len,
) -> Hashable:
"""
Make cache key from optionally typed positional and keyword arguments. Structure is flat and hashable.
Although efficient, it will treat `f(x=1, y=2)` and `f(y=2, x=1)` as distinct calls and will be cached
separately.
"""
key: tuple[Any, ...] = args
if kwargs:
key += kwargs_mark
for item in kwargs.items():
key += item
return key[0] if len(key) == 1 and type(key[0]) in fast_types else HashedSeq(key)

return key[0] if _len(key) == 1 and _type(key[0]) in fast_types else HashedSeq(key)

class AsyncMethod(Protocol[S_co, P, T_co]):
def __call__(self, __self: Self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_co]: ...

class CachedTask(Generic[S, P, T]):
__slots__: tuple[str, ...] = ("__wrapped__", "cache_info", "cache_clear", "cache_parameters")

class LRUAsyncMethod(Generic[S, P, T]):
__wrapped__: Callable[Concatenate[S, P], Callable[..., AsyncMethod[S, P, T]]]

def __call__(self, __self: Self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_parameters(self) -> dict[str, int | float | None]: ...

@property
def __wrapped__(self) -> WrappedCoroutine[S, P, T]: ...

class LRUAsyncCallable(Generic[P, T]):
__wrapped__: Callable[P, Callable[Concatenate[Any, P], Coroutine[Any, Any, T]]]
@__wrapped__.setter
def __wrapped__(self, value: WrappedCoroutine[S, P, T]) -> None: ...

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]: ...
def __call__(self, /, __self: S = None, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_parameters(self) -> dict[str, int | float | None]: ...


def update_wrapper(
wrapper: LRUAsyncCallable[P, T] | LRUAsyncMethod[S, P, T],
wrapped: Callable[Concatenate[S, P], Coroutine[Any, Any, T]],
wrapper: CachedTask[S, P, T],
wrapped: WrappedCoroutine[S, P, T],
assigned: tuple[str, ...] = WRAPPER_ASSIGNMENTS,
updated: tuple[str, ...] = WRAPPER_UPDATES,
) -> LRUAsyncCallable[P, T] | LRUAsyncMethod[Any, P, T]:
) -> CachedTask[S, P, T]:
"""
Update a wrapper function to look more like the wrapped function.
Parameters
----------
wrapper : CachedTask[S, P, T]
The wrapper function to be updated.
wrapped : WrappedCoroutine[S, P, T]
The original function being wrapped.
assigned : tuple of str, optional
Attribute names to assign from the wrapped function. Default is WRAPPER_ASSIGNMENTS.
updated : tuple of str, optional
Attribute names to update from the wrapped function. Default is WRAPPER_UPDATES.
Returns
-------
CachedTask[S, P, T]
The updated wrapper function.
Notes
-----
Typically used in decorators to ensure the wrapper function retains the metadata
of the wrapped function.
See Also
--------
functools.update_wrapper : Similar function for synchronous functions.
"""
for attr in assigned:
try:
value = getattr(wrapped, attr)
except AttributeError:
pass
else:
setattr(wrapper, attr, value)
if hasattr(wrapped, attr):
setattr(wrapper, attr, getattr(wrapped, attr))
for attr in updated:
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
if hasattr(wrapper, attr) and hasattr(wrapped, attr):
getattr(wrapper, attr).update(getattr(wrapped, attr))

wrapper.__wrapped__ = wrapped
return wrapper


@overload
def async_cache(
maxsize: int
| Callable[P, Coroutine[Any, Any, T]]
| Callable[Concatenate[S, P], Coroutine[Any, Any, T]]
| None = 128,
*, maxsize: int | None = 128, ttl: float | None = None
) -> Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]: ...


@overload
def async_cache(coro: WrappedCoroutine[S, P, T], /) -> CachedTask[S, P, T]: ...


def async_cache(
maxsize: int | WrappedCoroutine[S, P, T] | None = 128,
ttl: float | None = None,
) -> Any:
"""Decorator to cache the result of an asynchronous function.
) -> CachedTask[S, P, T] | Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]:
"""
Decorator to cache the result of an asynchronous function.
Functionally similar to `functools.cache` & `functools.lru_cache` but non-blocking.
This decorator caches the result of an asynchronous function to improve performance
by avoiding redundant computations. It is functionally similar to :func:`functools.cache`
and :func:`functools.lru_cache` but designed for asynchronous functions.
Parameters
----------
maxsize : int | None, optional
Set the maximum number of items to cache.
ttl : int | None, optional
Set the time to live for cached items in seconds.
See
---
- https://github.com/mikeshardmind/async-utils/blob/main/async_utils/task_cache.py
- https://asyncstdlib.readthedocs.io/en/stable
"""
maxsize : int | WrappedCoroutine[S, P, T] | None, optional
The maximum number of items to cache. If a coroutine function is provided directly,
it is assumed to be the function to be decorated, and `maxsize` defaults to 128.
If `None`, the cache can grow without bound. Default is 128.
ttl : float | None, optional
The time-to-live for cached items in seconds. If `None`, items do not expire.
Default is None.
Returns
-------
CachedTask[S, P, T] | Callable[[WrappedCoroutine[S, P, T]], CachedTask[S, P, T]]
If a coroutine function is provided directly, returns the cached task.
Otherwise, returns a decorator that can be applied to a coroutine function.
Examples
--------
Using the decorator with default parameters:
>>> @async_cache
... async def fetch_data(url: str) -> str:
... # Simulate a network request
... await asyncio.sleep(1)
... return f"Data from {url}"
Using the decorator with custom parameters:
>>> @async_cache(maxsize=256, ttl=60.0)
... async def fetch_data(url: str) -> str:
... # Simulate a network request
... await asyncio.sleep(1)
... return f"Data from {url}"
"""
if isinstance(maxsize, int):
maxsize = max(maxsize, 0)
elif callable(maxsize):
Expand All @@ -136,9 +195,7 @@ def async_cache(
error = "Expected first argument to be an integer, a callable, or None"
raise TypeError(error)

def decorator(
coro: Callable[P, Coroutine[Any, Any, T]] | Callable[Concatenate[S, P], Coroutine[Any, Any, T]],
) -> LRUAsyncCallable[P, T] | LRUAsyncMethod[S, P, T]:
def decorator(coro: WrappedCoroutine[S, P, T]) -> CachedTask[S, P, T]:
wrapper = _async_cache_wrapper(coro, maxsize, ttl)
wrapper.cache_parameters = lambda: {"maxsize": maxsize, "ttl": ttl}
return update_wrapper(wrapper, coro)
Expand All @@ -147,10 +204,10 @@ def decorator(


def _async_cache_wrapper(
coro: Callable[P, Coroutine[Any, Any, T]] | Callable[Concatenate[S, P], Coroutine[Any, Any, T]],
coro: WrappedCoroutine[S, P, T],
maxsize: int | None,
ttl: float | None,
) -> LRUAsyncCallable[P, T] | LRUAsyncMethod[S, P, T]:
) -> CachedTask[S, P, T]:
sentinel = MISSING
make_key = _make_key

Expand All @@ -160,70 +217,38 @@ def _async_cache_wrapper(
lock = threading.Lock()
_cache_info = CacheInfo()

if maxsize == 0:

def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
log.debug("Cache miss for %s", args)
_cache_info.misses += 1
task: asyncio.Task[T] = asyncio.create_task(coro(*args, **kwargs))
if ttl is not None:
call_after_ttl = partial(
asyncio.get_running_loop().call_later,
ttl,
internal_cache.pop,
MISSING,
)
task.add_done_callback(call_after_ttl)
return task

elif maxsize is None:

def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
key = make_key(args, kwargs)
def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
key: Hashable = make_key(args, kwargs)
with lock:
result = cache_get(key, sentinel)
if result is not sentinel:
log.debug("Cache hit for %s", args)
_cache_info.hits += 1
return result
task: asyncio.Task[T] = asyncio.create_task(coro(*args, **kwargs))
internal_cache[key] = task
log.debug("Cache miss for %s", args)
_cache_info.misses += 1
if ttl is not None:
call_after_ttl = partial(
asyncio.get_running_loop().call_later,
ttl,
internal_cache.pop,
key,
)
task.add_done_callback(call_after_ttl)
return task

else:

def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
key = make_key(args, kwargs)
with lock:
link = cache_get(key)
if link is not None:
log.debug("Cache hit for %s", args)
_cache_info.hits += 1
return link
log.debug("Cache miss for %s", args)
_cache_info.misses += 1
task: asyncio.Task[T] = asyncio.create_task(coro(*args, **kwargs))
task: asyncio.Task[T] = asyncio.create_task(coro(*args, **kwargs))
if maxsize is not None:
with lock:
if key in internal_cache:
pass
elif _cache_info.full:
if key not in internal_cache and _cache_info.full:
internal_cache.popitem(last=False)
internal_cache[key] = task
internal_cache.move_to_end(key)
else:
internal_cache[key] = task
_cache_info.full = cache_len() >= maxsize
internal_cache[key] = task
internal_cache.move_to_end(key)
_cache_info.full = cache_len() >= maxsize
_cache_info.currsize = cache_len()
return task
else:
internal_cache[key] = task

if ttl is not None:
call_after_ttl = partial(
asyncio.get_running_loop().call_later,
ttl,
internal_cache.pop,
key,
)
task.add_done_callback(call_after_ttl)
return task

def cache_info() -> CacheInfo:
return _cache_info
Expand All @@ -232,6 +257,7 @@ def cache_clear() -> None:
internal_cache.clear()
_cache_info.clear()

wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper
_wrapper = cast(CachedTask[S, P, T], wrapper)
_wrapper.cache_info = cache_info
_wrapper.cache_clear = cache_clear
return _wrapper
8 changes: 5 additions & 3 deletions dynamo/utils/context.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from enum import StrEnum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import aiohttp
import discord
from discord.ext import commands
from discord.ui import View

from dynamo._typing import V

if TYPE_CHECKING:
from dynamo.core.bot import Dynamo


class ConfirmationView(discord.ui.View):
class ConfirmationView(View):
"""A view for confirming an action"""

value: bool
Expand Down Expand Up @@ -53,7 +54,8 @@ async def _defer_and_stop(self, interaction: discord.Interaction[Dynamo]) -> Non
self.stop()

async def on_timeout(self) -> None:
for item in self.children:
for i in self.children:
item = cast(discord.ui.Button[ConfirmationView], i)
item.disabled = True

if self.message:
Expand Down
Loading

0 comments on commit a1d6495

Please sign in to comment.