Skip to content

Commit

Permalink
misc: Add Redis async cache
Browse files Browse the repository at this point in the history
Introduce an asynchronous Redis instance to be used in async functions.
Also, this change migrates most of the sync cache usage to the new async
cache.
  • Loading branch information
adamantike committed Jul 21, 2024
1 parent dd4f9e4 commit d9a30e6
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 32 deletions.
7 changes: 4 additions & 3 deletions backend/endpoints/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

import pytest
from fastapi.testclient import TestClient
from handler.redis_handler import cache
from handler.redis_handler import async_cache, sync_cache
from main import app
from models.user import Role

client = TestClient(app)


@pytest.fixture(autouse=True)
def clear_cache():
async def clear_cache():
yield
cache.flushall()
sync_cache.flushall()
await async_cache.flushall()


def test_login_logout(admin_user):
Expand Down
24 changes: 12 additions & 12 deletions backend/handler/metadata/base_hander.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unicodedata
from typing import Final

from handler.redis_handler import cache
from handler.redis_handler import async_cache, sync_cache
from logger.logger import log
from tasks.update_switch_titledb import (
SWITCH_PRODUCT_ID_KEY,
Expand All @@ -18,9 +18,9 @@ def conditionally_set_cache(
index_key: str, filename: str, parent_dir: str = os.path.dirname(__file__)
) -> None:
fixtures_path = os.path.join(parent_dir, "fixtures")
if not cache.exists(index_key):
if not sync_cache.exists(index_key):
index_data = json.loads(open(os.path.join(fixtures_path, filename)).read())
with cache.pipeline() as pipe:
with sync_cache.pipeline() as pipe:
for data_batch in batched(index_data.items(), 2000):
data_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(index_key, mapping=data_map)
Expand Down Expand Up @@ -99,15 +99,15 @@ def _normalize_exact_match(name: str) -> str:

async def _ps2_opl_format(self, match: re.Match[str], search_term: str) -> str:
serial_code = match.group(1)
index_entry = cache.hget(PS2_OPL_KEY, serial_code)
index_entry = await async_cache.hget(PS2_OPL_KEY, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
search_term = index_entry["Name"] # type: ignore

return search_term

async def _sony_serial_format(self, index_key: str, serial_code: str) -> str | None:
index_entry = cache.hget(index_key, serial_code)
index_entry = await async_cache.hget(index_key, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["title"]
Expand Down Expand Up @@ -140,15 +140,15 @@ async def _switch_titledb_format(
) -> tuple[str, dict | None]:
title_id = match.group(1)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.warning("Fetching the Switch titleID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.error("Could not fetch the Switch titleID index file")
return search_term, None

index_entry = cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
index_entry = await async_cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -165,15 +165,15 @@ async def _switch_productid_format(
product_id[-3] = "0"
product_id = "".join(product_id)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.warning("Fetching the Switch productID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.error("Could not fetch the Switch productID index file")
return search_term, None

index_entry = cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
index_entry = await async_cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -183,7 +183,7 @@ async def _switch_productid_format(
async def _mame_format(self, search_term: str) -> str:
from handler.filesystem import fs_rom_handler

index_entry = cache.hget(MAME_XML_KEY, search_term)
index_entry = await async_cache.hget(MAME_XML_KEY, search_term)
if index_entry:
index_entry = json.loads(index_entry)
search_term = fs_rom_handler.get_file_name_with_no_tags(
Expand Down
10 changes: 5 additions & 5 deletions backend/handler/metadata/igdb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from config import IGDB_CLIENT_ID, IGDB_CLIENT_SECRET
from fastapi import HTTPException, status
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from logger.logger import log
from requests.exceptions import HTTPError, Timeout
from typing_extensions import TypedDict
Expand Down Expand Up @@ -592,8 +592,8 @@ def _update_twitch_token(self) -> str:
return ""

# Set token in redis to expire in <expires_in> seconds
cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]

log.info("Twitch token fetched!")

Expand All @@ -608,8 +608,8 @@ def get_oauth_token(self) -> str:
return ""

# Fetch the token cache
token = cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]
token = sync_cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = sync_cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]

if not token or time.time() > float(token_expires_at or 0):
log.warning("Twitch token invalid: fetching a new one...")
Expand Down
33 changes: 28 additions & 5 deletions backend/handler/redis_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from config import REDIS_DB, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT, REDIS_USERNAME
from logger.logger import log
from redis import Redis, StrictRedis
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from rq import Queue


Expand Down Expand Up @@ -31,12 +32,12 @@ class QueuePrio(Enum):
low_prio_queue = Queue(name=QueuePrio.LOW.value, connection=redis_client)


def __get_cache() -> StrictRedis:
def __get_sync_cache() -> Redis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeStrictRedis
from fakeredis import FakeRedis

return FakeStrictRedis(version=7)
return FakeRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
Expand All @@ -52,4 +53,26 @@ def __get_cache() -> StrictRedis:
return client


cache = __get_cache()
def __get_async_cache() -> AsyncRedis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeAsyncRedis

return FakeAsyncRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
client = AsyncRedis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
username=REDIS_USERNAME,
db=REDIS_DB,
decode_responses=True,
)
log.info(f"Redis connection established in {sys.argv[0]}!")
return client


sync_cache = __get_sync_cache()
async_cache = __get_async_cache()
4 changes: 2 additions & 2 deletions backend/models/firmware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from handler.metadata.base_hander import conditionally_set_cache
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from models.base import BaseModel
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand Down Expand Up @@ -55,7 +55,7 @@ def platform_name(self) -> str:

@cached_property
def is_verified(self) -> bool:
cache_entry = cache.hget(
cache_entry = sync_cache.hget(
KNOWN_BIOS_KEY, f"{self.platform_slug}:{self.file_name}"
)
if cache_entry:
Expand Down
10 changes: 5 additions & 5 deletions backend/tasks/update_switch_titledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from handler.redis_handler import cache
from handler.redis_handler import async_cache
from logger.logger import log
from tasks.tasks import RemoteFilePullTask
from utils.iterators import batched
Expand All @@ -32,19 +32,19 @@ async def run(self, force: bool = False) -> None:
index_json = json.loads(content)
relevant_data = {k: v for k, v in index_json.items() if k and v}

with cache.pipeline() as pipe:
async with async_cache.pipeline() as pipe:
for data_batch in batched(relevant_data.items(), 2000):
titledb_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
await pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
for data_batch in batched(relevant_data.items(), 2000):
product_map = {
v["id"]: json.dumps(v)
for v in dict(data_batch).values()
if v.get("id")
}
if product_map:
pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
pipe.execute()
await pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
await pipe.execute()

log.info("Scheduled switch titledb update completed!")

Expand Down

0 comments on commit d9a30e6

Please sign in to comment.