From 0905ccc1d8424f2dafbd4e3617a0e1f1b5da6359 Mon Sep 17 00:00:00 2001 From: Nursultan Kudaibergenov Date: Thu, 25 Jul 2024 13:03:22 +0600 Subject: [PATCH] implemented granular cache invalidation based on tags --- examples/in_memory/main.py | 74 +++++++++++++++++++-- fastapi_cache/decorator.py | 56 +++++++++++++++- fastapi_cache/tag_provider.py | 93 ++++++++++++++++++++++++++ fastapi_cache/types.py | 22 ++++++- tests/test_cache_invalidation.py | 110 +++++++++++++++++++++++++++++++ 5 files changed, 348 insertions(+), 7 deletions(-) create mode 100644 fastapi_cache/tag_provider.py create mode 100644 tests/test_cache_invalidation.py diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index f4de1a0..17b3938 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -1,13 +1,15 @@ # pyright: reportGeneralTypeIssues=false +from collections import defaultdict from contextlib import asynccontextmanager -from typing import AsyncIterator, Dict, Optional +from typing import Annotated, AsyncIterator, Dict, List, Optional import pendulum import uvicorn -from fastapi import FastAPI +from fastapi import Body, FastAPI, Query from fastapi_cache import FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend -from fastapi_cache.decorator import cache +from fastapi_cache.decorator import cache, cache_invalidator +from fastapi_cache.tag_provider import TagProvider from pydantic import BaseModel from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -65,7 +67,7 @@ async def get_kwargs(name: str): @app.get("/sync-me") -@cache(namespace="test") # pyright: ignore[reportArgumentType] +@cache(namespace="test") # pyright: ignore[reportArgumentType] def sync_me(): # as per the fastapi docs, this sync function is wrapped in a thread, # thereby converted to async. fastapi-cache does the same. @@ -115,8 +117,10 @@ async def uncached_put(): put_ret = put_ret + 1 return {"value": put_ret} + put_ret2 = 0 + @app.get("/cached_put") @cache(namespace="test", expire=5) async def cached_put(): @@ -126,7 +130,7 @@ async def cached_put(): @app.get("/namespaced_injection") -@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType] +@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType] def namespaced_injection( __fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17 ) -> Dict[str, int]: @@ -136,5 +140,65 @@ def namespaced_injection( } +# Note: examples with cache invalidation +files = defaultdict( + list, + { + 1: [1, 2, 3], + 2: [4, 5, 6], + 3: [100], + }, +) + +FileTagProvider = TagProvider("file") + + +# Note: providing tags for future granular cache invalidation +@app.get("/files") +@cache(expire=10, tag_provider=FileTagProvider) +async def get_files(file_id_in: Annotated[Optional[List[int]], Query()] = None): + return [ + {"id": k, "value": v} + for k, v in files.items() + if (True if not file_id_in else k in file_id_in) + ] + + +# Note: here we're retrieving keys by file_id, so we also need to invalidate this, when file changes +@app.get("/files/{file_id:int}") +@cache( + expire=10, + tag_provider=FileTagProvider, + items_provider=lambda data, method_args, method_kwargs: [ + {"id": method_kwargs["file_id"]} + ], +) +async def get_file_keys(file_id: int): + if file_id in files: + return files[file_id] + return Response("file id not found") + + +# Note: here we can use default invalidator, because in response we have :id: +@app.patch("/files/{file_id:int}") +@cache_invalidator(tag_provider=FileTagProvider) +async def edit_file(file_id: int, items: Annotated[List[int], Body(embed=True)]): + files[file_id] = items + return { + "id": file_id, + "value": files[file_id] + } + + +# Note: here we need to use custom :invalidator: because we don't have access to identifier in response +@app.delete("/files/{file_id:int}") +@cache_invalidator( + tag_provider=FileTagProvider, invalidator=lambda resp, kwargs: kwargs["file_id"] +) +async def delete_file(file_id: int): + if file_id in files: + del files[file_id] + + if __name__ == "__main__": uvicorn.run("main:app", reload=True) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 7df09e8..4f48b2b 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -29,7 +29,8 @@ from fastapi_cache import FastAPICache from fastapi_cache.coder import Coder -from fastapi_cache.types import KeyBuilder +from fastapi_cache.tag_provider import TagProvider +from fastapi_cache.types import ItemsProviderProtocol, KeyBuilder logger: logging.Logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -90,6 +91,8 @@ def cache( key_builder: Optional[KeyBuilder] = None, namespace: str = "", injected_dependency_namespace: str = "__fastapi_cache", + tag_provider: Optional[TagProvider] = None, + items_provider: Optional[ItemsProviderProtocol] = None, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]: """ cache all function @@ -98,6 +101,8 @@ def cache( :param expire: :param coder: :param key_builder: + :param tag_provider: + :param items_provider: :return: """ @@ -194,6 +199,22 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R: f"Error setting cache key '{cache_key}' in backend:", exc_info=True, ) + else: + if tag_provider: + decoded = coder.decode(to_cache) + try: + await tag_provider.provide( + data=decoded, + parent_key=cache_key, + expire=expire, + items_provider=items_provider, + method_args=args, + method_kwargs=kwargs, + ) + except Exception: + logger.warning( + f"Error while providing tags: {cache_key}", exc_info=True + ) if response: response.headers.update( @@ -229,3 +250,36 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R: return inner return wrapper + + +def default_invalidator(response: dict, kwargs: dict) -> str: + return f"{response['id']}" + + +def cache_invalidator( + tag_provider: TagProvider, + invalidator: Callable[[dict, dict], str] = default_invalidator, +): + def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]: + @wraps(func) + async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]: + coder = FastAPICache.get_coder() + + response = await func(*args, **kwargs) + + data = coder.decode(coder.encode(response)) + + try: + object_id = invalidator(data, kwargs) + await tag_provider.invalidate(object_id) + except Exception as e: + logger.warning( + f"Exception occurred while invalidating cache: {e}", + exc_info=True, + ) + + return response + + return inner + + return wrapper diff --git a/fastapi_cache/tag_provider.py b/fastapi_cache/tag_provider.py new file mode 100644 index 0000000..32f1dc9 --- /dev/null +++ b/fastapi_cache/tag_provider.py @@ -0,0 +1,93 @@ +import asyncio +from typing import Callable, List, Optional, Union + +from fastapi_cache import FastAPICache +from fastapi_cache.types import ItemsProviderProtocol + + +class TagProvider: + def __init__( + self, + object_type: str, + object_id_provider: Optional[Callable[[dict], str]] = None, + ) -> None: + self.object_type = object_type + self.object_id_provider = object_id_provider or self.default_object_id_provider + + @staticmethod + def default_object_id_provider(item: dict) -> str: + return f"{item['id']}" + + @staticmethod + def default_items_provider( + data: Union[dict, list], + method_args: Optional[tuple] = None, + method_kwargs: Optional[dict] = None, + ) -> list[dict]: + return data + + def get_tag(self, item: Optional[dict] = None, object_id: Optional[str] = None) -> str: + prefix = FastAPICache.get_prefix() + object_id = object_id or self.object_id_provider(item) + return f"{prefix}:invalidation:{self.object_type}:{object_id}" + + @staticmethod + async def _append_value(key: str, parent_key: str, expire: int): + backend = FastAPICache.get_backend() + coder = FastAPICache.get_coder() + value = await backend.get(key) + if value: + value = coder.decode(value) + value.append(parent_key) + else: + value = [parent_key] + await backend.set(key=key, value=coder.encode(value), expire=expire) + + async def provide( + self, + data: Union[dict, list], + parent_key: str, + expire: Optional[int] = None, + items_provider: Optional[ItemsProviderProtocol] = None, + method_args: Optional[tuple] = None, + method_kwargs: Optional[dict] = None, + ) -> None: + """ + Provides tags for endpoint. + + :param data: + :param parent_key: + :param expire: + :param items_provider: + :param method_args: + :param method_kwargs: + """ + provider = items_provider or self.default_items_provider + tasks = [ + self._append_value( + key=self.get_tag(item), + parent_key=parent_key, + expire=expire or FastAPICache.get_expire(), + ) + for item in provider(data, method_args, method_kwargs) + ] + await asyncio.gather(*tasks) + + async def invalidate(self, object_id: str) -> None: + """ + Invalidate tags with given object_id + + :param object_id: object_id to invalidate + """ + backend = FastAPICache.get_backend() + coder = FastAPICache.get_coder() + tag = self.get_tag(object_id=object_id) + + value = await backend.get(tag) + if not value: + return + + keys: List[str] = coder.decode(value) + tasks = [backend.clear(key=key) for key in keys] + tasks.append(backend.clear(key=tag)) + await asyncio.gather(*tasks) diff --git a/fastapi_cache/types.py b/fastapi_cache/types.py index 551746e..8711b69 100644 --- a/fastapi_cache/types.py +++ b/fastapi_cache/types.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypeAlias, Union from starlette.requests import Request from starlette.responses import Response @@ -38,3 +38,23 @@ async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> Non @abc.abstractmethod async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError + + +class _ItemsProviderProtocol(Protocol): + def __call__(self, data: Union[dict, list]): + pass + + +class _ItemsProviderProtocolWithParams(Protocol): + def __call__( + self, + data: Union[dict, list], + method_args: Optional[tuple] = None, + method_kwargs: Optional[dict] = None, + ) -> list[dict]: + pass + + +ItemsProviderProtocol: TypeAlias = ( + _ItemsProviderProtocol | _ItemsProviderProtocolWithParams +) diff --git a/tests/test_cache_invalidation.py b/tests/test_cache_invalidation.py new file mode 100644 index 0000000..4309833 --- /dev/null +++ b/tests/test_cache_invalidation.py @@ -0,0 +1,110 @@ +from collections import defaultdict +from typing import Any, Generator +from unittest.mock import patch + +import pytest +from starlette.testclient import TestClient + +from fastapi_cache import FastAPICache +from fastapi_cache.backends.inmemory import InMemoryBackend + + +@pytest.fixture(autouse=True) +def _init_cache() -> Generator[Any, Any, None]: # pyright: ignore[reportUnusedFunction] + FastAPICache.init(InMemoryBackend()) + yield + FastAPICache.reset() + + +@pytest.fixture() +def test_client(): + from examples.in_memory.main import app + + with TestClient(app=app) as client: + yield client + + +@pytest.fixture(autouse=True) +def set_initial_value_for_files(): + files = defaultdict( + list, + { + 1: [1, 2, 3], + 2: [4, 5, 6], + 3: [100], + }, + ) + with patch("examples.in_memory.main.files", files): + yield + + +class TestCacheInvalidation: + def test_cache_invalidation(self, test_client) -> None: + response = test_client.get("/files") + + assert response.headers.get("X-FastAPI-Cache") == "MISS" + assert response.json() == [ + {'id': 1, 'value': [1, 2, 3]}, + {'id': 2, 'value': [4, 5, 6]}, + {'id': 3, 'value': [100]}, + ] + + response = test_client.get("/files") + + assert response.headers.get("X-FastAPI-Cache") == "HIT" + assert response.json() == [ + {'id': 1, 'value': [1, 2, 3]}, + {'id': 2, 'value': [4, 5, 6]}, + {'id': 3, 'value': [100]}, + ] + + # changing file and invalidating first request + change_response = test_client.patch("/files/1", json={"items": [42]}) + assert change_response.status_code == 200, change_response.json() + assert change_response.json() == { + "id": 1, + "value": [42], + } + + # this was invalidated + response = test_client.get("/files") + assert response.headers.get("X-FastAPI-Cache") == "MISS" + assert response.json() == [ + {'id': 1, 'value': [42]}, + {'id': 2, 'value': [4, 5, 6]}, + {'id': 3, 'value': [100]}, + ] + + def test_partial_invalidation(self, test_client) -> None: + response = test_client.get("/files", params={"file_id_in": [1, 2]}) + + assert response.json() == [ + {'id': 1, 'value': [1, 2, 3]}, + {'id': 2, 'value': [4, 5, 6]}, + ] + + response = test_client.get("/files", params={"file_id_in": [2, 3]}) + + assert response.json() == [ + {'id': 2, 'value': [4, 5, 6]}, + {'id': 3, 'value': [100]}, + ] + + # changing file with id 1 not causing second request invalidation + change_response = test_client.patch("/files/1", json={"items": [42]}) + + # this was invalidated + response = test_client.get("/files", params={"file_id_in": [1, 2]}) + assert response.headers.get("X-FastAPI-Cache") == "MISS" + assert response.json() == [ + {'id': 1, 'value': [42]}, + {'id': 2, 'value': [4, 5, 6]}, + ] + + # this is not + response = test_client.get("/files", params={"file_id_in": [2, 3]}) + assert response.headers.get("X-FastAPI-Cache") == "HIT" + assert response.json() == [ + {'id': 2, 'value': [4, 5, 6]}, + {'id': 3, 'value': [100]}, + ]