Skip to content

Commit

Permalink
Apply cache decorator to get_dimesionality
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Jul 18, 2023
1 parent e7ee303 commit b50c9c1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pint/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,6 @@ def cache_stack_pop(self: UnitRegistry) -> dict[Any, T]:
################################################################################


def cache(user_function: Callable[..., Any], /):
def cache(user_function: T, /) -> T:
'Simple lightweight unbounded cache. Sometimes called "memoize".'
return lru_cache(user_function)
1 change: 0 additions & 1 deletion pint/facets/context/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ContextCacheOverlay:
def __init__(self, registry_cache: RegistryCache) -> None:
self.dimensional_equivalents = registry_cache.dimensional_equivalents
self.root_units = {}
self.dimensionality = registry_cache.dimensionality
self.parse_unit = registry_cache.parse_unit


Expand Down
19 changes: 2 additions & 17 deletions pint/facets/plain/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def __init__(self) -> None:
# TODO: this description is not right.
self.root_units: dict[UnitsContainer, tuple[Scalar, UnitsContainer]] = {}

#: Maps dimensionality (UnitsContainer) to Units (UnitsContainer)
self.dimensionality: dict[UnitsContainer, UnitsContainer] = {}

#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
self.parse_unit: dict[str, UnitsContainer] = {}

Expand All @@ -140,7 +137,6 @@ def __eq__(self, other: Any):
attrs = (
"dimensional_equivalents",
"root_units",
"dimensionality",
"parse_unit",
)
return all(getattr(self, attr) == getattr(other, attr) for attr in attrs)
Expand Down Expand Up @@ -620,7 +616,6 @@ def _build_cache(self, loaded_files=None) -> None:
di = self._get_dimensionality(uc)

self._cache.root_units[uc] = bu
self._cache.dimensionality[uc] = di

if not prefix:
dimeq_set = self._cache.dimensional_equivalents.setdefault(
Expand Down Expand Up @@ -702,20 +697,12 @@ def get_dimensionality(self, input_units: UnitLike) -> UnitsContainer:

return self._get_dimensionality(input_units)

def _get_dimensionality(
self, input_units: Optional[UnitsContainer]
) -> UnitsContainer:
@methodcache
def _get_dimensionality(self, input_units: UnitsContainer) -> UnitsContainer:
"""Convert a UnitsContainer to plain dimensions."""
if not input_units:
return self.UnitsContainer()

cache = self._cache.dimensionality

try:
return cache[input_units]
except KeyError:
pass

accumulator: dict[str, int] = defaultdict(int)
self._get_dimensionality_recurse(input_units, 1, accumulator)

Expand All @@ -724,8 +711,6 @@ def _get_dimensionality(

dims = self.UnitsContainer({k: v for k, v in accumulator.items() if v != 0})

cache[input_units] = dims

return dims

def _get_dimensionality_recurse(
Expand Down

0 comments on commit b50c9c1

Please sign in to comment.