diff --git a/yt/utilities/on_demand_imports.py b/yt/utilities/on_demand_imports.py index 24de0206bf3..ba112a7a3ed 100644 --- a/yt/utilities/on_demand_imports.py +++ b/yt/utilities/on_demand_imports.py @@ -1,7 +1,6 @@ import sys -from functools import wraps -from importlib.util import find_spec -from typing import Type + +from packaging.version import Version class NotAModule: @@ -23,47 +22,40 @@ def __getattr__(self, item): def __call__(self, *args, **kwargs): raise self.error - def __repr__(self) -> str: - return f"NotAModule({self.pkg_name!r})" - -class OnDemand: - _default_factory: Type[NotAModule] = NotAModule - - def __init_subclass__(cls): - if not cls.__name__.endswith("_imports"): - raise TypeError(f"class {cls}'s name needs to be suffixed '_imports'") +class NotCartopy(NotAModule): + """ + A custom class to return error messages dependent on system installation + for cartopy imports. + """ - def __new__(cls): - if cls is OnDemand: - raise TypeError("The OnDemand base class cannot be instanciated.") + def __init__(self, pkg_name): + self.pkg_name = pkg_name + if any(s in sys.version for s in ("Anaconda", "Continuum")): + # the conda-based installs of cartopy don't have issues with the + # GEOS library, so the error message for users with conda can be + # relatively short. Discussion related to this is in + # yt-project/yt#1966 + self.error = ImportError( + "This functionality requires the %s " + "package to be installed." % self.pkg_name + ) else: - return object.__new__(cls) - - @property - def _name(self) -> str: - _name, _, _suffix = self.__class__.__name__.rpartition("_") - return _name - - @property - def __is_available__(self) -> bool: - # special protocol to support testing framework - return find_spec(self._name) is not None - - -def safe_import(func): - @property - @wraps(func) - def inner(self): - try: - return func(self) - except ImportError: - return self._default_factory(self._name) + self.error = ImportError( + "This functionality requires the %s " + "package to be installed. Try installing proj4 and " + "geos with your package manager and building shapely " + "and cartopy from source with: \n \n " + "python -m pip install --no-binary :all: shapely cartopy \n \n" + "For further instruction please refer to the " + "yt documentation." % self.pkg_name + ) - return inner +class netCDF4_imports: + _name = "netCDF4" + _Dataset = None -class netCDF4_imports(OnDemand): def __init__(self): # this ensures the import ordering between netcdf4 and h5py. If h5py is # imported first, can get file lock errors on some systems (including travis-ci) @@ -74,163 +66,189 @@ def __init__(self): except ImportError: pass - @safe_import + @property def Dataset(self): - from netCDF4 import Dataset - - return Dataset + if self._Dataset is None: + try: + from netCDF4 import Dataset + except ImportError: + Dataset = NotAModule(self._name) + self._Dataset = Dataset + return self._Dataset _netCDF4 = netCDF4_imports() -class astropy_imports(OnDemand): - @safe_import - def log(self): - from astropy import log - - if log.exception_logging_enabled(): - log.disable_exception_logging() - - return log +class astropy_imports: + _name = "astropy" + _pyfits = None - @safe_import + @property def pyfits(self): - from astropy.io import fits - - return fits - - @safe_import - def pywcs(self): - import astropy.wcs as pywcs - - self.log - return pywcs + if self._pyfits is None: + try: + import astropy.io.fits as pyfits - @safe_import - def units(self): - from astropy import units + self.log + except ImportError: + pyfits = NotAModule(self._name) + self._pyfits = pyfits + return self._pyfits - self.log - return units + _pywcs = None - @safe_import - def conv(self): - import astropy.convolution as conv - - self.log - return conv + @property + def pywcs(self): + if self._pywcs is None: + try: + import astropy.wcs as pywcs - @safe_import - def time(self): - import astropy.time as time + self.log + except ImportError: + pywcs = NotAModule(self._name) + self._pywcs = pywcs + return self._pywcs - self.log - return time + _log = None - @safe_import - def wcsaxes(self): - from astropy.visualization import wcsaxes + @property + def log(self): + if self._log is None: + try: + from astropy import log - self.log - return wcsaxes + if log.exception_logging_enabled(): + log.disable_exception_logging() + except ImportError: + log = NotAModule(self._name) + self._log = log + return self._log + _units = None -_astropy = astropy_imports() + @property + def units(self): + if self._units is None: + try: + from astropy import units + self.log + except ImportError: + units = NotAModule(self._name) + self._units = units + return self._units -class NotCartopy(NotAModule): - """ - A custom class to return error messages dependent on system installation - for cartopy imports. - """ + _conv = None - def __init__(self, pkg_name): - self.pkg_name = pkg_name - if any(s in sys.version for s in ("Anaconda", "Continuum")): - # the conda-based installs of cartopy don't have issues with the - # GEOS library, so the error message for users with conda can be - # relatively short. Discussion related to this is in - # yt-project/yt#1966 - self.error = ImportError( - "This functionality requires the %s " - "package to be installed." % self.pkg_name - ) - else: - self.error = ImportError( - "This functionality requires the %s " - "package to be installed. Try installing proj4 and " - "geos with your package manager and building shapely " - "and cartopy from source with: \n \n " - "python -m pip install --no-binary :all: shapely cartopy \n \n" - "For further instruction please refer to the " - "yt documentation." % self.pkg_name - ) + @property + def conv(self): + if self._conv is None: + try: + import astropy.convolution as conv + self.log + except ImportError: + conv = NotAModule(self._name) + self._conv = conv + return self._conv -class cartopy_imports(OnDemand): + _time = None - _default_factory = NotCartopy + @property + def time(self): + if self._time is None: + try: + import astropy.time as time - @safe_import - def crs(self): - import cartopy.crs as crs + self.log + except ImportError: + time = NotAModule(self._name) + self._time = time + return self._time - return crs + _wcsaxes = None + @property + def wcsaxes(self): + if self._wcsaxes is None: + try: + import astropy.visualization.wcsaxes as wcsaxes -_cartopy = cartopy_imports() + self.log + except ImportError: + wcsaxes = NotAModule(self._name) + self._wcsaxes = wcsaxes + return self._wcsaxes + _version = None -class pooch_imports(OnDemand): - @safe_import - def HTTPDownloader(self): - from pooch import HTTPDownloader + @property + def __version__(self): + if self._version is None: + try: + import astropy - return HTTPDownloader + version = astropy.__version__ + except ImportError: + version = NotAModule(self._name) + self._version = version + return self._version - @safe_import - def utils(self): - from pooch import utils - return utils +_astropy = astropy_imports() - @safe_import - def create(self): - from pooch import create - return create +class cartopy_imports: + _name = "cartopy" + _crs = None -_pooch = pooch_imports() + @property + def crs(self): + if self._crs is None: + try: + import cartopy.crs as crs + except ImportError: + crs = NotCartopy(self._name) + self._crs = crs + return self._crs + _version = None -class pyart_imports(OnDemand): - @safe_import - def io(self): - from pyart import io + @property + def __version__(self): + if self._version is None: + try: + import cartopy - return io + version = cartopy.__version__ + except ImportError: + version = NotCartopy(self._name) + self._version = version + return self._version - @safe_import - def map(self): - from pyart import map - return map +_cartopy = cartopy_imports() -_pyart = pyart_imports() +class pooch_imports: + _name = "pooch" + _module = None + def __init__(self): + try: + import pooch as myself -class xarray_imports(OnDemand): - @safe_import - def open_dataset(self): - from xarray import open_dataset + self._module = myself + except ImportError: + self._module = NotAModule(self._name) - return open_dataset + def __getattr__(self, attr): + return getattr(self._module, attr) -_xarray = xarray_imports() +_pooch = pooch_imports() class scipy_imports: @@ -331,102 +349,232 @@ def ndimage(self): self._ndimage = ndimage return self._ndimage - @property - def __is_available__(self) -> bool: - # special protocol to support testing framework - return find_spec(self._name) is not None - _scipy = scipy_imports() -class h5py_imports(OnDemand): - @safe_import +class h5py_imports: + _name = "h5py" + _err = None + + def __init__(self): + try: + import h5py + + if Version(h5py.__version__) < Version("2.4.0"): + self._err = RuntimeError( + "yt requires h5py version 2.4.0 or newer, " + "please update h5py with e.g. `python -m pip install -U h5py` " + "and try again" + ) + except ImportError: + pass + super().__init__() + + _File = None + + @property def File(self): - from h5py import File + if self._err: + raise self._err + if self._File is None: + try: + from h5py import File + except ImportError: + File = NotAModule(self._name) + self._File = File + return self._File - return File + _Group = None - @safe_import + @property def Group(self): - from h5py import Group + if self._err: + raise self._err + if self._Group is None: + try: + from h5py import Group + except ImportError: + Group = NotAModule(self._name) + self._Group = Group + return self._Group - return Group + _Dataset = None - @safe_import + @property def Dataset(self): - from h5py import Dataset + if self._err: + raise self._err + if self._Dataset is None: + try: + from h5py import Dataset + except ImportError: + Dataset = NotAModule(self._name) + self._Dataset = Dataset + return self._Dataset + + ___version__ = None + + @property + def __version__(self): + if self._err: + raise self._err + if self.___version__ is None: + try: + from h5py import __version__ + except ImportError: + __version__ = NotAModule(self._name) + self.___version__ = __version__ + return self.___version__ - return Dataset + _get_config = None - @safe_import + @property def get_config(self): - from h5py import get_config + if self._err: + raise self._err + if self._get_config is None: + try: + from h5py import get_config + except ImportError: + get_config = NotAModule(self._name) + self._get_config = get_config + return self._get_config - return get_config + _h5f = None - @safe_import + @property def h5f(self): - from h5py import h5f + if self._err: + raise self._err + if self._h5f is None: + try: + import h5py.h5f as h5f + except ImportError: + h5f = NotAModule(self._name) + self._h5f = h5f + return self._h5f - return h5f + _h5p = None - @safe_import + @property def h5p(self): - from h5py import h5p + if self._err: + raise self._err + if self._h5p is None: + try: + import h5py.h5p as h5p + except ImportError: + h5p = NotAModule(self._name) + self._h5p = h5p + return self._h5p - return h5p + _h5d = None - @safe_import + @property def h5d(self): - from h5py import h5d + if self._err: + raise self._err + if self._h5d is None: + try: + import h5py.h5d as h5d + except ImportError: + h5d = NotAModule(self._name) + self._h5d = h5d + return self._h5d - return h5d + _h5s = None - @safe_import + @property def h5s(self): - from h5py import h5s + if self._err: + raise self._err + if self._h5s is None: + try: + import h5py.h5s as h5s + except ImportError: + h5s = NotAModule(self._name) + self._h5s = h5s + return self._h5s + + _version = None - return h5s + @property + def version(self): + if self._err: + raise self._err + if self._version is None: + try: + import h5py.version as version + except ImportError: + version = NotAModule(self._name) + self._version = version + return self._version _h5py = h5py_imports() -class nose_imports(OnDemand): - @safe_import - def run(self): - from nose import run +class nose_imports: + _name = "nose" + _run = None - return run + @property + def run(self): + if self._run is None: + try: + from nose import run + except ImportError: + run = NotAModule(self._name) + self._run = run + return self._run _nose = nose_imports() -class libconf_imports(OnDemand): - @safe_import - def load(self): - from libconf import load +class libconf_imports: + _name = "libconf" + _load = None - return load + @property + def load(self): + if self._load is None: + try: + from libconf import load + except ImportError: + load = NotAModule(self._name) + self._load = load + return self._load _libconf = libconf_imports() -class yaml_imports(OnDemand): - @safe_import - def load(self): - from yaml import load +class yaml_imports: + _name = "yaml" + _load = None + _FullLoader = None - return load + @property + def load(self): + if self._load is None: + try: + from yaml import load + except ImportError: + load = NotAModule(self._name) + self._load = load + return self._load - @safe_import + @property def FullLoader(self): - from yaml import FullLoader - - return FullLoader + if self._FullLoader is None: + try: + from yaml import FullLoader + except ImportError: + FullLoader = NotAModule(self._name) + self._FullLoader = FullLoader + return self._FullLoader _yaml = yaml_imports() @@ -444,122 +592,76 @@ def __init__(self, pkg_name): self.error = ImportError(str % self.pkg_name) -class miniball_imports(OnDemand): - @safe_import - def Miniball(self): - from miniball import Miniball +class miniball_imports: + _name = "miniball" + _Miniball = None - return Miniball + @property + def Miniball(self): + if self._Miniball is None: + try: + from miniball import Miniball + except ImportError: + Miniball = NotMiniball(self._name) + self._Miniball = Miniball + return self._Miniball _miniball = miniball_imports() -class f90nml_imports(OnDemand): - @safe_import - def read(self): - from f90nml import read +class f90nml_imports: + _name = "f90nml" + _module = None - return read + def __init__(self): + try: + import f90nml as myself - @safe_import - def Namelist(self): - from f90nml import Namelist + self._module = myself + except ImportError: + self._module = NotAModule(self._name) - return Namelist + def __getattr__(self, attr): + return getattr(self._module, attr) _f90nml = f90nml_imports() -class requests_imports(OnDemand): - @safe_import - def post(self): - from requests import post - - return post +class requests_imports: + _name = "requests" + _module = None - @safe_import - def put(self): - from requests import put - - return put - - @safe_import - def codes(self): - from requests import codes - - return codes - - @safe_import - def get(self): - from requests import get - - return get + def __init__(self): + try: + import requests as myself - @safe_import - def exceptions(self): - from requests import exceptions + self._module = myself + except ImportError: + self._module = NotAModule(self._name) - return exceptions + def __getattr__(self, attr): + return getattr(self._module, attr) _requests = requests_imports() -class pandas_imports(OnDemand): - @safe_import - def NA(self): - from pandas import NA - - return NA +class pandas_imports: + _name = "pandas" + _module = None - @safe_import - def DataFrame(self): - from pandas import DataFrame - - return DataFrame + def __init__(self): + try: + import pandas as myself - @safe_import - def concat(self): - from pandas import concat + self._module = myself + except ImportError: + self._module = NotAModule(self._name) - return concat + def __getattr__(self, attr): + return getattr(self._module, attr) _pandas = pandas_imports() - - -class Firefly_imports(OnDemand): - @safe_import - def data_reader(self): - import Firefly.data_reader as data_reader - - return data_reader - - @safe_import - def server(self): - import Firefly.server as server - - return server - - -_firefly = Firefly_imports() - - -# Note: ratarmount may fail with an OSError on import if libfuse is missing -class ratarmount_imports(OnDemand): - @safe_import - def TarMount(self): - from ratarmount import TarMount - - return TarMount - - @safe_import - def fuse(self): - from ratarmount import fuse - - return fuse - - -_ratarmount = ratarmount_imports()