Skip to content

Commit

Permalink
More type fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Sep 26, 2024
1 parent 6c7fa74 commit 431c674
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
21 changes: 12 additions & 9 deletions dynamo/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
import threading
from collections import OrderedDict
from collections.abc import Callable, Coroutine, Hashable, MutableSequence, Sized
from collections.abc import Callable, Coroutine, Hashable, Sized
from dataclasses import dataclass
from functools import partial
from typing import Any, Protocol, Self, cast, overload
from typing import Any, Protocol, cast, overload

from dynamo._typing import MISSING, P, T

Expand Down Expand Up @@ -36,14 +36,14 @@ def clear(self) -> None:
self.full = False


class HashedSeq(MutableSequence[Any]):
class HashedSeq(list[Any]):
__slots__: tuple[str, ...] = ("hash_value",)

def __init__(self, /, *args: Any, _hash: Callable[[object], int] = hash) -> None:
self[:] = args
self.hash_value: int = _hash(args)

def __hash__(self) -> int:
def __hash__(self) -> int: # type: ignore
return self.hash_value


Expand Down Expand Up @@ -75,11 +75,7 @@ def __wrapped__(self) -> WrappedCoroutine[P, T]: ...
@__wrapped__.setter
def __wrapped__(self, value: WrappedCoroutine[P, T]) -> None: ...

