Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backends entrypoints #4577

Merged
merged 22 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ setup_requires =
setuptools >= 38.4
setuptools_scm

[options.entry_points]
xarray.backends =
zarr = xarray.backends.zarr:zarr_backend
h5netcdf = xarray.backends.h5netcdf_:h5netcdf_backend
cfgrib = xarray.backends.cfgrib_:cfgrib_backend

[options.extras_require]
io =
netCDF4
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .h5netcdf_ import H5NetCDFStore
from .memory import InMemoryDataStore
from .netCDF4_ import NetCDF4DataStore
from .plugins import list_engines
from .pseudonetcdf_ import PseudoNetCDFDataStore
from .pydap_ import PydapDataStore
from .pynio_ import NioDataStore
Expand All @@ -29,4 +30,5 @@
"H5NetCDFStore",
"ZarrStore",
"PseudoNetCDFDataStore",
"list_engines",
]
2 changes: 1 addition & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def open_dataset(
kwargs = locals().copy()
from . import apiv2, plugins

if engine in plugins.ENGINES:
if engine in plugins.list_engines():
return apiv2.open_dataset(**kwargs)

if autoclose is not None:
Expand Down
24 changes: 12 additions & 12 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def dataset_from_backend_dataset(
return ds2


def resolve_decoders_kwargs(decode_cf, engine, **decoders):
signature = plugins.ENGINES[engine]["signature"]
if decode_cf is False:
for d in decoders:
if d in signature:
decoders[d] = False
return {k: v for k, v in decoders.items() if v is not None}
def resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders):
for d in list(decoders):
if decode_cf is False and d in open_backend_dataset_parameters:
decoders[d] = False
if decoders[d] is None:
decoders.pop(d)
return decoders
Comment on lines +107 to +112
Copy link
Collaborator

@keewis keewis Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this sort of filtering and transformation I would usually rely on comprehensions:

Suggested change
for d in list(decoders):
if decode_cf is False and d in open_backend_dataset_parameters:
decoders[d] = False
if decoders[d] is None:
decoders.pop(d)
return decoders
def choose_decoder(name):
if decode_cf is False and name in open_backend_dataset_parameters:
return False
return decoder
return {
choose_decoder(name)
for name, decoder in decoders.items()
if decoder is not None
}

but there might be a better (more readable) way to express the conditional transformation of decoder. Maybe use a local function with a descriptive name? Edit: I added the local function, but I'm still searching for a better name



def open_dataset(
Expand Down Expand Up @@ -219,9 +219,12 @@ def open_dataset(
if engine is None:
engine = _autodetect_engine(filename_or_obj)

engines = plugins.list_engines()
backend = _get_backend_cls(engine, engines=engines)

decoders = resolve_decoders_kwargs(
decode_cf,
engine=engine,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
Expand All @@ -233,10 +236,7 @@ def open_dataset(
backend_kwargs = backend_kwargs.copy()
overwrite_encoded_chunks = backend_kwargs.pop("overwrite_encoded_chunks", None)

open_backend_dataset = _get_backend_cls(engine, engines=plugins.ENGINES)[
"open_dataset"
]
backend_ds = open_backend_dataset(
backend_ds = backend.open_dataset(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
Expand Down
4 changes: 4 additions & 0 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..core.variable import Variable
from .common import AbstractDataStore, BackendArray
from .locks import SerializableLock, ensure_lock
from .plugins import BackendEntrypoint

# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe
# in most circumstances. See:
Expand Down Expand Up @@ -126,3 +127,6 @@ def open_backend_dataset_cfgrib(
ds.encoding = encoding

return ds


cfgrib_backend = BackendEntrypoint(open_dataset=open_backend_dataset_cfgrib)
6 changes: 5 additions & 1 deletion xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_get_datatype,
_nc4_require_group,
)
from .plugins import BackendEntrypoint


class H5NetCDFArrayWrapper(BaseNetCDF4Array):
Expand Down Expand Up @@ -325,7 +326,7 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


def open_backend_dataset_h5necdf(
def open_backend_dataset_h5netcdf(
filename_or_obj,
*,
mask_and_scale=True,
Expand Down Expand Up @@ -374,3 +375,6 @@ def open_backend_dataset_h5necdf(
ds.encoding = encoding

return ds


h5netcdf_backend = BackendEntrypoint(open_dataset=open_backend_dataset_h5netcdf)
107 changes: 77 additions & 30 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,78 @@
import inspect
import typing as T

from . import cfgrib_, h5netcdf_, zarr

ENGINES: T.Dict[str, T.Dict[str, T.Any]] = {
"h5netcdf": {
"open_dataset": h5netcdf_.open_backend_dataset_h5necdf,
},
"zarr": {
"open_dataset": zarr.open_backend_dataset_zarr,
},
"cfgrib": {
"open_dataset": cfgrib_.open_backend_dataset_cfgrib,
},
}


for engine in ENGINES.values():
if "signature" not in engine:
parameters = inspect.signature(engine["open_dataset"]).parameters
for name, param in parameters.items():
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
raise TypeError(
f'All the parameters in {engine["open_dataset"]!r} signature should be explicit. '
"*args and **kwargs is not supported"
)
engine["signature"] = set(parameters)
import itertools
import warnings
from functools import lru_cache

import pkg_resources


class BackendEntrypoint:
__slots__ = ("open_dataset", "open_dataset_parameters")

def __init__(self, open_dataset, open_dataset_parameters=None):
self.open_dataset = open_dataset
self.open_dataset_parameters = open_dataset_parameters

jhamman marked this conversation as resolved.
Show resolved Hide resolved

def remove_duplicates(backend_entrypoints):

# sort and group entrypoints by name
backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name)
backend_entrypoints_grouped = itertools.groupby(
backend_entrypoints, key=lambda ep: ep.name
)
# check if there are multiple entrypoints for the same name
unique_backend_entrypoints = []
for name, matches in backend_entrypoints_grouped:
matches = list(matches)
unique_backend_entrypoints.append(matches[0])
matches_len = len(matches)
if matches_len > 1:
selected_module_name = matches[0].module_name
all_module_names = [e.module_name for e in matches]
warnings.warn(
f"\nFound {matches_len} entrypoints for the engine name {name}:"
f"\n {all_module_names}.\n It will be used: {selected_module_name}.",
RuntimeWarning,
)
return unique_backend_entrypoints
keewis marked this conversation as resolved.
Show resolved Hide resolved


def detect_parameters(open_dataset):
signature = inspect.signature(open_dataset)
parameters = signature.parameters
for name, param in parameters.items():
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
raise TypeError(
f"All the parameters in {open_dataset!r} signature should be explicit. "
"*args and **kwargs is not supported"
)
return tuple(parameters)


def create_engines_dict(backend_entrypoints):
engines = {}
for backend_ep in backend_entrypoints:
name = backend_ep.name
backend = backend_ep.load()
engines[name] = backend
return engines


def set_missing_parameters(engines):
for name, backend in engines.items():
if backend.open_dataset_parameters is None:
open_dataset = backend.open_dataset
backend.open_dataset_parameters = detect_parameters(open_dataset)
jhamman marked this conversation as resolved.
Show resolved Hide resolved


@lru_cache(maxsize=1)
def list_engines():
entrypoints = pkg_resources.iter_entry_points("xarray.backends")
backend_entrypoints = remove_duplicates(entrypoints)
engines = create_engines_dict(backend_entrypoints)
set_missing_parameters(engines)
return engines
4 changes: 4 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error
from ..core.variable import Variable
from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name
from .plugins import BackendEntrypoint

# need some special secret attributes to tell us the dimensions
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
Expand Down Expand Up @@ -741,3 +742,6 @@ def open_backend_dataset_zarr(
ds.encoding = encoding

return ds


zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr)
96 changes: 96 additions & 0 deletions xarray/tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from unittest import mock

import pkg_resources
import pytest

from xarray.backends import plugins


def dummy_open_dataset_args(filename_or_obj, *args):
pass


def dummy_open_dataset_kwargs(filename_or_obj, **kwargs):
pass


def dummy_open_dataset(filename_or_obj, *, decoder):
pass


@pytest.fixture
def dummy_duplicated_entrypoints():
specs = [
"engine1 = xarray.tests.test_plugins:backend_1",
"engine1 = xarray.tests.test_plugins:backend_2",
"engine2 = xarray.tests.test_plugins:backend_1",
"engine2 = xarray.tests.test_plugins:backend_2",
]
eps = [
pkg_resources.EntryPoint.parse(spec)
for spec in specs
]
return eps


def test_remove_duplicates(dummy_duplicated_entrypoints):
entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints)
assert len(entrypoints) == 2


def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):

with pytest.warns(RuntimeWarning) as record:
_ = plugins.remove_duplicates(dummy_duplicated_entrypoints)

assert len(record) == 2
message0 = str(record[0].message)
message1 = str(record[1].message)
assert "entrypoints" in message0
assert "entrypoints" in message1


@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
def test_create_engines_dict():
specs = [
"engine1 = xarray.tests.test_plugins:backend_1",
"engine2 = xarray.tests.test_plugins:backend_2",
]
entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
engines = plugins.create_engines_dict(entrypoints)
assert len(engines) == 2
assert engines.keys() == set(("engine1", "engine2"))


def test_set_missing_parameters():
backend_1 = plugins.BackendEntrypoint(dummy_open_dataset)
backend_2 = plugins.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",))
engines = {"engine_1": backend_1, "engine_2": backend_2}
plugins.set_missing_parameters(engines)

assert len(engines) == 2
engine_1 = engines["engine_1"]
assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder")
engine_2 = engines["engine_2"]
assert engine_2.open_dataset_parameters == ("filename_or_obj",)


def test_set_missing_parameters_raise_error():

backend = plugins.BackendEntrypoint(dummy_open_dataset_args)
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(
dummy_open_dataset_args, ("filename_or_obj", "decoder")
)
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(dummy_open_dataset_kwargs)
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})

backend = plugins.BackendEntrypoint(
dummy_open_dataset_kwargs, ("filename_or_obj", "decoder")
)
plugins.set_missing_parameters({"engine": backend})