Skip to content

Commit

Permalink
Merge pull request #991 from adamantike/feat/use-async-requests-for-sgdb
Browse files Browse the repository at this point in the history
feat: Use async requests to retrieve SteamGridDB covers
  • Loading branch information
zurdi15 authored Jul 10, 2024
2 parents c82ae97 + b61bb11 commit 7590fd3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 42 deletions.
9 changes: 5 additions & 4 deletions backend/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def search_cover(
detail="No SteamGridDB enabled",
)

return [
SearchCoverSchema.model_validate(cover)
for cover in meta_sgdb_handler.get_details(search_term)
]
covers = await meta_sgdb_handler.get_details(
requests_client=request.app.requests_client, search_term=search_term
)

return [SearchCoverSchema.model_validate(cover) for cover in covers]
91 changes: 54 additions & 37 deletions backend/handler/metadata/sgdb_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import itertools
from typing import Any, Final

import requests
import httpx
from config import STEAMGRIDDB_API_KEY
from logger.logger import log

Expand Down Expand Up @@ -32,23 +33,41 @@ def __init__(self) -> None:
"Accept": "*/*",
}

def get_details(self, search_term: str) -> list[dict[str, Any]]:
search_response = requests.get(
f"{self.search_endpoint}/{search_term}",
headers=self.headers,
timeout=120,
async def get_details(
self, requests_client: httpx.AsyncClient, search_term: str
) -> list[dict[str, Any]]:
search_response = (
await requests_client.get(
f"{self.search_endpoint}/{search_term}",
headers=self.headers,
timeout=120,
)
).json()

if len(search_response["data"]) == 0:
log.warning(f"Could not find '{search_term}' on SteamGridDB")
return []

games = []
for game in search_response["data"]:
game_covers = []
for page in itertools.count(start=0):
covers_response = requests.get(
f"{self.grid_endpoint}/{game['id']}",
tasks = [
self._get_game_covers(
requests_client=requests_client,
game_id=game["id"],
game_name=game["name"],
)
for game in search_response["data"]
]
results = await asyncio.gather(*tasks)

return list(filter(None, results))

async def _get_game_covers(
self, requests_client: httpx.AsyncClient, game_id: int, game_name: str
) -> dict[str, Any] | None:
game_covers = []
for page in itertools.count(start=0):
covers_response = (
await requests_client.get(
f"{self.grid_endpoint}/{game_id}",
headers=self.headers,
timeout=120,
params={
Expand All @@ -57,32 +76,30 @@ def get_details(self, search_term: str) -> list[dict[str, Any]]:
"limit": SGDB_API_COVER_LIMIT,
"page": page,
},
).json()
page_covers = covers_response["data"]

game_covers.extend(page_covers)
if len(page_covers) < SGDB_API_COVER_LIMIT:
break

if game_covers:
games.append(
{
"name": game["name"],
"resources": [
{
"thumb": cover["thumb"],
"url": cover["url"],
"type": (
"animated"
if cover["thumb"].endswith(".webm")
else "static"
),
}
for cover in game_covers
],
}
)
return games
).json()
page_covers = covers_response["data"]

game_covers.extend(page_covers)
if len(page_covers) < SGDB_API_COVER_LIMIT:
break

if not game_covers:
return None

return {
"name": game_name,
"resources": [
{
"thumb": cover["thumb"],
"url": cover["url"],
"type": (
"animated" if cover["thumb"].endswith(".webm") else "static"
),
}
for cover in game_covers
],
}


sgdb_handler = SGDBBaseHandler()
13 changes: 12 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import re
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

import alembic.config
import endpoints.sockets.scan # noqa
import httpx
import uvicorn
from config import DEV_HOST, DEV_PORT, DISABLE_CSRF_PROTECTION, ROMM_AUTH_SECRET_KEY
from endpoints import (
Expand Down Expand Up @@ -32,7 +35,15 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from utils import get_version

app = FastAPI(title="RomM API", version=get_version())

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with httpx.AsyncClient() as client:
app.requests_client = client # type: ignore[attr-defined]
yield


app = FastAPI(title="RomM API", version=get_version(), lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
Expand Down

0 comments on commit 7590fd3

Please sign in to comment.