diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aa499731f4a..cf7d29c1d11 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -156,6 +156,21 @@ Internal Changes rewrite of the indexer key (:issue: `8377`, :pull:`8758`) By `Anderson Banihirwe `_. +- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`) By + `Matt Savoie `_ and `Tom Nicholas + `_. + +- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) By + `Matt Savoie `_ and `Tom Nicholas + `_. + +- Migrates ``datatree`` functionality into ``xarray/core``. (:pull: `8757`) + By `Owen Littlejohns `_. + +- Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary rewrite of the indexer key + (:issue: `8377`, :pull:`8758`) By `Anderson Banihirwe ` + .. _whats-new.2024.01.1: v2024.01.1 (23 Jan, 2024) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d3026a535e2..637eea4d076 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -69,7 +69,7 @@ T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 7d3cc00a52d..f318b4dd42f 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -23,8 +23,8 @@ from netCDF4 import Dataset as ncDataset from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree from xarray.core.types import NestedSequence - from xarray.datatree_.datatree import DataTree # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -137,8 +137,8 @@ def _open_datatree_netcdf( **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6720a67ae2f..ae86c4ce384 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -45,7 +45,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e9465dc0ba0..13b1819f206 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -34,7 +34,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # need some special secret attributes to tell us the dimensions @@ -1048,8 +1048,8 @@ def open_datatree( import zarr from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree zds = zarr.open_group(filename_or_obj, mode="r") ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/core/datatree.py similarity index 93% rename from xarray/datatree_/datatree/datatree.py rename to xarray/core/datatree.py index 10133052185..41341006380 100644 --- a/xarray/datatree_/datatree/datatree.py +++ b/xarray/core/datatree.py @@ -2,25 +2,14 @@ import copy import itertools -from collections import OrderedDict +from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generic, - Hashable, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, NoReturn, - Optional, - Set, - Tuple, - Union, overload, ) @@ -31,6 +20,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.treenode import NamedNode, NodePath, Tree from xarray.core.utils import ( Default, Frozen, @@ -40,17 +30,22 @@ maybe_wrap_array, ) from xarray.core.variable import Variable - -from . import formatting, formatting_html -from .common import TreeAttrAccessMixin -from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree -from .ops import ( +from xarray.datatree_.datatree.common import TreeAttrAccessMixin +from xarray.datatree_.datatree.formatting import datatree_repr +from xarray.datatree_.datatree.formatting_html import ( + datatree_repr as datatree_repr_html, +) +from xarray.datatree_.datatree.mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) +from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, ) -from .render import RenderTree -from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.datatree_.datatree.render import RenderTree try: from xarray.core.variable import calculate_dimensions @@ -60,6 +55,7 @@ if TYPE_CHECKING: import pandas as pd + from xarray.core.merge import CoercibleValue from xarray.core.types import ErrorOptions @@ -77,7 +73,7 @@ # """ -T_Path = Union[str, NodePath] +T_Path = str | NodePath def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: @@ -130,9 +126,9 @@ class DatasetView(Dataset): def __init__( self, - data_vars: Optional[Mapping[Any, Any]] = None, - coords: Optional[Mapping[Any, Any]] = None, - attrs: Optional[Mapping[Any, Any]] = None, + data_vars: Mapping[Any, Any] | None = None, + coords: Mapping[Any, Any] | None = None, + attrs: Mapping[Any, Any] | None = None, ): raise AttributeError("DatasetView objects are not to be initialized directly") @@ -178,8 +174,7 @@ def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] ... @overload - def __getitem__(self, key: Any) -> Dataset: - ... + def __getitem__(self, key: Any) -> Dataset: ... def __getitem__(self, key) -> DataArray: # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes @@ -191,11 +186,11 @@ def _construct_direct( cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, - close: Optional[Callable[[], None]] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + close: Callable[[], None] | None = None, ) -> Dataset: """ Overriding this method (along with ._replace) and modifying it to return a Dataset object @@ -217,11 +212,11 @@ def _construct_direct( def _replace( self, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -259,7 +254,7 @@ def map( Function which can be called in the form `func(x, *args, **kwargs)` to transform each DataArray `x` in this dataset into another DataArray. - keep_attrs : bool or None, optional + keep_attrs : bool | None, optional If True, both the dataset's and variables' attributes (`attrs`) will be copied from the original objects to the new ones. If False, the new dataset and variables will be returned without copying the attributes. @@ -337,17 +332,17 @@ class DataTree( # TODO all groupby classes - _name: Optional[str] - _parent: Optional[DataTree] - _children: OrderedDict[str, DataTree] - _attrs: Optional[Dict[Hashable, Any]] - _cache: Dict[str, Any] - _coord_names: Set[Hashable] - _dims: Dict[Hashable, int] - _encoding: Optional[Dict[Hashable, Any]] - _close: Optional[Callable[[], None]] - _indexes: Dict[Hashable, Index] - _variables: Dict[Hashable, Variable] + _name: str | None + _parent: DataTree | None + _children: dict[str, DataTree] + _attrs: dict[Hashable, Any] | None + _cache: dict[str, Any] + _coord_names: set[Hashable] + _dims: dict[Hashable, int] + _encoding: dict[Hashable, Any] | None + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + _variables: dict[Hashable, Variable] __slots__ = ( "_name", @@ -365,10 +360,10 @@ class DataTree( def __init__( self, - data: Optional[Dataset | DataArray] = None, - parent: Optional[DataTree] = None, - children: Optional[Mapping[str, DataTree]] = None, - name: Optional[str] = None, + data: Dataset | DataArray | None = None, + parent: DataTree | None = None, + children: Mapping[str, DataTree] | None = None, + name: str | None = None, ): """ Create a single node of a DataTree. @@ -446,7 +441,7 @@ def ds(self) -> DatasetView: return DatasetView._from_node(self) @ds.setter - def ds(self, data: Optional[Union[Dataset, DataArray]] = None) -> None: + def ds(self, data: Dataset | DataArray | None = None) -> None: ds = _coerce_to_dataset(data) _check_for_name_collisions(self.children, ds.variables) @@ -515,15 +510,14 @@ def is_hollow(self) -> bool: def variables(self) -> Mapping[Hashable, Variable]: """Low level interface to node contents as dict of Variable objects. - This ordered dictionary is frozen to prevent mutation that could - violate Dataset invariants. It contains all variable objects - constituting this DataTree node, including both data variables and - coordinates. + This dictionary is frozen to prevent mutation that could violate + Dataset invariants. It contains all variable objects constituting this + DataTree node, including both data variables and coordinates. """ return Frozen(self._variables) @property - def attrs(self) -> Dict[Hashable, Any]: + def attrs(self) -> dict[Hashable, Any]: """Dictionary of global attributes on this node object.""" if self._attrs is None: self._attrs = {} @@ -534,7 +528,7 @@ def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property - def encoding(self) -> Dict: + def encoding(self) -> dict: """Dictionary of global encoding attributes on this node object.""" if self._encoding is None: self._encoding = {} @@ -589,7 +583,7 @@ def _item_sources(self) -> Iterable[Mapping[Any, Any]]: # immediate child nodes yield self.children - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. @@ -637,30 +631,30 @@ def __array__(self, dtype=None): ) def __repr__(self) -> str: - return formatting.datatree_repr(self) + return datatree_repr(self) def __str__(self) -> str: - return formatting.datatree_repr(self) + return datatree_repr(self) def _repr_html_(self): """Make html representation of datatree object""" if XR_OPTS["display_style"] == "text": return f"
{escape(repr(self))}
" - return formatting_html.datatree_repr(self) + return datatree_repr_html(self) @classmethod def _construct_direct( cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, name: str | None = None, parent: DataTree | None = None, - children: Optional[OrderedDict[str, DataTree]] = None, - close: Optional[Callable[[], None]] = None, + children: dict[str, DataTree] | None = None, + close: Callable[[], None] | None = None, ) -> DataTree: """Shortcut around __init__ for internal use when we want to skip costly validation.""" @@ -670,7 +664,7 @@ def _construct_direct( if indexes is None: indexes = {} if children is None: - children = OrderedDict() + children = dict() obj: DataTree = object.__new__(cls) obj._variables = variables @@ -690,15 +684,15 @@ def _construct_direct( def _replace( self: DataTree, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, name: str | None | Default = _default, parent: DataTree | None = _default, - children: Optional[OrderedDict[str, DataTree]] = None, + children: dict[str, DataTree] | None = None, inplace: bool = False, ) -> DataTree: """ @@ -827,8 +821,8 @@ def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree return self._copy_subtree(deep=True, memo=memo) def get( - self: DataTree, key: str, default: Optional[DataTree | DataArray] = None - ) -> Optional[DataTree | DataArray]: + self: DataTree, key: str, default: DataTree | DataArray | None = None + ) -> DataTree | DataArray | None: """ Access child nodes, variables, or coordinates stored in this node. @@ -839,7 +833,7 @@ def get( ---------- key : str Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). - default : DataTree | DataArray, optional + default : DataTree | DataArray | None, optional A value to return if the specified key does not exist. Default return value is None. """ if key in self.children: @@ -863,7 +857,7 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: Returns ------- - Union[DataTree, DataArray] + DataTree | DataArray """ # Either: @@ -949,7 +943,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) # TODO are there any subtleties with preserving order of children like this? - merged_children = OrderedDict({**self.children, **new_children}) + merged_children = dict({**self.children, **new_children}) self._replace( inplace=True, children=merged_children, **vars_merge_result._asdict() ) @@ -1027,7 +1021,7 @@ def drop_nodes( if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") - children_to_keep = OrderedDict( + children_to_keep = dict( {name: child for name, child in self.children.items() if name not in names} ) return self._replace(children=children_to_keep) @@ -1036,7 +1030,7 @@ def drop_nodes( def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], - name: Optional[str] = None, + name: str | None = None, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1050,7 +1044,7 @@ def from_dict( tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. - name : Hashable, optional + name : Hashable | None, optional Name for the root node of the tree. Default is None. Returns @@ -1085,13 +1079,13 @@ def from_dict( return obj - def to_dict(self) -> Dict[str, Dataset]: + def to_dict(self) -> dict[str, Dataset]: """ Create a dictionary mapping of absolute node paths to the data contained in those nodes. Returns ------- - Dict[str, Dataset] + dict[str, Dataset] """ return {node.path: node.to_dataset() for node in self.subtree} @@ -1313,7 +1307,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | Tuple[DataTree]: + ) -> DataTree | tuple[DataTree]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. @@ -1336,7 +1330,7 @@ def map_over_subtree( Returns ------- - subtrees : DataTree, Tuple of DataTrees + subtrees : DataTree, tuple of DataTrees One or more subtrees containing results from applying ``func`` to the data at each node. """ # TODO this signature means that func has no way to know which node it is being called upon - change? @@ -1485,7 +1479,7 @@ def to_netcdf( kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` """ - from .io import _datatree_to_netcdf + from xarray.datatree_.datatree.io import _datatree_to_netcdf _datatree_to_netcdf( self, @@ -1527,7 +1521,7 @@ def to_zarr( kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` """ - from .io import _datatree_to_zarr + from xarray.datatree_.datatree.io import _datatree_to_zarr _datatree_to_zarr( self, diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index 071dcbecf8c..f2603b64641 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,15 +1,11 @@ # import public API -from .datatree import DataTree -from .extensions import register_datatree_accessor from .mapping import TreeIsomorphismError, map_over_subtree from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( - "DataTree", "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", "map_over_subtree", - "register_datatree_accessor", ) diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py index f6f4e985a79..bf888fc4484 100644 --- a/xarray/datatree_/datatree/extensions.py +++ b/xarray/datatree_/datatree/extensions.py @@ -1,6 +1,6 @@ from xarray.core.extensions import _register_accessor -from .datatree import DataTree +from xarray.core.datatree import DataTree def register_datatree_accessor(name): diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py index deba57eb09d..9ebee72d4ef 100644 --- a/xarray/datatree_/datatree/formatting.py +++ b/xarray/datatree_/datatree/formatting.py @@ -2,11 +2,11 @@ from xarray.core.formatting import _compat_to_str, diff_dataset_repr -from .mapping import diff_treestructure -from .render import RenderTree +from xarray.datatree_.datatree.mapping import diff_treestructure +from xarray.datatree_.datatree.render import RenderTree if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree def diff_nodewise_summary(a, b, compat): diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py index d3d533ee71e..48335ddca70 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/datatree_/datatree/io.py @@ -1,4 +1,4 @@ -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree def _get_nc_dataset_class(engine): diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/datatree_/datatree/mapping.py index 355149060a9..7742ece9738 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/datatree_/datatree/mapping.py @@ -156,7 +156,7 @@ def map_over_subtree(func: Callable) -> Callable: @functools.wraps(func) def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" - from .datatree import DataTree + from xarray.core.datatree import DataTree all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ a for a in kwargs.values() if isinstance(a, DataTree) diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py index aef327c5c47..e6af9c85ee8 100644 --- a/xarray/datatree_/datatree/render.py +++ b/xarray/datatree_/datatree/render.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree Row = collections.namedtuple("Row", ("pre", "fill", "node")) diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py index 1cbcdf2d4e3..bf54116725a 100644 --- a/xarray/datatree_/datatree/testing.py +++ b/xarray/datatree_/datatree/testing.py @@ -1,6 +1,6 @@ from xarray.testing.assertions import ensure_warnings -from .datatree import DataTree +from xarray.core.datatree import DataTree from .formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py index bd2e7ba3247..53a9a72239d 100644 --- a/xarray/datatree_/datatree/tests/conftest.py +++ b/xarray/datatree_/datatree/tests/conftest.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree @pytest.fixture(scope="module") diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py index c3eb74451a6..4ca532ebba4 100644 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ b/xarray/datatree_/datatree/tests/test_dataset_api.py @@ -1,7 +1,7 @@ import numpy as np import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py index 0241e496abf..fb2e82453ec 100644 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ b/xarray/datatree_/datatree/tests/test_extensions.py @@ -1,6 +1,7 @@ import pytest -from xarray.datatree_.datatree import DataTree, register_datatree_accessor +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree.extensions import register_datatree_accessor class TestAccessor: diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py index b58c02282e7..77f8346ae72 100644 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -2,7 +2,7 @@ from xarray import Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py index 943bbab4154..98cdf02bff4 100644 --- a/xarray/datatree_/datatree/tests/test_formatting_html.py +++ b/xarray/datatree_/datatree/tests/test_formatting_html.py @@ -1,7 +1,8 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree, formatting_html +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree import formatting_html @pytest.fixture(scope="module", params=["some html", "some other html"]) diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/datatree_/datatree/tests/test_mapping.py index 53d6e085440..c6cd04887c0 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/datatree_/datatree/tests/test_mapping.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 8590c9fb4e7..a32b0e08bea 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -6,7 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.tests import create_test_data, requires_dask diff --git a/xarray/datatree_/datatree/tests/test_datatree.py b/xarray/tests/test_datatree.py similarity index 98% rename from xarray/datatree_/datatree/tests/test_datatree.py rename to xarray/tests/test_datatree.py index cfb57470651..e2fe18cad7d 100644 --- a/xarray/datatree_/datatree/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2,13 +2,14 @@ import numpy as np import pytest + import xarray as xr +import xarray.datatree_.datatree.testing as dtt import xarray.testing as xrt +from xarray.core.datatree import DataTree +from xarray.core.treenode import NotFoundInTreeError from xarray.tests import create_test_data, source_ndarray -import xarray.datatree_.datatree.testing as dtt -from xarray.datatree_.datatree import DataTree, NotFoundInTreeError - class TestTreeCreation: def test_empty(self): @@ -166,8 +167,7 @@ def test_assign_when_already_child_with_variables_name(self): dt.ds = new_ds -class TestGet: - ... +class TestGet: ... class TestGetItem: @@ -212,7 +212,9 @@ def test_getitem_multiple_data_variables(self): results = DataTree(name="results", data=data) xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) - @pytest.mark.xfail(reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)") + @pytest.mark.xfail( + reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" + ) def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", data=data) @@ -450,8 +452,7 @@ def test_setitem_dataarray_replace_existing_node(self): xrt.assert_identical(results.to_dataset(), expected) -class TestDictionaryInterface: - ... +class TestDictionaryInterface: ... class TestTreeFromDict: