Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented granular cache invalidation based on tags #435

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions examples/in_memory/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# pyright: reportGeneralTypeIssues=false
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import AsyncIterator, Dict, Optional
from typing import Annotated, AsyncIterator, Dict, List, Optional

import pendulum
import uvicorn
from fastapi import FastAPI
from fastapi import Body, FastAPI, Query
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
from fastapi_cache.decorator import cache, cache_invalidator
from fastapi_cache.tag_provider import TagProvider
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand Down Expand Up @@ -65,7 +67,7 @@ async def get_kwargs(name: str):


@app.get("/sync-me")
@cache(namespace="test") # pyright: ignore[reportArgumentType]
@cache(namespace="test") # pyright: ignore[reportArgumentType]
def sync_me():
# as per the fastapi docs, this sync function is wrapped in a thread,
# thereby converted to async. fastapi-cache does the same.
Expand Down Expand Up @@ -115,8 +117,10 @@ async def uncached_put():
put_ret = put_ret + 1
return {"value": put_ret}


put_ret2 = 0


@app.get("/cached_put")
@cache(namespace="test", expire=5)
async def cached_put():
Expand All @@ -126,7 +130,7 @@ async def cached_put():


@app.get("/namespaced_injection")
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType]
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType]
def namespaced_injection(
__fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17
) -> Dict[str, int]:
Expand All @@ -136,5 +140,65 @@ def namespaced_injection(
}


# Note: examples with cache invalidation
files = defaultdict(
list,
{
1: [1, 2, 3],
2: [4, 5, 6],
3: [100],
},
)

FileTagProvider = TagProvider("file")


# Note: providing tags for future granular cache invalidation
@app.get("/files")
@cache(expire=10, tag_provider=FileTagProvider)
async def get_files(file_id_in: Annotated[Optional[List[int]], Query()] = None):
return [
{"id": k, "value": v}
for k, v in files.items()
if (True if not file_id_in else k in file_id_in)
]


# Note: here we're retrieving keys by file_id, so we also need to invalidate this, when file changes
@app.get("/files/{file_id:int}")
@cache(
expire=10,
tag_provider=FileTagProvider,
items_provider=lambda data, method_args, method_kwargs: [
{"id": method_kwargs["file_id"]}
],
)
async def get_file_keys(file_id: int):
if file_id in files:
return files[file_id]
return Response("file id not found")


# Note: here we can use default invalidator, because in response we have :id:
@app.patch("/files/{file_id:int}")
@cache_invalidator(tag_provider=FileTagProvider)
async def edit_file(file_id: int, items: Annotated[List[int], Body(embed=True)]):
files[file_id] = items
return {
"id": file_id,
"value": files[file_id]
}


# Note: here we need to use custom :invalidator: because we don't have access to identifier in response
@app.delete("/files/{file_id:int}")
@cache_invalidator(
tag_provider=FileTagProvider, invalidator=lambda resp, kwargs: kwargs["file_id"]
)
async def delete_file(file_id: int):
if file_id in files:
del files[file_id]


if __name__ == "__main__":
uvicorn.run("main:app", reload=True)
56 changes: 55 additions & 1 deletion fastapi_cache/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
from fastapi_cache.tag_provider import TagProvider
from fastapi_cache.types import ItemsProviderProtocol, KeyBuilder

logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -90,6 +91,8 @@ def cache(
key_builder: Optional[KeyBuilder] = None,
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
tag_provider: Optional[TagProvider] = None,
items_provider: Optional[ItemsProviderProtocol] = None,
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
"""
cache all function
Expand All @@ -98,6 +101,8 @@ def cache(
:param expire:
:param coder:
:param key_builder:
:param tag_provider:
:param items_provider:

:return:
"""
Expand Down Expand Up @@ -194,6 +199,22 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
f"Error setting cache key '{cache_key}' in backend:",
exc_info=True,
)
else:
if tag_provider:
decoded = coder.decode(to_cache)
try:
await tag_provider.provide(
data=decoded,
parent_key=cache_key,
expire=expire,
items_provider=items_provider,
method_args=args,
method_kwargs=kwargs,
)
except Exception:
logger.warning(
f"Error while providing tags: {cache_key}", exc_info=True
)

