Skip to content

Commit

Permalink
Use future instead of task for cache
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Sep 30, 2024
1 parent 41df285 commit 5f4f24a
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 73 deletions.
3 changes: 0 additions & 3 deletions dynamo/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from discord import app_commands
from discord.ext import commands
from discord.ui import View

P = ParamSpec("P")

Expand All @@ -19,8 +18,6 @@
CommandT = TypeVar("CommandT", bound=commands.Command[Any, ..., Any])
ContextT = TypeVar("ContextT", bound=commands.Context[Any], covariant=True)

V = TypeVar("V", bound="View", covariant=True)


class WrappedCoroutine[**P, T](Protocol):
"""A coroutine that has been wrapped by a decorator."""
Expand Down
5 changes: 1 addition & 4 deletions dynamo/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def __init__(self, connector: aiohttp.TCPConnector, session: aiohttp.ClientSessi
intents=intents,
enable_debug_events=True,
tree_cls=VersionableTree,
activity=discord.Activity(
name="The Cursed Apple",
type=discord.ActivityType.watching,
),
activity=discord.Activity(name="The Cursed Apple", type=discord.ActivityType.watching),
)

async def setup_hook(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions dynamo/core/logging_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

from dynamo.utils.helper import platformdir, resolve_path_with_links

known_messages: tuple[str, ...] = ("referencing an unknown", "PyNaCl is not installed, voice will NOT be supported")

class RemoveNoise(logging.Filter):
known_messages: tuple[str, ...] = ("referencing an unknown", "PyNaCl is not installed, voice will NOT be supported")

class RemoveNoise(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool | logging.LogRecord:
return not any(message in record.msg for message in self.known_messages)
return not any(message in record.msg for message in known_messages)


@contextmanager
Expand Down
11 changes: 6 additions & 5 deletions dynamo/extensions/cogs/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

def event_to_option(event: discord.ScheduledEvent) -> discord.SelectOption:
"""Convert a ScheduledEvent to a SelectOption to be used in a dropdown menu"""
description = event.description if event.description is not None else ""
return discord.SelectOption(label=event.name, value=str(event.id), description=shorten_string(description))
description = shorten_string(event.description or "")
return discord.SelectOption(label=event.name, value=str(event.id), description=description)


class EventsDropdown[V: discord.ui.View](discord.ui.Select[V]):
Expand Down Expand Up @@ -60,7 +60,7 @@ async def on_timeout(self) -> None:

class InterestedDropdown(EventsDropdown[EventsView]):
async def callback(self, interaction: discord.Interaction) -> None:
event: discord.ScheduledEvent | None = next(filter(lambda e: e.id == int(self.values[0]), self.events), None)
event: discord.ScheduledEvent | None = next((e for e in self.events if e.id == int(self.values[0])), None)
response = "No users found" if event is None else await get_interested(event)
await interaction.response.send_message(response, ephemeral=True)

Expand Down Expand Up @@ -152,8 +152,9 @@ async def event(self, ctx: Context, event: int | None = None) -> None:
# Prevent invokation when a view is already active by invoking user
self.active_users.add(ctx.author.id)

guild_not_cached = event_check.get_containing(ctx.guild, event) is None
fetch_message = "Events not cached, fetching..." if guild_not_cached else "Fetching events..."
# Message for when the events are cached or not
guild_cached = event_check.get_containing(ctx.guild, event) is not None
fetch_message = "Fetching events..." if guild_cached else "Events not cached, fetching..."
message = await ctx.send(f"{self.bot.app_emojis.get("loading2", "⏳")}\t{fetch_message}")

event_exists: str | list[discord.ScheduledEvent] = await event_check(ctx.guild, event)
Expand Down
2 changes: 1 addition & 1 deletion dynamo/extensions/cogs/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def generate_identicon(
description = (
f"**Generate this identicon:**\n" f"> {cmd_mention} {display_name}\n" f"> {prefix}identicon {display_name}"
)
e = discord.Embed(title=display_name, description=description, color=discord.Color.from_rgb(*fg.as_tuple()))
e = discord.Embed(title=display_name, description=description, color=fg.as_discord_color())
e.set_image(url="attachment://identicon.png")
return e, file

Expand Down
56 changes: 37 additions & 19 deletions dynamo/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable, Hashable, Sized
from dataclasses import dataclass
from functools import partial
from typing import Any, ParamSpec, Protocol, TypeVar, cast, overload
from typing import Any, ParamSpec, Protocol, TypeVar, cast, final, overload

from dynamo._types import MISSING, WrappedCoroutine

Expand All @@ -21,6 +21,7 @@
FAST_TYPES: set[type] = {int, str}


@final
class CachedTask[**P, T](Protocol):
__wrapped__: Callable[P, WrappedCoroutine[P, T]]
__call__: Callable[..., asyncio.Task[T]]
Expand All @@ -34,6 +35,7 @@ class CachedTask[**P, T](Protocol):
DecoratedCoroutine = Callable[[WrappedCoroutine[P, T]], CachedTask[P, T]]


@final
@dataclass(slots=True)
class CacheInfo:
"""Cache info for the async_cache decorator."""
Expand Down Expand Up @@ -201,50 +203,52 @@ def update_wrapper[**P, T](
if hasattr(wrapper, attr) and hasattr(wrapped, attr):
getattr(wrapper, attr).update(getattr(wrapped, attr))

wrapper.__wrapped__ = cast(Callable[..., WrappedCoroutine[P, T]], wrapped)
wrapper.__wrapped__ = cast(Callable[P, WrappedCoroutine[P, T]], wrapped)
return wrapper


def _cache_wrapper[**P, T](coro: WrappedCoroutine[P, T], maxsize: int | None, ttl: float | None) -> CachedTask[P, T]:
sentinel = MISSING
make_key = _make_key

internal_cache: OrderedDict[Hashable, asyncio.Task[T]] = OrderedDict()
internal_cache: OrderedDict[Hashable, asyncio.Future[T]] = OrderedDict()
cache_get = internal_cache.get
cache_len = internal_cache.__len__
lock = threading.Lock()
_cache_info = CacheInfo()

def wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
key: Hashable = make_key(args, kwargs)
result = cache_get(key, sentinel)
future = cache_get(key, sentinel)

# Mitigate lock contention on cache hit
if result is not sentinel:
if future is not sentinel:
log.debug("Cache hit for %s", args)
_cache_info.hits += 1
return result
return asyncio.create_task(_wrap_future(future))

with lock:
result = cache_get(key, sentinel)
if result is not sentinel:
future = cache_get(key, sentinel)
if future is not sentinel:
log.debug("Cache hit for %s", args)
_cache_info.hits += 1
return result
return asyncio.create_task(_wrap_future(future))
log.debug("Cache miss for %s", args)
_cache_info.misses += 1

task: asyncio.Task[T] = asyncio.create_task(coro(*args, **kwargs))
if maxsize is not None:
with lock:
future = asyncio.get_running_loop().create_future()

with lock:
if maxsize is not None:
if key not in internal_cache and _cache_info.full:
log.debug("Eviction: LRU cache is full")
internal_cache.popitem(last=False)
internal_cache[key] = task
internal_cache[key] = future
internal_cache.move_to_end(key)
_cache_info.full = cache_len() >= maxsize
_cache_info.currsize = cache_len()
else:
internal_cache[key] = task
else:
internal_cache[key] = future

if ttl is not None:

Expand All @@ -254,8 +258,21 @@ def evict(k: Hashable, default: Any = MISSING) -> None:
internal_cache.pop(k, default)

call_after_ttl = partial(asyncio.get_running_loop().call_later, ttl, evict, key)
task.add_done_callback(call_after_ttl)
return task
future.add_done_callback(call_after_ttl)

async def run_coro():
try:
result = await coro(*args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
raise
return result

return asyncio.create_task(run_coro())

async def _wrap_future(future: asyncio.Future[T]) -> T:
return await future

def cache_info() -> CacheInfo:
return _cache_info
Expand All @@ -267,8 +284,9 @@ def cache_clear() -> None:

def get_containing(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T] | None:
key = make_key(args, kwargs)
result = cache_get(key, sentinel)
return result if result is not sentinel else None
with lock:
future = cache_get(key, sentinel)
return asyncio.create_task(_wrap_future(future)) if future is not sentinel else None

_wrapper = cast(CachedTask[P, T], wrapper)
_wrapper.cache_info = cache_info
Expand Down
15 changes: 4 additions & 11 deletions dynamo/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from discord.ext import commands
from discord.ui import View

from dynamo._types import V

if TYPE_CHECKING:
from dynamo.core.bot import Dynamo

Expand Down Expand Up @@ -41,19 +39,14 @@ async def interaction_check(self, interaction: discord.Interaction) -> bool:
return bool(interaction.user and interaction.user.id == self.author_id)

async def _defer_and_stop(self, interaction: discord.Interaction[Dynamo]) -> None:
"""Defer the interaction and stop the view.
Parameters
----------
interaction : discord.Interaction
The interaction to defer.
"""
"""Defer the interaction and stop the view."""
await interaction.response.defer()
if self.delete_after and self.message:
await interaction.delete_original_response()
self.stop()

async def on_timeout(self) -> None:
"""Disable the buttons and delete the message"""
for i in self.children:
item = cast(discord.ui.Button[ConfirmationView], i)
item.disabled = True
Expand All @@ -62,13 +55,13 @@ async def on_timeout(self) -> None:
await self.message.delete()

@discord.ui.button(label="Confirm", style=discord.ButtonStyle.green)
async def confirm(self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
async def confirm[V: View](self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
"""Confirm the action"""
self.value = True
await self._defer_and_stop(interaction)

@discord.ui.button(label="Cancel", style=discord.ButtonStyle.red)
async def cancel(self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
async def cancel[V: View](self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
"""Cancel the action"""
await self._defer_and_stop(interaction)

Expand Down
2 changes: 1 addition & 1 deletion dynamo/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def convert(self, ctx: commands.Context[BotT], argument: str) -> SeedLike:
return cast(SeedLike, result)

@override
async def transform(self, interaction: discord.Interaction, value: str) -> discord.Member | str:
async def transform(self, interaction: discord.Interaction, value: str) -> SeedLike:
# No need to reinvent the wheel, just run it through the commands.MemberConverter method.
ctx = await Context.from_interaction(cast(discord.Interaction[Dynamo], interaction))
return await self.convert(ctx, value)
Expand Down
6 changes: 4 additions & 2 deletions dynamo/utils/identicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from dataclasses import dataclass
from io import BytesIO

import discord
import numpy as np
from numpy.typing import NDArray
from PIL import Image

from dynamo.utils.cache import async_cache
from dynamo.utils.wrappers import timer

# 0.0 = same color | 1.0 = different color
COLOR_THRESHOLD = 0.4
Expand Down Expand Up @@ -66,6 +66,9 @@ def perceived_distance(self, other: RGB) -> float:
def as_tuple(self) -> tuple[int, int, int]:
return self.r, self.g, self.b

def as_discord_color(self) -> discord.Color:
return discord.Color.from_rgb(*self.as_tuple())


def make_color(rng: np.random.Generator) -> RGB:
return RGB(*map(int, rng.integers(low=0, high=256, size=3, dtype=int)))
Expand Down Expand Up @@ -121,7 +124,6 @@ async def get_identicon(idt: Identicon, size: int = 256) -> bytes:
Get an identicon as bytes
"""

@timer
def _buffer(idt: Identicon, size: int) -> bytes:
buffer = BytesIO()
Image.fromarray(idt.icon.astype("uint8")).convert("RGB").resize((size, size), Image.Resampling.NEAREST).save(
Expand Down
8 changes: 4 additions & 4 deletions dynamo/utils/spotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def track_duration(seconds: int) -> str:
"""
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{hours}:{minutes:02d}:{seconds:02d}" if hours else f"{minutes:02d}:{seconds:02d}"
return f"{f'{hours}:' if hours else ''}{minutes:02d}:{seconds:02d}"


def get_progress(end: datetime.datetime, duration: datetime.timedelta) -> float:
Expand Down Expand Up @@ -276,9 +276,9 @@ class StaticDrawArgs:

def draw_static_elements(args: StaticDrawArgs) -> None:
# Draw artist name
draw.text( # type: ignore
(CONTENT_START_X, TITLE_START_Y + TITLE_FONT_SIZE + 5),
text=", ".join(args.artists),
args.draw.text( # type: ignore
xy=(CONTENT_START_X, TITLE_START_Y + TITLE_FONT_SIZE + 5),
text=str(", ".join(args.artists)),
fill=TEXT_COLOR,
font=args.artist_font,
)
Expand Down
12 changes: 4 additions & 8 deletions dynamo/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@ def timer[**P, T](func: Callable[P, T]) -> Callable[P, T]: ...


def timer[**P, T](func: Callable[P, T] | WrappedCoroutine[P, T]) -> Callable[P, T] | WrappedCoroutine[P, T]:
if asyncio.iscoroutinefunction(func):

async def async_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
with time_it(func.__name__):
return await func(*args, **kwargs)

return async_wrap
async def async_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
with time_it(func.__name__):
return await func(*args, **kwargs)

def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with time_it(func.__name__):
return cast(T, func(*args, **kwargs))

return wrapper
return async_wrap if asyncio.iscoroutinefunction(func) else wrapper
Loading

0 comments on commit 5f4f24a

Please sign in to comment.