Skip to content

Commit

Permalink
convert dict entry to dataclass & typing (#239)
Browse files Browse the repository at this point in the history
* convert `dict` entry to dataclass

* linting

* Dict

* lint

* if key not in cache

* fixing

* mock

* mock

* mock

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused TypedDict import from config.py

---------

Co-authored-by: Shay Palachy-Affek <shaypal5@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent 9dd89b0 commit 731b66e
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 140 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
- name: Unit tests (local)
if: matrix.backend == 'local'
run: pytest -m "not mongo"
run: pytest -m "not mongo" --cov=cachier --cov-report=term --cov-report=xml:cov.xml

- name: Setup docker (missing on MacOS)
if: runner.os == 'macOS' && matrix.backend == 'db'
Expand All @@ -77,7 +77,7 @@ jobs:
docker ps -a
- name: Unit tests (DB)
if: matrix.backend == 'db'
run: pytest -m "mongo"
run: pytest -m "mongo" --cov=cachier --cov-report=term --cov-report=xml:cov.xml
- name: Speed eval
run: python tests/speed_eval.py

Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ norecursedirs = [
]
addopts = [
"--color=yes",
"--cov=cachier",
"--cov-report=term",
"--cov-report=xml:cov.xml",
"-r a",
"-v",
"-s",
Expand Down
14 changes: 13 additions & 1 deletion src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import hashlib
import os
import pickle
import threading
from collections.abc import Mapping
from dataclasses import dataclass, replace
from typing import Optional, Union
from typing import Any, Optional, Union

from ._types import Backend, HashFunc, Mongetter

Expand Down Expand Up @@ -38,6 +39,17 @@ class Params:
_global_params = Params()


@dataclass
class CacheEntry:
"""Data class for cache entries."""

value: Any
time: datetime
stale: bool
being_calculated: bool
condition: Optional[threading.Condition] = None


def _update_with_defaults(
param, name: str, func_kwargs: Optional[dict] = None
):
Expand Down
14 changes: 7 additions & 7 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,17 @@ def func_wrapper(*args, **kwds):
_print("No entry found. No current calc. Calling like a boss.")
return _calc_entry(core, key, func, args, kwds)
_print("Entry found.")
if _allow_none or entry.get("value", None) is not None:
if _allow_none or entry.value is not None:
_print("Cached result found.")
now = datetime.datetime.now()
if now - entry["time"] <= _stale_after:
if now - entry.time <= _stale_after:
_print("And it is fresh!")
return entry["value"]
return entry.value
_print("But it is stale... :(")
if entry["being_calculated"]:
if entry.being_calculated:
if _next_time:
_print("Returning stale.")
return entry["value"] # return stale val
return entry.value # return stale val
_print("Already calc. Waiting on change.")
try:
return core.wait_on_entry_calc(key)
Expand All @@ -283,10 +283,10 @@ def func_wrapper(*args, **kwds):
)
finally:
core.mark_entry_not_calculated(key)
return entry["value"]
return entry.value
_print("Calling decorated function and waiting")
return _calc_entry(core, key, func, args, kwds)
if entry["being_calculated"]:
if entry.being_calculated:
_print("No value but being calculated. Waiting.")
try:
return core.wait_on_entry_calc(key)
Expand Down
20 changes: 10 additions & 10 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import abc # for the _BaseCore abstract base class
import inspect
import threading
from typing import Callable
from typing import Callable, Optional, Tuple

from .._types import HashFunc
from ..config import _update_with_defaults
from ..config import CacheEntry, _update_with_defaults


class RecalculationNeeded(Exception):
Expand Down Expand Up @@ -51,7 +51,7 @@ def get_key(self, args, kwds):
"""Return a unique key based on the arguments provided."""
return self.hash_func(args, kwds)

def get_entry(self, args, kwds):
def get_entry(self, args, kwds) -> Tuple[str, Optional[CacheEntry]]:
"""Get entry based on given arguments.
Return the result mapped to the given arguments in this core's cache,
Expand All @@ -76,7 +76,7 @@ def check_calc_timeout(self, time_spent):
raise RecalculationNeeded()

@abc.abstractmethod
def get_entry_by_key(self, key):
def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
"""Get entry based on given key.
Return the result mapped to the given key in this core's cache, if such
Expand All @@ -85,25 +85,25 @@ def get_entry_by_key(self, key):
"""

@abc.abstractmethod
def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res):
"""Map the given result to the given key in this core's cache."""

@abc.abstractmethod
def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
"""Mark the entry mapped by the given key as being calculated."""

@abc.abstractmethod
def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
"""Mark the entry mapped by the given key as not being calculated."""

@abc.abstractmethod
def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> None:
"""Wait on the entry with keys being calculated and returns result."""

@abc.abstractmethod
def clear_cache(self):
def clear_cache(self) -> None:
"""Clear the cache of this core."""

@abc.abstractmethod
def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
"""Mark all entries in this cache as not being calculated."""
76 changes: 40 additions & 36 deletions src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import threading
from datetime import datetime
from typing import Any, Optional, Tuple

from .._types import HashFunc
from ..config import CacheEntry
from .base import _BaseCore, _get_func_str


Expand All @@ -14,76 +16,78 @@ def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
super().__init__(hash_func, wait_for_calc_timeout)
self.cache = {}

def _hash_func_key(self, key):
def _hash_func_key(self, key: str) -> str:
return f"{_get_func_str(self.func)}:{key}"

def get_entry_by_key(self, key, reload=False):
def get_entry_by_key(
self, key: str, reload=False
) -> Tuple[str, Optional[CacheEntry]]:
with self.lock:
return key, self.cache.get(self._hash_func_key(key), None)

def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res: Any) -> None:
with self.lock:
try:
# we need to retain the existing condition so that
# mark_entry_not_calculated can notify all possibly-waiting
# threads about it
cond = self.cache[self._hash_func_key(key)]["condition"]
cond = self.cache[self._hash_func_key(key)].condition
except KeyError: # pragma: no cover
cond = None
self.cache[self._hash_func_key(key)] = {
"value": func_res,
"time": datetime.now(),
"stale": False,
"being_calculated": False,
"condition": cond,
}
self.cache[self._hash_func_key(key)] = CacheEntry(
value=func_res,
time=datetime.now(),
stale=False,
being_calculated=False,
condition=cond,
)

