Skip to content

Commit

Permalink
Improve cache decorator to keep track of decorated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Jul 18, 2023
1 parent 42aa025 commit e7ee303
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
19 changes: 14 additions & 5 deletions pint/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _make_key(
return _HashedSeq(key)


def lru_cache():
class lru_cache:
"""Least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
Expand All @@ -115,15 +115,23 @@ def lru_cache():
# The internals of the lru_cache are encapsulated for thread safety and
# to allow the implementation to change (including a possible C version).

def decorating_function(user_function: Callable[..., T]) -> Callable[..., T]:
def __init__(self, user_function):
wrapper = _lru_cache_wrapper(user_function)
return update_wrapper(wrapper, user_function)
self.wrapped_fun = update_wrapper(wrapper, user_function)

return decorating_function
def __set_name__(self, owner, name):
cache_methods = getattr(owner, "cache_methods", None)
if cache_methods is None:
owner._cache_methods = cache_methods = []

cache_methods.append(self.wrapped_fun)

setattr(owner, name, self.wrapped_fun)


def _lru_cache_wrapper(user_function: Callable[..., T]) -> Callable[..., T]:
# Constants shared by all lru cache instances:

sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments

Expand Down Expand Up @@ -184,6 +192,7 @@ def cache_stack_pop(self: UnitRegistry) -> dict[Any, T]:
wrapper.cache_clear = cache_clear
wrapper.cache_stack_push = cache_stack_push
wrapper.cache_stack_pop = cache_stack_pop

return wrapper


Expand All @@ -194,4 +203,4 @@ def cache_stack_pop(self: UnitRegistry) -> dict[Any, T]:

def cache(user_function: Callable[..., Any], /):
'Simple lightweight unbounded cache. Sometimes called "memoize".'
return lru_cache()(user_function)
return lru_cache(user_function)
5 changes: 5 additions & 0 deletions pint/facets/context/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def _register_definition_adders(self) -> None:
super()._register_definition_adders()
self._register_adder(ContextDefinition, self.add_context)

def _clear_cache(self):
super()._clear_cache()
for func in self._cache_methods:
func.cache_clear(self)

def add_context(self, context: Union[objects.Context, ContextDefinition]) -> None:
"""Add a context object to the registry.
Expand Down
7 changes: 6 additions & 1 deletion pint/facets/plain/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,11 @@ def load_definitions(

return parsed_project

def _clear_cache(self):
self._cache = RegistryCache()
for func in self._cache_methods:
func.cache_clear(self)

def _build_cache(self, loaded_files=None) -> None:
"""Build a cache of dimensionality and plain units."""

Expand All @@ -591,7 +596,7 @@ def _build_cache(self, loaded_files=None) -> None:
diskcache.save(self._cache, loaded_files, "build_cache")
return

self._cache = RegistryCache()
self._clear_cache()

deps: dict[str, set[str]] = {
name: set(definition.reference.keys()) if definition.reference else set()
Expand Down
28 changes: 28 additions & 0 deletions pint/testsuite/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ def calculated_value(self, value):
return self.value * value + 0.5


def test_cache_methods():
assert hasattr(Demo, "_cache_methods")
assert isinstance(Demo._cache_methods, list)
assert Demo._cache_methods == [
Demo.calculated_value,
]

assert hasattr(DerivedDemo, "_cache_methods")
assert isinstance(DerivedDemo._cache_methods, list)
assert DerivedDemo._cache_methods == [
DerivedDemo.calculated_value,
]

d = Demo(1)
assert hasattr(d, "_cache_methods")
assert isinstance(d._cache_methods, list)
assert d._cache_methods == [
Demo.calculated_value,
]

d = DerivedDemo(2)
assert hasattr(d, "_cache_methods")
assert isinstance(d._cache_methods, list)
assert d._cache_methods == [
DerivedDemo.calculated_value,
]


def test_cache_clear():
demo = Demo(2)

Expand Down

0 comments on commit e7ee303

Please sign in to comment.