@overload
def __call__(self: Self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]: ...
@overload
def __call__(self: CachedTask[P, T], *args: Any, **kwargs: Any) -> asyncio.Task[T]: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]: ...
def __call__(self, *args: Any, **kwargs: Any) -> asyncio.Task[T]: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_parameters(self) -> dict[str, int | float | None]: ...
Expand Down Expand Up @@ -221,6 +217,13 @@ def _async_cache_wrapper(

def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task[T]:
key: Hashable = make_key(args, kwargs)
result = cache_get(key, sentinel)

# Mitigate lock contention on cache hit
if result is not sentinel:
log.debug("Cache hit for %s", args)
_cache_info.hits += 1
return result
with lock:
result = cache_get(key, sentinel)
if result is not sentinel:
Expand Down
13 changes: 7 additions & 6 deletions dynamo/utils/converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import discord
from discord import app_commands
Expand All @@ -16,11 +16,12 @@ class GuildConverter(commands.GuildConverter):
"""Convert an argument to a guild. If not found, return the current guild. If there's no guild at all,
return the argument."""

async def convert(self, ctx: Context, argument: Any) -> discord.Guild | Any:
async def convert(self, ctx: Context, argument: str) -> discord.Guild | str: # type: ignore
try:
return await commands.GuildConverter().convert(ctx, argument)
result: discord.Guild = await commands.GuildConverter().convert(ctx, argument)
except commands.GuildNotFound:
return ctx.guild or argument
return argument if ctx.guild is None else ctx.guild
return result


class SeedConverter(commands.Converter[discord.Member | str], app_commands.Transformer):
Expand All @@ -31,13 +32,13 @@ class SeedConverter(commands.Converter[discord.Member | str], app_commands.Trans
:func:`discord.ext.commands.MemberConverter.convert`
"""

async def convert(self, ctx: Context, argument: str) -> discord.Member | str:
async def convert(self, ctx: Context, argument: str) -> discord.Member | str: # type: ignore
try:
return await commands.MemberConverter().convert(ctx, argument)
except commands.MemberNotFound:
return argument

async def transform(self, interaction: discord.Interaction[Dynamo], value: str) -> discord.Member | str:
async def transform(self, interaction: discord.Interaction[Dynamo], value: str) -> discord.Member | str: # type: ignore
# No need to reinvent the wheel, just run it through the commands.MemberConverter method.
ctx = await Context.from_interaction(interaction)
return await self.convert(ctx, value)
Expand Down
20 changes: 11 additions & 9 deletions dynamo/utils/spotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
BACKGROUND_COLOR: tuple[int, int, int] = (5, 5, 25)

# White
TEXT_COLOR = PROGRESS_BAR_COLOR = (255, 255, 255)
TEXT_COLOR: tuple[int, int, int] = (255, 255, 255)
PROGRESS_BAR_COLOR: tuple[int, int, int] = (255, 255, 255)

# Light gray
LENGTH_BAR_COLOR: tuple[int, int, int] = (64, 64, 64)
Expand Down Expand Up @@ -197,7 +198,7 @@ def _draw(

# Draw only one frame if the title fits
if title_width <= available_width:
base_draw.text((CONTENT_START_X, TITLE_START_Y), text=name, fill=TEXT_COLOR, font=title_font)
base_draw.text((CONTENT_START_X, TITLE_START_Y), text=name, fill=TEXT_COLOR, font=title_font) # type: ignore

draw_static_elements(base_draw, base, artists, artist_font, duration, end, spotify_logo)

Expand All @@ -208,7 +209,7 @@ def _draw(

# Generate scrolling frames for the title
title_frames = draw_text_scroll(title_font, name, available_width)
num_frames = len(title_frames)
num_frames = len(list(title_frames))

frames: list[Image.Image] = []
for _ in range(num_frames):
Expand All @@ -234,7 +235,7 @@ def draw_static_elements(
spotify_logo: Image.Image,
) -> None:
# Draw artist name
draw.text(
draw.text( # type: ignore
(CONTENT_START_X, TITLE_START_Y + TITLE_FONT_SIZE + 5),
text=", ".join(artists),
fill=TEXT_COLOR,
Expand Down Expand Up @@ -281,7 +282,7 @@ def _draw_track_bar(draw: ImageDraw.ImageDraw, progress: float, duration: dateti
progress_text = f"{played} / {track_duration_str}"
progress_font = get_font(progress_text, bold=False, size=PROGRESS_FONT_SIZE)

draw.text((x, PROGRESS_TEXT_Y), text=progress_text, fill=TEXT_COLOR, font=progress_font)
draw.text((x, PROGRESS_TEXT_Y), text=progress_text, fill=TEXT_COLOR, font=progress_font) # type: ignore


def draw_text_scroll(font: ImageFont.FreeTypeFont, text: str, width: int) -> Generator[Image.Image, None, None]:
Expand All @@ -302,19 +303,20 @@ def draw_text_scroll(font: ImageFont.FreeTypeFont, text: str, width: int) -> Gen
A frame of the text scrolling
"""
text_bbox = font.getbbox(text)
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
text_width = int(text_bbox[2] - text_bbox[0])
text_height = int(text_bbox[3] - text_bbox[1])

if text_width <= width:
frame = Image.new("RGBA", (width, text_height))
frame_draw = ImageDraw.Draw(frame)
frame_draw.text((0, 0), text, fill=TEXT_COLOR, font=font)
frame_draw.text((0, 0), text, fill=TEXT_COLOR, font=font) # type: ignore
yield frame
return

# Add space between end and start for continuous scrolling
full_text = text + " " + text
full_text_bbox = font.getbbox(full_text)
full_text_width = full_text_bbox[2] - full_text_bbox[0]
full_text_width = int(full_text_bbox[2] - full_text_bbox[0])

pause_frames = 30
total_frames = pause_frames + (full_text_width // SLIDING_SPEED)
Expand All @@ -325,7 +327,7 @@ def draw_text_scroll(font: ImageFont.FreeTypeFont, text: str, width: int) -> Gen

x_pos = 0 if i < pause_frames else -((i - pause_frames) * SLIDING_SPEED) % full_text_width

frame_draw.text((x_pos, 0), full_text, fill=TEXT_COLOR, font=font)
frame_draw.text((x_pos, 0), full_text, fill=TEXT_COLOR, font=font) # type: ignore
yield frame


Expand Down
2 changes: 1 addition & 1 deletion dynamo/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def human_timedelta(

attrs = [("year", "y"), ("month", "mo"), ("day", "d"), ("hour", "h"), ("minute", "m"), ("second", "s")]

output = []
output: list[str] = []
for attr, brief_attr in attrs:
if not (elem := getattr(delta, attr + "s")):
continue
Expand Down
19 changes: 19 additions & 0 deletions tests/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ async def async_cacheable_sized(x: int) -> int:
assert async_cacheable_sized.cache_info().currsize == 0


@pytest.mark.asyncio
@settings(deadline=None)
@given(inputs=st.lists(st.integers(min_value=1, max_value=100), min_size=10, max_size=10))
async def test_maxsize_enforcement(inputs: list[int]) -> None:
"""Test that the cache enforces the maxsize."""

@dynamo.utils.cache.async_cache(maxsize=5)
async def async_cacheable_sized(x: int) -> int:
await asyncio.sleep(0.01)
return x * 2

for i in inputs:
await async_cacheable_sized(i)

cache_info = async_cacheable_sized.cache_info()
assert cache_info.currsize <= 5
assert cache_info.currsize == min(len(set(inputs)), 5)


@pytest.mark.asyncio
@settings(deadline=None)
@given(inputs=st.lists(st.integers(min_value=1, max_value=100), min_size=10, max_size=50))
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_plural_value_immutable(value: int) -> None:
"""Test the plural value is immutable"""
p = dynamo.utils.format.plural(value)
with pytest.raises(AttributeError):
p.value = value + 1
p.value = value + 1 # type: ignore


@pytest.mark.parametrize(
Expand Down

0 comments on commit 431c674

Please sign in to comment.