if response:
response.headers.update(
Expand Down Expand Up @@ -229,3 +250,36 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
return inner

return wrapper


def default_invalidator(response: dict, kwargs: dict) -> str:
return f"{response['id']}"


def cache_invalidator(
tag_provider: TagProvider,
invalidator: Callable[[dict, dict], str] = default_invalidator,
):
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
coder = FastAPICache.get_coder()

response = await func(*args, **kwargs)

data = coder.decode(coder.encode(response))

try:
object_id = invalidator(data, kwargs)
await tag_provider.invalidate(object_id)
except Exception as e:
logger.warning(
f"Exception occurred while invalidating cache: {e}",
exc_info=True,
)

return response

return inner

return wrapper
93 changes: 93 additions & 0 deletions fastapi_cache/tag_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
from typing import Callable, List, Optional, Union

from fastapi_cache import FastAPICache
from fastapi_cache.types import ItemsProviderProtocol


class TagProvider:
def __init__(
self,
object_type: str,
object_id_provider: Optional[Callable[[dict], str]] = None,
) -> None:
self.object_type = object_type
self.object_id_provider = object_id_provider or self.default_object_id_provider

@staticmethod
def default_object_id_provider(item: dict) -> str:
return f"{item['id']}"

@staticmethod
def default_items_provider(
data: Union[dict, list],
method_args: Optional[tuple] = None,
method_kwargs: Optional[dict] = None,
) -> list[dict]:
return data

def get_tag(self, item: Optional[dict] = None, object_id: Optional[str] = None) -> str:
prefix = FastAPICache.get_prefix()
object_id = object_id or self.object_id_provider(item)
return f"{prefix}:invalidation:{self.object_type}:{object_id}"

@staticmethod
async def _append_value(key: str, parent_key: str, expire: int):
backend = FastAPICache.get_backend()
coder = FastAPICache.get_coder()
value = await backend.get(key)
if value:
value = coder.decode(value)
value.append(parent_key)
else:
value = [parent_key]
await backend.set(key=key, value=coder.encode(value), expire=expire)

async def provide(
self,
data: Union[dict, list],
parent_key: str,
expire: Optional[int] = None,
items_provider: Optional[ItemsProviderProtocol] = None,
method_args: Optional[tuple] = None,
method_kwargs: Optional[dict] = None,
) -> None:
"""
Provides tags for endpoint.

:param data:
:param parent_key:
:param expire:
:param items_provider:
:param method_args:
:param method_kwargs:
"""
provider = items_provider or self.default_items_provider
tasks = [
self._append_value(
key=self.get_tag(item),
parent_key=parent_key,
expire=expire or FastAPICache.get_expire(),
)
for item in provider(data, method_args, method_kwargs)
]
await asyncio.gather(*tasks)

async def invalidate(self, object_id: str) -> None:
"""
Invalidate tags with given object_id

:param object_id: object_id to invalidate
"""
backend = FastAPICache.get_backend()
coder = FastAPICache.get_coder()
tag = self.get_tag(object_id=object_id)

value = await backend.get(tag)
if not value:
return

keys: List[str] = coder.decode(value)
tasks = [backend.clear(key=key) for key in keys]
tasks.append(backend.clear(key=tag))
await asyncio.gather(*tasks)
22 changes: 21 additions & 1 deletion fastapi_cache/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypeAlias, Union

from starlette.requests import Request
from starlette.responses import Response
Expand Down Expand Up @@ -38,3 +38,23 @@ async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> Non
@abc.abstractmethod
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError


class _ItemsProviderProtocol(Protocol):
def __call__(self, data: Union[dict, list]):
pass


class _ItemsProviderProtocolWithParams(Protocol):
def __call__(
self,
data: Union[dict, list],
method_args: Optional[tuple] = None,
method_kwargs: Optional[dict] = None,
) -> list[dict]:
pass


ItemsProviderProtocol: TypeAlias = (
_ItemsProviderProtocol | _ItemsProviderProtocolWithParams
)
Loading