From e7ee303bbb593b7a4a923574e268bab52727ecf9 Mon Sep 17 00:00:00 2001 From: Hernan Grecco Date: Mon, 17 Jul 2023 22:33:38 -0300 Subject: [PATCH] Improve cache decorator to keep track of decorated functions --- pint/cache.py | 19 ++++++++++++++----- pint/facets/context/registry.py | 5 +++++ pint/facets/plain/registry.py | 7 ++++++- pint/testsuite/test_cache.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/pint/cache.py b/pint/cache.py index 2959355ea..d66fd34c0 100644 --- a/pint/cache.py +++ b/pint/cache.py @@ -89,7 +89,7 @@ def _make_key( return _HashedSeq(key) -def lru_cache(): +class lru_cache: """Least-recently-used cache decorator. If *maxsize* is set to None, the LRU features are disabled and the cache @@ -115,15 +115,23 @@ def lru_cache(): # The internals of the lru_cache are encapsulated for thread safety and # to allow the implementation to change (including a possible C version). - def decorating_function(user_function: Callable[..., T]) -> Callable[..., T]: + def __init__(self, user_function): wrapper = _lru_cache_wrapper(user_function) - return update_wrapper(wrapper, user_function) + self.wrapped_fun = update_wrapper(wrapper, user_function) - return decorating_function + def __set_name__(self, owner, name): + cache_methods = getattr(owner, "cache_methods", None) + if cache_methods is None: + owner._cache_methods = cache_methods = [] + + cache_methods.append(self.wrapped_fun) + + setattr(owner, name, self.wrapped_fun) def _lru_cache_wrapper(user_function: Callable[..., T]) -> Callable[..., T]: # Constants shared by all lru cache instances: + sentinel = object() # unique object used to signal cache misses make_key = _make_key # build a key from the function arguments @@ -184,6 +192,7 @@ def cache_stack_pop(self: UnitRegistry) -> dict[Any, T]: wrapper.cache_clear = cache_clear wrapper.cache_stack_push = cache_stack_push wrapper.cache_stack_pop = cache_stack_pop + return wrapper @@ -194,4 +203,4 @@ def cache_stack_pop(self: UnitRegistry) -> dict[Any, T]: def cache(user_function: Callable[..., Any], /): 'Simple lightweight unbounded cache. Sometimes called "memoize".' - return lru_cache()(user_function) + return lru_cache(user_function) diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index 02b77ce03..46ff566ff 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -80,6 +80,11 @@ def _register_definition_adders(self) -> None: super()._register_definition_adders() self._register_adder(ContextDefinition, self.add_context) + def _clear_cache(self): + super()._clear_cache() + for func in self._cache_methods: + func.cache_clear(self) + def add_context(self, context: Union[objects.Context, ContextDefinition]) -> None: """Add a context object to the registry. diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py index 684052844..6ad059698 100644 --- a/pint/facets/plain/registry.py +++ b/pint/facets/plain/registry.py @@ -580,6 +580,11 @@ def load_definitions( return parsed_project + def _clear_cache(self): + self._cache = RegistryCache() + for func in self._cache_methods: + func.cache_clear(self) + def _build_cache(self, loaded_files=None) -> None: """Build a cache of dimensionality and plain units.""" @@ -591,7 +596,7 @@ def _build_cache(self, loaded_files=None) -> None: diskcache.save(self._cache, loaded_files, "build_cache") return - self._cache = RegistryCache() + self._clear_cache() deps: dict[str, set[str]] = { name: set(definition.reference.keys()) if definition.reference else set() diff --git a/pint/testsuite/test_cache.py b/pint/testsuite/test_cache.py index 8c0a65341..3f7b6913d 100644 --- a/pint/testsuite/test_cache.py +++ b/pint/testsuite/test_cache.py @@ -22,6 +22,34 @@ def calculated_value(self, value): return self.value * value + 0.5 +def test_cache_methods(): + assert hasattr(Demo, "_cache_methods") + assert isinstance(Demo._cache_methods, list) + assert Demo._cache_methods == [ + Demo.calculated_value, + ] + + assert hasattr(DerivedDemo, "_cache_methods") + assert isinstance(DerivedDemo._cache_methods, list) + assert DerivedDemo._cache_methods == [ + DerivedDemo.calculated_value, + ] + + d = Demo(1) + assert hasattr(d, "_cache_methods") + assert isinstance(d._cache_methods, list) + assert d._cache_methods == [ + Demo.calculated_value, + ] + + d = DerivedDemo(2) + assert hasattr(d, "_cache_methods") + assert isinstance(d._cache_methods, list) + assert d._cache_methods == [ + DerivedDemo.calculated_value, + ] + + def test_cache_clear(): demo = Demo(2)