Skip to content

Commit

Permalink
Backends entrypoints (#4577)
Browse files Browse the repository at this point in the history
* Define _get_backends_cls function inside apiv2.py to read engines from plugins.py

* Read open_backends_dataset_* from entrypoints.

* Add backend entrypoints in setup.cfg

* Pass apiv2.py isort and black formatting tests.

* add dependencies

* add backend entrypoints and check on conflicts

* black

* removed global variable EMGINES
add class for entrypointys

* black isort

* add detect_engines in __all__ init.py

* removed entrypoints in py36-bare-minimum.yml and py36-min-all-deps.yml

* add entrypoints in IGNORE_DEPS

* Plugins test (#20)

- replace entrypoints with pkg_resources
- add tests

* fix typo

Co-authored-by: keewis <keewis@users.noreply.github.com>

* style

Co-authored-by: keewis <keewis@users.noreply.github.com>

* style

* Code style

* Code style

* fix: updated plugins.ENGINES with plugins.list_engines()

* fix

* One more correctness fix of the latest merge from master

Co-authored-by: TheRed86 <m.rossetti@bopen.eu>
Co-authored-by: keewis <keewis@users.noreply.github.com>
Co-authored-by: Alessandro Amici <a.amici@bopen.eu>
  • Loading branch information
4 people authored Dec 10, 2020
1 parent 8915058 commit 74dffff
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 45 deletions.
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ setup_requires =
setuptools >= 38.4
setuptools_scm

[options.entry_points]
xarray.backends =
zarr = xarray.backends.zarr:zarr_backend
h5netcdf = xarray.backends.h5netcdf_:h5netcdf_backend
cfgrib = xarray.backends.cfgrib_:cfgrib_backend

[options.extras_require]
io =
netCDF4
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .h5netcdf_ import H5NetCDFStore
from .memory import InMemoryDataStore
from .netCDF4_ import NetCDF4DataStore
from .plugins import list_engines
from .pseudonetcdf_ import PseudoNetCDFDataStore
from .pydap_ import PydapDataStore
from .pynio_ import NioDataStore
Expand All @@ -29,4 +30,5 @@
"H5NetCDFStore",
"ZarrStore",
"PseudoNetCDFDataStore",
"list_engines",
]
2 changes: 1 addition & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def open_dataset(
kwargs = locals().copy()
from . import apiv2, plugins

if engine in plugins.ENGINES:
if engine in plugins.list_engines():
return apiv2.open_dataset(**kwargs)

if autoclose is not None:
Expand Down
25 changes: 12 additions & 13 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def _dataset_from_backend_dataset(
return ds


def _resolve_decoders_kwargs(decode_cf, engine, **decoders):
signature = plugins.ENGINES[engine]["signature"]
if decode_cf is False:
for d in decoders:
if d in signature:
decoders[d] = False
return {k: v for k, v in decoders.items() if v is not None}
def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders):
for d in list(decoders):
if decode_cf is False and d in open_backend_dataset_parameters:
decoders[d] = False
if decoders[d] is None:
decoders.pop(d)
return decoders


def open_dataset(
Expand Down Expand Up @@ -252,9 +252,12 @@ def open_dataset(
if engine is None:
engine = _autodetect_engine(filename_or_obj)

engines = plugins.list_engines()
backend = _get_backend_cls(engine, engines=engines)

decoders = _resolve_decoders_kwargs(
decode_cf,
engine=engine,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
Expand All @@ -265,11 +268,7 @@ def open_dataset(

backend_kwargs = backend_kwargs.copy()
overwrite_encoded_chunks = backend_kwargs.pop("overwrite_encoded_chunks", None)

open_backend_dataset = _get_backend_cls(engine, engines=plugins.ENGINES)[
"open_dataset"
]
backend_ds = open_backend_dataset(
backend_ds = backend.open_dataset(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
Expand Down
4 changes: 4 additions & 0 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..core.variable import Variable
from .common import AbstractDataStore, BackendArray
from .locks import SerializableLock, ensure_lock
from .plugins import BackendEntrypoint

# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe
# in most circumstances. See:
Expand Down Expand Up @@ -126,3 +127,6 @@ def open_backend_dataset_cfgrib(
ds.encoding = encoding

return ds


cfgrib_backend = BackendEntrypoint(open_dataset=open_backend_dataset_cfgrib)
6 changes: 5 additions & 1 deletion xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_get_datatype,
_nc4_require_group,
)
from .plugins import BackendEntrypoint


class H5NetCDFArrayWrapper(BaseNetCDF4Array):
Expand Down Expand Up @@ -325,7 +326,7 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


def open_backend_dataset_h5necdf(
def open_backend_dataset_h5netcdf(
filename_or_obj,
*,
mask_and_scale=True,
Expand Down Expand Up @@ -374,3 +375,6 @@ def open_backend_dataset_h5necdf(
ds.encoding = encoding

return ds


h5netcdf_backend = BackendEntrypoint(open_dataset=open_backend_dataset_h5netcdf)
107 changes: 77 additions & 30 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,78 @@
import inspect
import typing as T

from . import cfgrib_, h5netcdf_, zarr

ENGINES: T.Dict[str, T.Dict[str, T.Any]] = {
"h5netcdf": {
"open_dataset": h5netcdf_.open_backend_dataset_h5necdf,
},
"zarr": {
"open_dataset": zarr.open_backend_dataset_zarr,
},
"cfgrib": {
"open_dataset": cfgrib_.open_backend_dataset_cfgrib,
},
}


for engine in ENGINES.values():
if "signature" not in engine:
parameters = inspect.signature(engine["open_dataset"]).parameters
for name, param in parameters.items():
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
raise TypeError(
f'All the parameters in {engine["open_dataset"]!r} signature should be explicit. '
"*args and **kwargs is not supported"
)
engine["signature"] = set(parameters)
import itertools
import warnings
from functools import lru_cache

import pkg_resources


class BackendEntrypoint:
__slots__ = ("open_dataset", "open_dataset_parameters")

def __init__(self, open_dataset, open_dataset_parameters=None):
self.open_dataset = open_dataset
self.open_dataset_parameters = open_dataset_parameters


def remove_duplicates(backend_entrypoints):

# sort and group entrypoints by name
backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name)
backend_entrypoints_grouped = itertools.groupby(
backend_entrypoints, key=lambda ep: ep.name
)
# check if there are multiple entrypoints for the same name
unique_backend_entrypoints = []
for name, matches in backend_entrypoints_grouped:
matches = list(matches)
unique_backend_entrypoints.append(matches[0])
matches_len = len(matches)
if matches_len > 1:
selected_module_name = matches[0].module_name
all_module_names = [e.module_name for e in matches]
warnings.warn(
f"\nFound {matches_len} entrypoints for the engine name {name}:"
f"\n {all_module_names}.\n It will be used: {selected_module_name}.",
RuntimeWarning,
)
return unique_backend_entrypoints


def detect_parameters(open_dataset):
signature = inspect.signature(open_dataset)
parameters = signature.parameters
for name, param in parameters.items():
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
raise TypeError(
f"All the parameters in {open_dataset!r} signature should be explicit. "
"*args and **kwargs is not supported"
)
return tuple(parameters)


def create_engines_dict(backend_entrypoints):
engines = {}
for backend_ep in backend_entrypoints:
name = backend_ep.name
backend = backend_ep.load()
engines[name] = backend
return engines


def set_missing_parameters(engines):
for name, backend in engines.items():
if backend.open_dataset_parameters is None:
open_dataset = backend.open_dataset
backend.open_dataset_parameters = detect_parameters(open_dataset)


@lru_cache(maxsize=1)
def list_engines():
entrypoints = pkg_resources.iter_entry_points("xarray.backends")
backend_entrypoints = remove_duplicates(entrypoints)
engines = create_engines_dict(backend_entrypoints)
set_missing_parameters(engines)
return engines
4 changes: 4 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error
from ..core.variable import Variable
from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name
from .plugins import BackendEntrypoint

# need some special secret attributes to tell us the dimensions
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
Expand Down Expand Up @@ -700,3 +701,6 @@ def open_backend_dataset_zarr(
ds.encoding = encoding

return ds


zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr)
93 changes: 93 additions & 0 deletions xarray/tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from unittest import mock

import pkg_resources
import pytest

from xarray.backends import plugins


def dummy_open_dataset_args(filename_or_obj, *args):
pass


def dummy_open_dataset_kwargs(filename_or_obj, **kwargs):
pass


def dummy_open_dataset(filename_or_obj, *, decoder):
pass


@pytest.fixture
def dummy_duplicated_entrypoints():
specs = [
"engine1 = xarray.tests.test_plugins:backend_1",
"engine1 = xarray.tests.test_plugins:backend_2",
"engine2 = xarray.tests.test_plugins:backend_1",
"engine2 = xarray.tests.test_plugins:backend_2",
]
eps = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
return eps


def test_remove_duplicates(dummy_duplicated_entrypoints):
entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints)
assert len(entrypoints) == 2


def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):

with pytest.warns(RuntimeWarning) as record:
_ = plugins.remove_duplicates(dummy_duplicated_entrypoints)

assert len(record) == 2
message0 = str(record[0].message)
message1 = str(record[1].message)
assert "entrypoints" in message0
assert "entrypoints" in message1


@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
def test_create_engines_dict():
specs = [
"engine1 = xarray.tests.test_plugins:backend_1",
"engine2 = xarray.tests.test_plugins:backend_2",
]
entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
engines = plugins.create_engines_dict(entrypoints)
assert len(engines) == 2
assert engines.keys() == set(("engine1", "engine2"))


def test_set_missing_parameters():
backend_1 = plugins.BackendEntrypoint(dummy_open_dataset)
backend_2 = plugins.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",))
engines = {"engine_1": backend_1, "engine_2": backend_2}
plugins.set_missing_parameters(engines)

assert len(engines) == 2
engine_1 = engines["engine_1"]
assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder")
engine_2 = engines["engine_2"]
assert engine_2.open_dataset_parameters == ("filename_or_obj",)


def test_set_missing_parameters_raise_error():

backend = plugins.BackendEntrypoint(dummy_open_dataset_args)
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(
dummy_open_dataset_args, ("filename_or_obj", "decoder")
)
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(dummy_open_dataset_kwargs)
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(
dummy_open_dataset_kwargs, ("filename_or_obj", "decoder")
)
plugins.set_missing_parameters({"engine": backend})

0 comments on commit 74dffff

Please sign in to comment.