Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAS-2067 - Migrate datatree io.py and common.py #9011

Merged
merged 11 commits into from
Jun 12, 2024
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ Internal Changes
rather than ``dims`` or ``dimensions``. This is the final change to make xarray methods
consistent with their use of ``dim``. Using the existing kwarg will raise a
warning. By `Maximilian Roos <https://github.com/max-sixty>`_
- Migrates remainder of ``io.py`` to ``xarray/core/datatree_io.py`` and
``TreeAttrAccessMixin`` into ``xarray/core/common.py`` (:pull: `9011`)
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
18 changes: 9 additions & 9 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.types import ZarrWriteModes
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import guess_chunkmanager
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def open_mfdataset(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1139,7 +1139,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1156,7 +1156,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1174,7 +1174,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1192,7 +1192,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1210,7 +1210,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1227,7 +1227,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1242,7 +1242,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
28 changes: 28 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,34 @@ def _ipython_key_completions_(self) -> list[str]:
return list(items)


class TreeAttrAccessMixin(AttrAccessMixin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here's the biggest question I have... is this okay for now? (Maybe this is directed at @TomNicholas)

The use of __slots__ is a little new to me (but makes sense from reading up on it). I think to adopt __slots__ and remove __dict__ I would need to slightly rework DataTree, TreeNode and NamedNode. I wanted to check that the scope mainly seems to be self.children and self.parent attributes, and understand why those classes were implemented as they currently are before trying to alter things.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to embarrass me, but why did you stop calling init_subclass on the Parent object? everything else with this merge seems reasonable to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I accidentally commented out that last line. Will fix!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, my last comment was true, but... uncommenting out that line means that the same method in the parent class (AttrAccessMixin) gets called by super in addition to this method, right? That would mean we end up calling the validation logic that this child method is trying to avoid for now (checking that __dict__ is not present on the class).

Now if only I'd said that the first time around 😉

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this class necessary? Can't DataTree inherit directly from AttrAccessMixin? And override _attr_sources and so on?

I think the reason this class existed separately was just to do with hacking around limitations of datatree being in a separate repository...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially tried to use AttrAccessMixin. Just dropping it in as a replacement for TreeAttrAccessMixin doesn't work because of the checking done inside the AttrAccessMixin.__init_subclass__ method. The check for if __dict__ exists fails. DataTree, TreeNode and NamedNode do all define __slots__, but I think there are attributes being defined in places like DataTree.__init__, for example DataTree.parent and DataTree.children (versus the DataTree._parent and DataTree._children attributes defined in DataTree.__slots__).

I think we talked about this a little bit in one of our calls - things like DataTree.parent and DataTree.children seemed like tricky things to change over.

"""Mixin class that allows getting keys with attribute access"""

# TODO: Ensure ipython tab completion can include both child datatrees and
# variables from Dataset objects on relevant nodes.

__slots__ = ()

def __init_subclass__(cls, **kwargs):
"""Verify that all subclasses explicitly define ``__slots__``. If they don't,
raise error in the core xarray module and a FutureWarning in third-party
extensions.
"""
if not hasattr(object.__new__(cls), "__dict__"):
pass
# TODO Rework DataTree to avoid __dict__.
# elif cls.__module__.startswith("xarray."):
# raise AttributeError(f"{cls.__name__} must explicitly define __slots__")
# else:
# cls.__setattr__ = cls._setattr_dict
# warnings.warn(
# f"xarray subclass {cls.__name__} should explicitly define __slots__",
# FutureWarning,
# stacklevel=2,
# )
# super().__init_subclass__(**kwargs)
owenlittlejohns marked this conversation as resolved.
Show resolved Hide resolved


def get_squeeze_dims(
xarray_obj,
dim: Hashable | Iterable[Hashable] | None = None,
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
DaCompatible,
NetcdfWriteModes,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
Expand Down Expand Up @@ -3943,7 +3944,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
def to_netcdf(
self,
path: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3958,7 +3959,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3974,7 +3975,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3990,7 +3991,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -4003,7 +4004,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
NetcdfWriteModes,
QuantileMethods,
Self,
T_ChunkDim,
Expand Down Expand Up @@ -2171,7 +2172,7 @@ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None:
def to_netcdf(
self,
path: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2186,7 +2187,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2202,7 +2203,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2218,7 +2219,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2231,7 +2232,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
17 changes: 11 additions & 6 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from xarray.core import utils
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
Expand Down Expand Up @@ -46,7 +47,6 @@
maybe_wrap_array,
)
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin

try:
from xarray.core.variable import calculate_dimensions
Expand All @@ -58,7 +58,7 @@
import pandas as pd

from xarray.core.merge import CoercibleValue
from xarray.core.types import ErrorOptions
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -1475,7 +1475,12 @@ def groups(self):
return tuple(node.path for node in self.subtree)

def to_netcdf(
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
self,
filepath,
mode: NetcdfWriteModes = "w",
encoding=None,
unlimited_dims=None,
**kwargs,
):
"""
Write datatree contents to a netCDF file.
Expand All @@ -1502,7 +1507,7 @@ def to_netcdf(
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from xarray.datatree_.datatree.io import _datatree_to_netcdf
from xarray.core.datatree_io import _datatree_to_netcdf

_datatree_to_netcdf(
self,
Expand All @@ -1516,7 +1521,7 @@ def to_netcdf(
def to_zarr(
self,
store,
mode: str = "w-",
mode: ZarrWriteModes = "w-",
encoding=None,
consolidated: bool = True,
**kwargs,
Expand Down Expand Up @@ -1544,7 +1549,7 @@ def to_zarr(
kwargs :
Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
"""
from xarray.datatree_.datatree.io import _datatree_to_zarr
from xarray.core.datatree_io import _datatree_to_zarr

_datatree_to_zarr(
self,
Expand Down
45 changes: 34 additions & 11 deletions xarray/datatree_/datatree/io.py → xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from __future__ import annotations

from collections.abc import Mapping, MutableMapping
from os import PathLike
from typing import Any

from xarray.core.datatree import DataTree
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes


def _get_nc_dataset_class(engine):
def _get_nc_dataset_class(engine: str | None):
owenlittlejohns marked this conversation as resolved.
Show resolved Hide resolved
if engine == "netcdf4":
from netCDF4 import Dataset
elif engine == "h5netcdf":
Expand All @@ -16,7 +23,9 @@ def _get_nc_dataset_class(engine):
return Dataset


def _create_empty_netcdf_group(filename, group, mode, engine):
def _create_empty_netcdf_group(
filename: str | PathLike, group: str, mode: NetcdfWriteModes, engine: str | None
):
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
Expand All @@ -25,12 +34,18 @@ def _create_empty_netcdf_group(filename, group, mode, engine):

def _datatree_to_netcdf(
dt: DataTree,
filepath,
mode: str = "w",
encoding=None,
unlimited_dims=None,
filepath: str | PathLike,
mode: NetcdfWriteModes = "w",
encoding: Mapping[str, Any] | None = None,
unlimited_dims: Mapping | None = None,
**kwargs,
):
"""This function creates an appropriate datastore for writing a datatree to
disk as a netCDF file.

See `DataTree.to_netcdf` for full API docs.
"""

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

Expand Down Expand Up @@ -74,10 +89,12 @@ def _datatree_to_netcdf(
unlimited_dims=unlimited_dims.get(node.path),
**kwargs,
)
mode = "r+"
mode = "a"
owenlittlejohns marked this conversation as resolved.
Show resolved Hide resolved


def _create_empty_zarr_group(store, group, mode):
def _create_empty_zarr_group(
store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
):
import zarr

root = zarr.open_group(store, mode=mode)
Expand All @@ -86,12 +103,18 @@ def _create_empty_zarr_group(store, group, mode):

def _datatree_to_zarr(
dt: DataTree,
store,
mode: str = "w-",
encoding=None,
store: MutableMapping | str | PathLike[str],
mode: ZarrWriteModes = "w-",
encoding: Mapping[str, Any] | None = None,
consolidated: bool = True,
**kwargs,
):
"""This function creates an appropriate datastore for writing a datatree
to a zarr store.

See `DataTree.to_zarr` for full API docs.
"""

from zarr.convenience import consolidate_metadata

if kwargs.get("group", None) is not None:
Expand Down
1 change: 1 addition & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,5 @@ def copy(
]


NetcdfWriteModes = Literal["w", "a"]
owenlittlejohns marked this conversation as resolved.
Show resolved Hide resolved
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
Loading
Loading