From 6941e32dbb9114109e935ab190115ca6d47c9a27 Mon Sep 17 00:00:00 2001 From: Isac Byeonghoon Yoo Date: Thu, 8 Apr 2021 13:48:27 +0900 Subject: [PATCH] Utilize __init_subclass__ instead of inspect on Locale mapping to improve type checking (#920) * Utilize __init_subclass__ * Add Unit Test for duplicated locale name Co-authored-by: Jad Chaar --- arrow/locales.py | 28 +++++++++++----------------- tests/conftest.py | 2 +- tests/test_locales.py | 10 ++++++++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/arrow/locales.py b/arrow/locales.py index 8ceba33e2..90ca2b5c7 100644 --- a/arrow/locales.py +++ b/arrow/locales.py @@ -1,4 +1,3 @@ -import inspect import sys from math import trunc from typing import ( @@ -48,6 +47,9 @@ ] +_locale_map: Dict[str, Type["Locale"]] = dict() + + def get_locale(name: str) -> "Locale": """Returns an appropriate :class:`Locale ` corresponding to an input locale name. @@ -56,7 +58,7 @@ def get_locale(name: str) -> "Locale": """ - locale_cls = _locales.get(name.lower()) + locale_cls = _locale_map.get(name.lower()) if locale_cls is None: raise ValueError(f"Unsupported locale {name!r}.") @@ -121,6 +123,13 @@ class Locale: _month_name_to_ordinal: Optional[Dict[str, int]] + def __init_subclass__(cls, **kwargs: Any) -> None: + for locale_name in cls.names: + if locale_name in _locale_map: + raise LookupError(f"Duplicated locale name: {locale_name}") + + _locale_map[locale_name] = cls + def __init__(self) -> None: self._month_name_to_ordinal = None @@ -3409,18 +3418,6 @@ class MarathiLocale(Locale): day_abbreviations = ["", "सोम", "मंगळ", "बुध", "गुरु", "शुक्र", "शनि", "रवि"] -def _map_locales() -> Dict[str, Type[Locale]]: - - locales: Dict[str, Type[Locale]] = {} - - for _, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass): - if issubclass(cls, Locale): # pragma: no branch - for name in cls.names: - locales[name.lower()] = cls - - return locales - - class CatalanLocale(Locale): names = ["ca", "ca_es", "ca_ad", "ca_fr", "ca_it"] past = "Fa {0}" @@ -4644,6 +4641,3 @@ class LithuanianLocale(Locale): "še", "se", ] - - -_locales: Dict[str, Type[Locale]] = _map_locales() diff --git a/tests/conftest.py b/tests/conftest.py index 4043bc3b1..5d5b9980e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def arrow_factory(request): @pytest.fixture(scope="class") def lang_locales(request): - request.cls.locales = locales._locales + request.cls.locales = locales._locale_map @pytest.fixture(scope="class") diff --git a/tests/test_locales.py b/tests/test_locales.py index 721d22bb1..6c984ecd6 100644 --- a/tests/test_locales.py +++ b/tests/test_locales.py @@ -33,6 +33,12 @@ def test_locale_validation(self): assert locale_cls.past is not None assert locale_cls.future is not None + def test_duplicated_locale_name(self): + with pytest.raises(LookupError): + + class Locale1(locales.Locale): + names = ["en_us"] + class TestModule: def test_get_locale(self, mocker): @@ -43,7 +49,7 @@ def test_get_locale(self, mocker): with pytest.raises(ValueError): arrow.locales.get_locale("locale_name") - cls_dict = arrow.locales._locales + cls_dict = arrow.locales._locale_map mocker.patch.dict(cls_dict, {"locale_name": mock_locale_cls}) result = arrow.locales.get_locale("locale_name") @@ -68,7 +74,7 @@ def test_get_locale_by_class_name(self, mocker): def test_locales(self): - assert len(locales._locales) > 0 + assert len(locales._locale_map) > 0 @pytest.mark.usefixtures("lang_locale")