From c501b8cc746b4504a9cdcb6d48b85138d6bda390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 20 May 2022 15:44:49 +0200 Subject: [PATCH 1/2] ENH: avoid unnecessary startup overhead for available optional dependencies, reduce boilerplate code in on_demand_import.py --- yt/testing.py | 8 +- yt/utilities/on_demand_imports.py | 773 ++++++++++-------------------- 2 files changed, 259 insertions(+), 522 deletions(-) diff --git a/yt/testing.py b/yt/testing.py index 40f6abb99f2..4827354bd1a 100644 --- a/yt/testing.py +++ b/yt/testing.py @@ -909,13 +909,11 @@ def requires_module_pytest(*module_names): from yt.utilities import on_demand_imports as odi def deco(func): - required_modules = { - name: getattr(odi, f"_{name}")._module for name in module_names - } missing = [ name - for name, mod in required_modules.items() - if isinstance(mod, odi.NotAModule) + for name in module_names + if not getattr(odi, f"_{name}").__is_available__ + for name in module_names ] # note that order between these two decorators matters diff --git a/yt/utilities/on_demand_imports.py b/yt/utilities/on_demand_imports.py index e1eaa3c6823..68dac17a25b 100644 --- a/yt/utilities/on_demand_imports.py +++ b/yt/utilities/on_demand_imports.py @@ -1,4 +1,7 @@ import sys +from functools import wraps +from importlib.util import find_spec +from typing import Type class NotAModule: @@ -20,40 +23,47 @@ def __getattr__(self, item): def __call__(self, *args, **kwargs): raise self.error + def __repr__(self) -> str: + return f"NotAModule({self.pkg_name!r})" -class NotCartopy(NotAModule): - """ - A custom class to return error messages dependent on system installation - for cartopy imports. - """ - def __init__(self, pkg_name): - self.pkg_name = pkg_name - if any(s in sys.version for s in ("Anaconda", "Continuum")): - # the conda-based installs of cartopy don't have issues with the - # GEOS library, so the error message for users with conda can be - # relatively short. Discussion related to this is in - # yt-project/yt#1966 - self.error = ImportError( - "This functionality requires the %s " - "package to be installed." % self.pkg_name - ) +class OnDemand: + _default_factory: Type[NotAModule] = NotAModule + + def __init_subclass__(cls): + if not cls.__name__.endswith("_imports"): + raise TypeError(f"class {cls}'s name needs to be suffixed '_imports'") + + def __new__(cls): + if cls is OnDemand: + raise TypeError("The OnDemand base class cannot be instanciated.") else: - self.error = ImportError( - "This functionality requires the %s " - "package to be installed. Try installing proj4 and " - "geos with your package manager and building shapely " - "and cartopy from source with: \n \n " - "python -m pip install --no-binary :all: shapely cartopy \n \n" - "For further instruction please refer to the " - "yt documentation." % self.pkg_name - ) + return object.__new__(cls) + @property + def _name(self) -> str: + _name, _, _suffix = self.__class__.__name__.rpartition("_") + return _name + + @property + def __is_available__(self) -> bool: + # special protocol to support testing framework + return find_spec(self._name) is not None -class netCDF4_imports: - _name = "netCDF4" - _Dataset = None +def safe_import(func): + @property + @wraps(func) + def inner(self): + try: + return func(self) + except ImportError: + return self._default_factory(self._name) + + return inner + + +class netCDF4_imports(OnDemand): def __init__(self): # this ensures the import ordering between netcdf4 and h5py. If h5py is # imported first, can get file lock errors on some systems (including travis-ci) @@ -64,500 +74,247 @@ def __init__(self): except ImportError: pass - @property + @safe_import def Dataset(self): - if self._Dataset is None: - try: - from netCDF4 import Dataset - except ImportError: - Dataset = NotAModule(self._name) - self._Dataset = Dataset - return self._Dataset - - -_netCDF4 = netCDF4_imports() + from netCDF4 import Dataset + return Dataset -class astropy_imports: - _name = "astropy" - _pyfits = None - @property - def pyfits(self): - if self._pyfits is None: - try: - import astropy.io.fits as pyfits +_netCDF4 = netCDF4_imports() - self.log - except ImportError: - pyfits = NotAModule(self._name) - self._pyfits = pyfits - return self._pyfits - _pywcs = None +class astropy_imports(OnDemand): + @safe_import + def log(self): + from astropy import log - @property - def pywcs(self): - if self._pywcs is None: - try: - import astropy.wcs as pywcs + if log.exception_logging_enabled(): + log.disable_exception_logging() - self.log - except ImportError: - pywcs = NotAModule(self._name) - self._pywcs = pywcs - return self._pywcs + return log - _log = None + @safe_import + def pyfits(self): + from astropy.io import fits - @property - def log(self): - if self._log is None: - try: - from astropy import log + return fits - if log.exception_logging_enabled(): - log.disable_exception_logging() - except ImportError: - log = NotAModule(self._name) - self._log = log - return self._log + @safe_import + def pywcs(self): + import astropy.wcs as pywcs - _units = None + self.log + return pywcs - @property + @safe_import def units(self): - if self._units is None: - try: - from astropy import units - - self.log - except ImportError: - units = NotAModule(self._name) - self._units = units - return self._units + from astropy import units - _conv = None + self.log + return units - @property + @safe_import def conv(self): - if self._conv is None: - try: - import astropy.convolution as conv - - self.log - except ImportError: - conv = NotAModule(self._name) - self._conv = conv - return self._conv + import astropy.convolution as conv - _time = None + self.log + return conv - @property + @safe_import def time(self): - if self._time is None: - try: - import astropy.time as time - - self.log - except ImportError: - time = NotAModule(self._name) - self._time = time - return self._time + import astropy.time as time - _wcsaxes = None + self.log + return time - @property + @safe_import def wcsaxes(self): - if self._wcsaxes is None: - try: - import astropy.visualization.wcsaxes as wcsaxes + from astropy.visualization import wcsaxes - self.log - except ImportError: - wcsaxes = NotAModule(self._name) - self._wcsaxes = wcsaxes - return self._wcsaxes + self.log + return wcsaxes - _version = None - @property - def __version__(self): - if self._version is None: - try: - import astropy +_astropy = astropy_imports() - version = astropy.__version__ - except ImportError: - version = NotAModule(self._name) - self._version = version - return self._version +class NotCartopy(NotAModule): + """ + A custom class to return error messages dependent on system installation + for cartopy imports. + """ -_astropy = astropy_imports() + def __init__(self, pkg_name): + self.pkg_name = pkg_name + if any(s in sys.version for s in ("Anaconda", "Continuum")): + # the conda-based installs of cartopy don't have issues with the + # GEOS library, so the error message for users with conda can be + # relatively short. Discussion related to this is in + # yt-project/yt#1966 + self.error = ImportError( + "This functionality requires the %s " + "package to be installed." % self.pkg_name + ) + else: + self.error = ImportError( + "This functionality requires the %s " + "package to be installed. Try installing proj4 and " + "geos with your package manager and building shapely " + "and cartopy from source with: \n \n " + "python -m pip install --no-binary :all: shapely cartopy \n \n" + "For further instruction please refer to the " + "yt documentation." % self.pkg_name + ) -class cartopy_imports: - _name = "cartopy" +class cartopy_imports(OnDemand): - _crs = None + _default_factory = NotCartopy - @property + @safe_import def crs(self): - if self._crs is None: - try: - import cartopy.crs as crs - except ImportError: - crs = NotCartopy(self._name) - self._crs = crs - return self._crs - - _version = None - - @property - def __version__(self): - if self._version is None: - try: - import cartopy + import cartopy.crs as crs - version = cartopy.__version__ - except ImportError: - version = NotCartopy(self._name) - self._version = version - return self._version + return crs _cartopy = cartopy_imports() -class pooch_imports: - _name = "pooch" - _module = None - - def __init__(self): - try: - import pooch as myself - - self._module = myself - except ImportError: - self._module = NotAModule(self._name) - - def __getattr__(self, attr): - return getattr(self._module, attr) - +class pooch_imports(OnDemand): + @safe_import + def HTTPDownloader(self): + from pooch import HTTPDownloader -_pooch = pooch_imports() - - -class scipy_imports: - _name = "scipy" - _integrate = None + return HTTPDownloader - @property - def integrate(self): - if self._integrate is None: - try: - import scipy.integrate as integrate - except ImportError: - integrate = NotAModule(self._name) - self._integrate = integrate - return self._integrate + @safe_import + def utils(self): + from pooch import utils - _stats = None + return utils - @property - def stats(self): - if self._stats is None: - try: - import scipy.stats as stats - except ImportError: - stats = NotAModule(self._name) - self._stats = stats - return self._stats + @safe_import + def create(self): + from pooch import create - _optimize = None + return create - @property - def optimize(self): - if self._optimize is None: - try: - import scipy.optimize as optimize - except ImportError: - optimize = NotAModule(self._name) - self._optimize = optimize - return self._optimize - _interpolate = None - - @property - def interpolate(self): - if self._interpolate is None: - try: - import scipy.interpolate as interpolate - except ImportError: - interpolate = NotAModule(self._name) - self._interpolate = interpolate - return self._interpolate - - _special = None - - @property - def special(self): - if self._special is None: - try: - import scipy.special as special - except ImportError: - special = NotAModule(self._name) - self._special = special - return self._special +_pooch = pooch_imports() - _signal = None - @property +class scipy_imports(OnDemand): + @safe_import def signal(self): - if self._signal is None: - try: - import scipy.signal as signal - except ImportError: - signal = NotAModule(self._name) - self._signal = signal - return self._signal + from scipy import signal - _spatial = None + return signal - @property + @safe_import def spatial(self): - if self._spatial is None: - try: - import scipy.spatial as spatial - except ImportError: - spatial = NotAModule(self._name) - self._spatial = spatial - return self._spatial + from scipy import spatial - _ndimage = None + return spatial - @property + @safe_import def ndimage(self): - if self._ndimage is None: - try: - import scipy.ndimage as ndimage - except ImportError: - ndimage = NotAModule(self._name) - self._ndimage = ndimage - return self._ndimage + from scipy import ndimage + return ndimage -_scipy = scipy_imports() +_scipy = scipy_imports() -class h5py_imports: - _name = "h5py" - _err = None - _File = None - @property +class h5py_imports(OnDemand): + @safe_import def File(self): - if self._err: - raise self._err - if self._File is None: - try: - from h5py import File - except ImportError: - File = NotAModule(self._name) - self._File = File - return self._File - - _Group = None + from h5py import File - @property + return File + + @safe_import def Group(self): - if self._err: - raise self._err - if self._Group is None: - try: - from h5py import Group - except ImportError: - Group = NotAModule(self._name) - self._Group = Group - return self._Group - - _Dataset = None + from h5py import Group - @property + return Group + + @safe_import def Dataset(self): - if self._err: - raise self._err - if self._Dataset is None: - try: - from h5py import Dataset - except ImportError: - Dataset = NotAModule(self._name) - self._Dataset = Dataset - return self._Dataset - - ___version__ = None + from h5py import Dataset - @property - def __version__(self): - if self._err: - raise self._err - if self.___version__ is None: - try: - from h5py import __version__ - except ImportError: - __version__ = NotAModule(self._name) - self.___version__ = __version__ - return self.___version__ - - _get_config = None + return Dataset - @property + @safe_import def get_config(self): - if self._err: - raise self._err - if self._get_config is None: - try: - from h5py import get_config - except ImportError: - get_config = NotAModule(self._name) - self._get_config = get_config - return self._get_config - - _h5f = None + from h5py import get_config - @property + return get_config + + @safe_import def h5f(self): - if self._err: - raise self._err - if self._h5f is None: - try: - import h5py.h5f as h5f - except ImportError: - h5f = NotAModule(self._name) - self._h5f = h5f - return self._h5f - - _h5p = None + from h5py import h5f - @property + return h5f + + @safe_import def h5p(self): - if self._err: - raise self._err - if self._h5p is None: - try: - import h5py.h5p as h5p - except ImportError: - h5p = NotAModule(self._name) - self._h5p = h5p - return self._h5p - - _h5d = None + from h5py import h5p - @property + return h5p + + @safe_import def h5d(self): - if self._err: - raise self._err - if self._h5d is None: - try: - import h5py.h5d as h5d - except ImportError: - h5d = NotAModule(self._name) - self._h5d = h5d - return self._h5d - - _h5s = None + from h5py import h5d - @property + return h5d + + @safe_import def h5s(self): - if self._err: - raise self._err - if self._h5s is None: - try: - import h5py.h5s as h5s - except ImportError: - h5s = NotAModule(self._name) - self._h5s = h5s - return self._h5s - - _version = None + from h5py import h5s - @property - def version(self): - if self._err: - raise self._err - if self._version is None: - try: - import h5py.version as version - except ImportError: - version = NotAModule(self._name) - self._version = version - return self._version + return h5s _h5py = h5py_imports() -class nose_imports: - _name = "nose" - _run = None - - @property +class nose_imports(OnDemand): + @safe_import def run(self): - if self._run is None: - try: - from nose import run - except ImportError: - run = NotAModule(self._name) - self._run = run - return self._run + from nose import run + return run -_nose = nose_imports() +_nose = nose_imports() -class libconf_imports: - _name = "libconf" - _load = None - @property +class libconf_imports(OnDemand): + @safe_import def load(self): - if self._load is None: - try: - from libconf import load - except ImportError: - load = NotAModule(self._name) - self._load = load - return self._load + from libconf import load + return load -_libconf = libconf_imports() +_libconf = libconf_imports() -class yaml_imports: - _name = "yaml" - _load = None - _FullLoader = None - @property +class yaml_imports(OnDemand): + @safe_import def load(self): - if self._load is None: - try: - from yaml import load - except ImportError: - load = NotAModule(self._name) - self._load = load - return self._load + from yaml import load - @property + return load + + @safe_import def FullLoader(self): - if self._FullLoader is None: - try: - from yaml import FullLoader - except ImportError: - FullLoader = NotAModule(self._name) - self._FullLoader = FullLoader - return self._FullLoader + from yaml import FullLoader + + return FullLoader _yaml = yaml_imports() @@ -575,140 +332,122 @@ def __init__(self, pkg_name): self.error = ImportError(str % self.pkg_name) -class miniball_imports: - _name = "miniball" - _Miniball = None - - @property +class miniball_imports(OnDemand): + @safe_import def Miniball(self): - if self._Miniball is None: - try: - from miniball import Miniball - except ImportError: - Miniball = NotMiniball(self._name) - self._Miniball = Miniball - return self._Miniball + from miniball import Miniball + + return Miniball _miniball = miniball_imports() -class f90nml_imports: - _name = "f90nml" - _module = None +class f90nml_imports(OnDemand): + @safe_import + def read(self): + from f90nml import read - def __init__(self): - try: - import f90nml as myself + return read - self._module = myself - except ImportError: - self._module = NotAModule(self._name) + @safe_import + def Namelist(self): + from f90nml import Namelist - def __getattr__(self, attr): - return getattr(self._module, attr) + return Namelist _f90nml = f90nml_imports() -class requests_imports: - _name = "requests" - _module = None +class requests_imports(OnDemand): + @safe_import + def post(self): + from requests import post - def __init__(self): - try: - import requests as myself + return post - self._module = myself - except ImportError: - self._module = NotAModule(self._name) + @safe_import + def put(self): + from requests import put - def __getattr__(self, attr): - return getattr(self._module, attr) + return put + + @safe_import + def codes(self): + from requests import codes + + return codes + + @safe_import + def get(self): + from requests import get + + return get + + @safe_import + def exceptions(self): + from requests import exceptions + + return exceptions _requests = requests_imports() -class pandas_imports: - _name = "pandas" - _module = None +class pandas_imports(OnDemand): + @safe_import + def NA(self): + from pandas import NA - def __init__(self): - try: - import pandas as myself + return NA - self._module = myself - except ImportError: - self._module = NotAModule(self._name) + @safe_import + def DataFrame(self): + from pandas import DataFrame - def __getattr__(self, attr): - return getattr(self._module, attr) + return DataFrame + @safe_import + def concat(self): + from pandas import concat -_pandas = pandas_imports() + return concat -class firefly_imports: - _name = "firefly" - _data_reader = None - _server = None +_pandas = pandas_imports() - @property + +class Firefly_imports(OnDemand): + @safe_import def data_reader(self): - if self._data_reader is None: - try: - import Firefly.data_reader as data_reader - except ImportError: - data_reader = NotAModule(self._name) - self._data_reader = data_reader - return self._data_reader + import Firefly.data_reader as data_reader - @property - def server(self): - if self._server is None: - try: - import Firefly.server as server - except ImportError: - server = NotAModule(self._name) - self._server = server - return self._server + return data_reader + @safe_import + def server(self): + import Firefly.server as server -_firefly = firefly_imports() + return server -# Note: ratarmount may fail with an OSError on import if libfuse is missing -# In this case, we want the on-demand-import to fail _where_ ratarmount -# is being used, rather than at startup. -# We could catch the OSError and throw it again when we try to access -# ratarmount. Instead here, we delay as much as possible the actual import of -# the package which thus raises an exception where expected. -# -# Note 2: we need to store the imported module in __module, as _module plays -# a special role in on-demand-imports (e.g. used for testing purposes to know -# if the package has been installed). -class ratarmount_imports: - _name = "ratarmount" - __module = None +_firefly = Firefly_imports() - @property - def _module(self): - if self.__module is not None: - return self.__module - try: - import ratarmount as myself +# Note: ratarmount may fail with an OSError on import if libfuse is missing +class ratarmount_imports(OnDemand): + @safe_import + def TarMount(self): + from ratarmount import TarMount - self.__module = myself - except ImportError: - self.__module = NotAModule(self._name) + return TarMount - return self.__module + @safe_import + def fuse(self): + from ratarmount import fuse - def __getattr__(self, attr): - return getattr(self._module, attr) + return fuse _ratarmount = ratarmount_imports() From defb07d47df8ab3eb2b41f0090316c96ac95ecdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 20 May 2022 23:57:22 +0200 Subject: [PATCH 2/2] TST: add meta tests for on_demand_imports.py --- nose_unit.cfg | 2 +- tests/tests.yaml | 1 + yt/utilities/tests/test_on_demand_imports.py | 48 ++++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 yt/utilities/tests/test_on_demand_imports.py diff --git a/nose_unit.cfg b/nose_unit.cfg index 8256e7baf41..58124fbd08c 100644 --- a/nose_unit.cfg +++ b/nose_unit.cfg @@ -6,5 +6,5 @@ nologcapture=1 verbosity=2 where=yt with-timer=1 -ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py) +ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py) exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF diff --git a/tests/tests.yaml b/tests/tests.yaml index 6455995fd01..6028cb33506 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -198,6 +198,7 @@ other_tests: - "--ignore-files=test_registration.py" - "--ignore-files=test_invalid_origin.py" - "--ignore-files=test_load_archive\\.py" + - "--ignore-file=test_on_demand_imports\\.py" - "--ignore-files=test_outputs_pytest\\.py" - "--ignore-files=test_normal_plot_api\\.py" - "--ignore-file=test_file_sanitizer\\.py" diff --git a/yt/utilities/tests/test_on_demand_imports.py b/yt/utilities/tests/test_on_demand_imports.py new file mode 100644 index 00000000000..549ccfee145 --- /dev/null +++ b/yt/utilities/tests/test_on_demand_imports.py @@ -0,0 +1,48 @@ +import pytest + +from yt.utilities.on_demand_imports import OnDemand, safe_import + + +def test_access_available_module(): + class os_imports(OnDemand): + @safe_import + def path(self): + from os import path + + return path + + _os = os_imports() + + _os.path.join("eggs", "saussage") + + +def test_access_unavailable_module(): + class Bacon_imports(OnDemand): + @safe_import + def spam(self): + from Bacon import spam + + return spam + + _bacon = Bacon_imports() + with pytest.raises( + ImportError, + match=r"This functionality requires the Bacon package to be installed\.", + ): + _bacon.spam() + + +def test_class_invalidation(): + with pytest.raises( + TypeError, match="class .*'s name needs to be suffixed '_imports'" + ): + + class Bacon(OnDemand): + pass + + +def test_base_class_instanciation(): + with pytest.raises( + TypeError, match="The OnDemand base class cannot be instanciated." + ): + OnDemand()