def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
with self.lock:
condition = threading.Condition()
# condition.acquire()
try:
self.cache[self._hash_func_key(key)]["being_calculated"] = True
self.cache[self._hash_func_key(key)]["condition"] = condition
self.cache[self._hash_func_key(key)].being_calculated = True
self.cache[self._hash_func_key(key)].condition = condition
except KeyError:
self.cache[self._hash_func_key(key)] = {
"value": None,
"time": datetime.now(),
"stale": False,
"being_calculated": True,
"condition": condition,
}
self.cache[self._hash_func_key(key)] = CacheEntry(
value=None,
time=datetime.now(),
stale=False,
being_calculated=True,
condition=condition,
)

def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
with self.lock:
try:
entry = self.cache[self._hash_func_key(key)]
except KeyError: # pragma: no cover
return # that's ok, we don't need an entry in that case
entry["being_calculated"] = False
cond = entry["condition"]
entry.being_calculated = False
cond = entry.condition
if cond:
cond.acquire()
cond.notify_all()
cond.release()
entry["condition"] = None
entry.condition = None

def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> Any:
with self.lock: # pragma: no cover
entry = self.cache[self._hash_func_key(key)]
if not entry["being_calculated"]:
return entry["value"]
entry["condition"].acquire()
entry["condition"].wait()
entry["condition"].release()
return self.cache[self._hash_func_key(key)]["value"]
if not entry.being_calculated:
return entry.value
entry.condition.acquire()
entry.condition.wait()
entry.condition.release()
return self.cache[self._hash_func_key(key)].value

def clear_cache(self):
def clear_cache(self) -> None:
with self.lock:
self.cache.clear()

def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
with self.lock:
for entry in self.cache.values():
entry["being_calculated"] = False
entry["condition"] = None
entry.being_calculated = False
entry.condition = None
44 changes: 23 additions & 21 deletions src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import warnings # to warn if pymongo is missing
from contextlib import suppress
from datetime import datetime
from typing import Any, Optional, Tuple

from .._types import HashFunc, Mongetter
from ..config import CacheEntry

with suppress(ImportError):
from bson.binary import Binary # to save binary data to mongodb
Expand Down Expand Up @@ -65,29 +67,29 @@ def __init__(
def _func_str(self) -> str:
return _get_func_str(self.func)

def get_entry_by_key(self, key):
def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
res = self.mongo_collection.find_one(
{"func": self._func_str, "key": key}
)
if not res:
return key, None
try:
entry = {
"value": pickle.loads(res["value"]), # noqa: S301
"time": res.get("time", None),
"stale": res.get("stale", False),
"being_calculated": res.get("being_calculated", False),
}
entry = CacheEntry(
value=pickle.loads(res["value"]), # noqa: S301
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
except KeyError:
entry = {
"value": None,
"time": res.get("time", None),
"stale": res.get("stale", False),
"being_calculated": res.get("being_calculated", False),
}
entry = CacheEntry(
value=None,
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
return key, entry

def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res: Any) -> None:
thebytes = pickle.dumps(func_res)
self.mongo_collection.update_one(
filter={"func": self._func_str, "key": key},
Expand All @@ -104,14 +106,14 @@ def set_entry(self, key, func_res):
upsert=True,
)

def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
self.mongo_collection.update_one(
filter={"func": self._func_str, "key": key},
update={"$set": {"being_calculated": True}},
upsert=True,
)

def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
with suppress(OperationFailure): # don't care in this case
self.mongo_collection.update_one(
filter={
Expand All @@ -122,22 +124,22 @@ def mark_entry_not_calculated(self, key):
upsert=False, # should not insert in this case
)

def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> Any:
time_spent = 0
while True:
time.sleep(MONGO_SLEEP_DURATION_IN_SEC)
time_spent += MONGO_SLEEP_DURATION_IN_SEC
key, entry = self.get_entry_by_key(key)
if entry is None:
raise RecalculationNeeded()
if not entry["being_calculated"]:
return entry["value"]
if not entry.being_calculated:
return entry.value
self.check_calc_timeout(time_spent)

def clear_cache(self):
def clear_cache(self) -> None:
self.mongo_collection.delete_many(filter={"func": self._func_str})

def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
self.mongo_collection.update_many(
filter={
"func": self._func_str,
Expand Down
Loading

0 comments on commit 731b66e

Please sign in to comment.