Skip to content

Commit

Permalink
Merge pull request #1010 from rommapp/misc/add-async-cache
Browse files Browse the repository at this point in the history
misc: Add Redis async cache
  • Loading branch information
adamantike authored Jul 22, 2024
2 parents dd4f9e4 + 6eb8e6a commit e1cec57
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 31 deletions.
4 changes: 2 additions & 2 deletions backend/endpoints/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from fastapi.testclient import TestClient
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from main import app
from models.user import Role

Expand All @@ -12,7 +12,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
yield
cache.flushall()
sync_cache.flushall()


def test_login_logout(admin_user):
Expand Down
24 changes: 12 additions & 12 deletions backend/handler/metadata/base_hander.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unicodedata
from typing import Final

from handler.redis_handler import cache
from handler.redis_handler import async_cache, sync_cache
from logger.logger import log
from tasks.update_switch_titledb import (
SWITCH_PRODUCT_ID_KEY,
Expand All @@ -18,9 +18,9 @@ def conditionally_set_cache(
index_key: str, filename: str, parent_dir: str = os.path.dirname(__file__)
) -> None:
fixtures_path = os.path.join(parent_dir, "fixtures")
if not cache.exists(index_key):
if not sync_cache.exists(index_key):
index_data = json.loads(open(os.path.join(fixtures_path, filename)).read())
with cache.pipeline() as pipe:
with sync_cache.pipeline() as pipe:
for data_batch in batched(index_data.items(), 2000):
data_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(index_key, mapping=data_map)
Expand Down Expand Up @@ -99,15 +99,15 @@ def _normalize_exact_match(name: str) -> str:

async def _ps2_opl_format(self, match: re.Match[str], search_term: str) -> str:
serial_code = match.group(1)
index_entry = cache.hget(PS2_OPL_KEY, serial_code)
index_entry = await async_cache.hget(PS2_OPL_KEY, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
search_term = index_entry["Name"] # type: ignore

return search_term

async def _sony_serial_format(self, index_key: str, serial_code: str) -> str | None:
index_entry = cache.hget(index_key, serial_code)
index_entry = await async_cache.hget(index_key, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["title"]
Expand Down Expand Up @@ -140,15 +140,15 @@ async def _switch_titledb_format(
) -> tuple[str, dict | None]:
title_id = match.group(1)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.warning("Fetching the Switch titleID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.error("Could not fetch the Switch titleID index file")
return search_term, None

index_entry = cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
index_entry = await async_cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -165,15 +165,15 @@ async def _switch_productid_format(
product_id[-3] = "0"
product_id = "".join(product_id)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.warning("Fetching the Switch productID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.error("Could not fetch the Switch productID index file")
return search_term, None

index_entry = cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
index_entry = await async_cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -183,7 +183,7 @@ async def _switch_productid_format(
async def _mame_format(self, search_term: str) -> str:
from handler.filesystem import fs_rom_handler

index_entry = cache.hget(MAME_XML_KEY, search_term)
index_entry = await async_cache.hget(MAME_XML_KEY, search_term)
if index_entry:
index_entry = json.loads(index_entry)
search_term = fs_rom_handler.get_file_name_with_no_tags(
Expand Down
10 changes: 5 additions & 5 deletions backend/handler/metadata/igdb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from config import IGDB_CLIENT_ID, IGDB_CLIENT_SECRET
from fastapi import HTTPException, status
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from logger.logger import log
from requests.exceptions import HTTPError, Timeout
from typing_extensions import TypedDict
Expand Down Expand Up @@ -592,8 +592,8 @@ def _update_twitch_token(self) -> str:
return ""

# Set token in redis to expire in <expires_in> seconds
cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]

log.info("Twitch token fetched!")

Expand All @@ -608,8 +608,8 @@ def get_oauth_token(self) -> str:
return ""

# Fetch the token cache
token = cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]
token = sync_cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = sync_cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]

if not token or time.time() > float(token_expires_at or 0):
log.warning("Twitch token invalid: fetching a new one...")
Expand Down
33 changes: 28 additions & 5 deletions backend/handler/redis_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from config import REDIS_DB, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT, REDIS_USERNAME
from logger.logger import log
from redis import Redis, StrictRedis
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from rq import Queue


Expand Down Expand Up @@ -31,12 +32,12 @@ class QueuePrio(Enum):
low_prio_queue = Queue(name=QueuePrio.LOW.value, connection=redis_client)


def __get_cache() -> StrictRedis:
def __get_sync_cache() -> Redis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeStrictRedis
from fakeredis import FakeRedis

return FakeStrictRedis(version=7)
return FakeRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
Expand All @@ -52,4 +53,26 @@ def __get_cache() -> StrictRedis:
return client


cache = __get_cache()
def __get_async_cache() -> AsyncRedis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeAsyncRedis

return FakeAsyncRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
client = AsyncRedis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
username=REDIS_USERNAME,
db=REDIS_DB,
decode_responses=True,
)
log.info(f"Redis connection established in {sys.argv[0]}!")
return client


sync_cache = __get_sync_cache()
async_cache = __get_async_cache()
4 changes: 2 additions & 2 deletions backend/models/firmware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from handler.metadata.base_hander import conditionally_set_cache
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from models.base import BaseModel
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand Down Expand Up @@ -55,7 +55,7 @@ def platform_name(self) -> str:

@cached_property
def is_verified(self) -> bool:
cache_entry = cache.hget(
cache_entry = sync_cache.hget(
KNOWN_BIOS_KEY, f"{self.platform_slug}:{self.file_name}"
)
if cache_entry:
Expand Down
10 changes: 5 additions & 5 deletions backend/tasks/update_switch_titledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from handler.redis_handler import cache
from handler.redis_handler import async_cache
from logger.logger import log
from tasks.tasks import RemoteFilePullTask
from utils.iterators import batched
Expand All @@ -32,19 +32,19 @@ async def run(self, force: bool = False) -> None:
index_json = json.loads(content)
relevant_data = {k: v for k, v in index_json.items() if k and v}

with cache.pipeline() as pipe:
async with async_cache.pipeline() as pipe:
for data_batch in batched(relevant_data.items(), 2000):
titledb_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
await pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
for data_batch in batched(relevant_data.items(), 2000):
product_map = {
v["id"]: json.dumps(v)
for v in dict(data_batch).values()
if v.get("id")
}
if product_map:
pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
pipe.execute()
await pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
await pipe.execute()

log.info("Scheduled switch titledb update completed!")

Expand Down

0 comments on commit e1cec57

Please sign in to comment.