From 74dffffbfea2ba9aea18ce194fe868f2cb00907d Mon Sep 17 00:00:00 2001 From: aurghs <35919497+aurghs@users.noreply.github.com> Date: Thu, 10 Dec 2020 10:56:12 +0100 Subject: [PATCH] Backends entrypoints (#4577) * Define _get_backends_cls function inside apiv2.py to read engines from plugins.py * Read open_backends_dataset_* from entrypoints. * Add backend entrypoints in setup.cfg * Pass apiv2.py isort and black formatting tests. * add dependencies * add backend entrypoints and check on conflicts * black * removed global variable EMGINES add class for entrypointys * black isort * add detect_engines in __all__ init.py * removed entrypoints in py36-bare-minimum.yml and py36-min-all-deps.yml * add entrypoints in IGNORE_DEPS * Plugins test (#20) - replace entrypoints with pkg_resources - add tests * fix typo Co-authored-by: keewis * style Co-authored-by: keewis * style * Code style * Code style * fix: updated plugins.ENGINES with plugins.list_engines() * fix * One more correctness fix of the latest merge from master Co-authored-by: TheRed86 Co-authored-by: keewis Co-authored-by: Alessandro Amici --- setup.cfg | 6 ++ xarray/backends/__init__.py | 2 + xarray/backends/api.py | 2 +- xarray/backends/apiv2.py | 25 ++++---- xarray/backends/cfgrib_.py | 4 ++ xarray/backends/h5netcdf_.py | 6 +- xarray/backends/plugins.py | 107 +++++++++++++++++++++++++---------- xarray/backends/zarr.py | 4 ++ xarray/tests/test_plugins.py | 93 ++++++++++++++++++++++++++++++ 9 files changed, 204 insertions(+), 45 deletions(-) create mode 100644 xarray/tests/test_plugins.py diff --git a/setup.cfg b/setup.cfg index 2a7762fb9f5..0a82f80ebd6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,6 +81,12 @@ setup_requires = setuptools >= 38.4 setuptools_scm +[options.entry_points] +xarray.backends = + zarr = xarray.backends.zarr:zarr_backend + h5netcdf = xarray.backends.h5netcdf_:h5netcdf_backend + cfgrib = xarray.backends.cfgrib_:cfgrib_backend + [options.extras_require] io = netCDF4 diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 2a769b1335e..1500ea5061f 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -9,6 +9,7 @@ from .h5netcdf_ import H5NetCDFStore from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore +from .plugins import list_engines from .pseudonetcdf_ import PseudoNetCDFDataStore from .pydap_ import PydapDataStore from .pynio_ import NioDataStore @@ -29,4 +30,5 @@ "H5NetCDFStore", "ZarrStore", "PseudoNetCDFDataStore", + "list_engines", ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 70ca42c43b5..5d8e553e9de 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -439,7 +439,7 @@ def open_dataset( kwargs = locals().copy() from . import apiv2, plugins - if engine in plugins.ENGINES: + if engine in plugins.list_engines(): return apiv2.open_dataset(**kwargs) if autoclose is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 6fcd97f9de0..4efce991490 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -103,13 +103,13 @@ def _dataset_from_backend_dataset( return ds -def _resolve_decoders_kwargs(decode_cf, engine, **decoders): - signature = plugins.ENGINES[engine]["signature"] - if decode_cf is False: - for d in decoders: - if d in signature: - decoders[d] = False - return {k: v for k, v in decoders.items() if v is not None} +def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders): + for d in list(decoders): + if decode_cf is False and d in open_backend_dataset_parameters: + decoders[d] = False + if decoders[d] is None: + decoders.pop(d) + return decoders def open_dataset( @@ -252,9 +252,12 @@ def open_dataset( if engine is None: engine = _autodetect_engine(filename_or_obj) + engines = plugins.list_engines() + backend = _get_backend_cls(engine, engines=engines) + decoders = _resolve_decoders_kwargs( decode_cf, - engine=engine, + open_backend_dataset_parameters=backend.open_dataset_parameters, mask_and_scale=mask_and_scale, decode_times=decode_times, decode_timedelta=decode_timedelta, @@ -265,11 +268,7 @@ def open_dataset( backend_kwargs = backend_kwargs.copy() overwrite_encoded_chunks = backend_kwargs.pop("overwrite_encoded_chunks", None) - - open_backend_dataset = _get_backend_cls(engine, engines=plugins.ENGINES)[ - "open_dataset" - ] - backend_ds = open_backend_dataset( + backend_ds = backend.open_dataset( filename_or_obj, drop_variables=drop_variables, **decoders, diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 80b3183f197..cf506d67e96 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -7,6 +7,7 @@ from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .locks import SerializableLock, ensure_lock +from .plugins import BackendEntrypoint # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: @@ -126,3 +127,6 @@ def open_backend_dataset_cfgrib( ds.encoding = encoding return ds + + +cfgrib_backend = BackendEntrypoint(open_dataset=open_backend_dataset_cfgrib) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 9d013e3e46e..e3539a05fb1 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -18,6 +18,7 @@ _get_datatype, _nc4_require_group, ) +from .plugins import BackendEntrypoint class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -325,7 +326,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def open_backend_dataset_h5necdf( +def open_backend_dataset_h5netcdf( filename_or_obj, *, mask_and_scale=True, @@ -374,3 +375,6 @@ def open_backend_dataset_h5necdf( ds.encoding = encoding return ds + + +h5netcdf_backend = BackendEntrypoint(open_dataset=open_backend_dataset_h5netcdf) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index a0b1a825600..9e78a3b71e8 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -1,31 +1,78 @@ import inspect -import typing as T - -from . import cfgrib_, h5netcdf_, zarr - -ENGINES: T.Dict[str, T.Dict[str, T.Any]] = { - "h5netcdf": { - "open_dataset": h5netcdf_.open_backend_dataset_h5necdf, - }, - "zarr": { - "open_dataset": zarr.open_backend_dataset_zarr, - }, - "cfgrib": { - "open_dataset": cfgrib_.open_backend_dataset_cfgrib, - }, -} - - -for engine in ENGINES.values(): - if "signature" not in engine: - parameters = inspect.signature(engine["open_dataset"]).parameters - for name, param in parameters.items(): - if param.kind in ( - inspect.Parameter.VAR_KEYWORD, - inspect.Parameter.VAR_POSITIONAL, - ): - raise TypeError( - f'All the parameters in {engine["open_dataset"]!r} signature should be explicit. ' - "*args and **kwargs is not supported" - ) - engine["signature"] = set(parameters) +import itertools +import warnings +from functools import lru_cache + +import pkg_resources + + +class BackendEntrypoint: + __slots__ = ("open_dataset", "open_dataset_parameters") + + def __init__(self, open_dataset, open_dataset_parameters=None): + self.open_dataset = open_dataset + self.open_dataset_parameters = open_dataset_parameters + + +def remove_duplicates(backend_entrypoints): + + # sort and group entrypoints by name + backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name) + backend_entrypoints_grouped = itertools.groupby( + backend_entrypoints, key=lambda ep: ep.name + ) + # check if there are multiple entrypoints for the same name + unique_backend_entrypoints = [] + for name, matches in backend_entrypoints_grouped: + matches = list(matches) + unique_backend_entrypoints.append(matches[0]) + matches_len = len(matches) + if matches_len > 1: + selected_module_name = matches[0].module_name + all_module_names = [e.module_name for e in matches] + warnings.warn( + f"\nFound {matches_len} entrypoints for the engine name {name}:" + f"\n {all_module_names}.\n It will be used: {selected_module_name}.", + RuntimeWarning, + ) + return unique_backend_entrypoints + + +def detect_parameters(open_dataset): + signature = inspect.signature(open_dataset) + parameters = signature.parameters + for name, param in parameters.items(): + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + raise TypeError( + f"All the parameters in {open_dataset!r} signature should be explicit. " + "*args and **kwargs is not supported" + ) + return tuple(parameters) + + +def create_engines_dict(backend_entrypoints): + engines = {} + for backend_ep in backend_entrypoints: + name = backend_ep.name + backend = backend_ep.load() + engines[name] = backend + return engines + + +def set_missing_parameters(engines): + for name, backend in engines.items(): + if backend.open_dataset_parameters is None: + open_dataset = backend.open_dataset + backend.open_dataset_parameters = detect_parameters(open_dataset) + + +@lru_cache(maxsize=1) +def list_engines(): + entrypoints = pkg_resources.iter_entry_points("xarray.backends") + backend_entrypoints = remove_duplicates(entrypoints) + engines = create_engines_dict(backend_entrypoints) + set_missing_parameters(engines) + return engines diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index abf0c7372eb..1c0ebb1dbc9 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -7,6 +7,7 @@ from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error from ..core.variable import Variable from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name +from .plugins import BackendEntrypoint # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -700,3 +701,6 @@ def open_backend_dataset_zarr( ds.encoding = encoding return ds + + +zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py new file mode 100644 index 00000000000..7e9bb58f140 --- /dev/null +++ b/xarray/tests/test_plugins.py @@ -0,0 +1,93 @@ +from unittest import mock + +import pkg_resources +import pytest + +from xarray.backends import plugins + + +def dummy_open_dataset_args(filename_or_obj, *args): + pass + + +def dummy_open_dataset_kwargs(filename_or_obj, **kwargs): + pass + + +def dummy_open_dataset(filename_or_obj, *, decoder): + pass + + +@pytest.fixture +def dummy_duplicated_entrypoints(): + specs = [ + "engine1 = xarray.tests.test_plugins:backend_1", + "engine1 = xarray.tests.test_plugins:backend_2", + "engine2 = xarray.tests.test_plugins:backend_1", + "engine2 = xarray.tests.test_plugins:backend_2", + ] + eps = [pkg_resources.EntryPoint.parse(spec) for spec in specs] + return eps + + +def test_remove_duplicates(dummy_duplicated_entrypoints): + entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints) + assert len(entrypoints) == 2 + + +def test_remove_duplicates_warnings(dummy_duplicated_entrypoints): + + with pytest.warns(RuntimeWarning) as record: + _ = plugins.remove_duplicates(dummy_duplicated_entrypoints) + + assert len(record) == 2 + message0 = str(record[0].message) + message1 = str(record[1].message) + assert "entrypoints" in message0 + assert "entrypoints" in message1 + + +@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None)) +def test_create_engines_dict(): + specs = [ + "engine1 = xarray.tests.test_plugins:backend_1", + "engine2 = xarray.tests.test_plugins:backend_2", + ] + entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs] + engines = plugins.create_engines_dict(entrypoints) + assert len(engines) == 2 + assert engines.keys() == set(("engine1", "engine2")) + + +def test_set_missing_parameters(): + backend_1 = plugins.BackendEntrypoint(dummy_open_dataset) + backend_2 = plugins.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",)) + engines = {"engine_1": backend_1, "engine_2": backend_2} + plugins.set_missing_parameters(engines) + + assert len(engines) == 2 + engine_1 = engines["engine_1"] + assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder") + engine_2 = engines["engine_2"] + assert engine_2.open_dataset_parameters == ("filename_or_obj",) + + +def test_set_missing_parameters_raise_error(): + + backend = plugins.BackendEntrypoint(dummy_open_dataset_args) + with pytest.raises(TypeError): + plugins.set_missing_parameters({"engine": backend}) + + backend = plugins.BackendEntrypoint( + dummy_open_dataset_args, ("filename_or_obj", "decoder") + ) + plugins.set_missing_parameters({"engine": backend}) + + backend = plugins.BackendEntrypoint(dummy_open_dataset_kwargs) + with pytest.raises(TypeError): + plugins.set_missing_parameters({"engine": backend}) + + backend = plugins.BackendEntrypoint( + dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") + ) + plugins.set_missing_parameters({"engine": backend})