Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename DataTree's "ds" and "data" to "dataset" #9476

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def open_datatree(
ds = open_dataset(
filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
)
new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
new_node = DataTree(name=NodePath(path_group).name, dataset=ds)
tree_root._set_item(
path_group,
new_node,
Expand Down
66 changes: 35 additions & 31 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class DataTree(

def __init__(
self,
data: Dataset | None = None,
dataset: Dataset | None = None,
children: Mapping[str, DataTree] | None = None,
name: str | None = None,
):
Expand All @@ -430,12 +430,12 @@ def __init__(

Parameters
----------
data : Dataset, optional
Data to store under the .ds attribute of this node.
dataset : Dataset, optional
Data to store directly at this node.
children : Mapping[str, DataTree], optional
Any child nodes of this node. Default is None.
Any child nodes of this node.
name : str, optional
Name for this node of the tree. Default is None.
Name for this node of the tree.

Returns
-------
Expand All @@ -449,24 +449,24 @@ def __init__(
children = {}

super().__init__(name=name)
self._set_node_data(_to_new_dataset(data))
self._set_node_data(_to_new_dataset(dataset))

# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
self.children = {name: child.copy() for name, child in children.items()}

def _set_node_data(self, ds: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(ds)
def _set_node_data(self, dataset: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
self._data_variables = data_vars
self._node_coord_variables = coord_vars
self._node_dims = ds._dims
self._node_indexes = ds._indexes
self._encoding = ds._encoding
self._attrs = ds._attrs
self._close = ds._close
self._node_dims = dataset._dims
self._node_indexes = dataset._indexes
self._encoding = dataset._encoding
self._attrs = dataset._attrs
self._close = dataset._close

def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._pre_attach(parent, name)
if name in parent.ds.variables:
if name in parent.dataset.variables:
raise KeyError(
f"parent {parent.name} already contains a variable named {name}"
)
Expand Down Expand Up @@ -534,7 +534,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView:
)

@property
def ds(self) -> DatasetView:
def dataset(self) -> DatasetView:
"""
An immutable Dataset-like view onto the data in this node.

Expand All @@ -549,11 +549,15 @@ def ds(self) -> DatasetView:
"""
return self._to_dataset_view(rebuild_dims=True, inherited=True)

@ds.setter
def ds(self, data: Dataset | None = None) -> None:
@dataset.setter
def dataset(self, data: Dataset | None = None) -> None:
ds = _to_new_dataset(data)
self._replace_node(ds)

# soft-deprecated alias, to facilitate the transition from
# xarray-contrib/datatree
ds = dataset

def to_dataset(self, inherited: bool = True) -> Dataset:
"""
Return the data in this node as a new xarray.Dataset object.
Expand All @@ -566,7 +570,7 @@ def to_dataset(self, inherited: bool = True) -> Dataset:

See Also
--------
DataTree.ds
DataTree.dataset
"""
coord_vars = self._coord_variables if inherited else self._node_coord_variables
variables = dict(self._data_variables)
Expand Down Expand Up @@ -845,8 +849,8 @@ def get( # type: ignore[override]
"""
if key in self.children:
return self.children[key]
elif key in self.ds:
return self.ds[key]
elif key in self.dataset:
return self.dataset[key]
else:
return default

Expand Down Expand Up @@ -1114,7 +1118,7 @@ def from_dict(
if isinstance(root_data, DataTree):
obj = root_data.copy()
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, data=root_data, children=None)
obj = cls(name=name, dataset=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
Expand All @@ -1133,7 +1137,7 @@ def depth(item) -> int:
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, data=data)
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
Expand Down Expand Up @@ -1264,7 +1268,7 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool:

return all(
[
node.ds.equals(other_node.ds)
node.dataset.equals(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
]
)
Expand Down Expand Up @@ -1294,7 +1298,7 @@ def identical(self, other: DataTree, from_root=True) -> bool:
return False

return all(
node.ds.identical(other_node.ds)
node.dataset.identical(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
)

Expand All @@ -1321,7 +1325,7 @@ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
map_over_subtree
"""
filtered_nodes = {
node.path: node.ds for node in self.subtree if filterfunc(node)
node.path: node.dataset for node in self.subtree if filterfunc(node)
}
return DataTree.from_dict(filtered_nodes, name=self.root.name)

Expand Down Expand Up @@ -1365,7 +1369,7 @@ def match(self, pattern: str) -> DataTree:
└── Group: /b/B
"""
matching_nodes = {
node.path: node.ds
node.path: node.dataset
for node in self.subtree
if NodePath(node.path).match(pattern)
}
Expand All @@ -1389,7 +1393,7 @@ def map_over_subtree(
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
`func(node.dataset, *args, **kwargs) -> Dataset`.

Function will not be applied to any nodes without datasets.
*args : tuple, optional
Expand Down Expand Up @@ -1420,7 +1424,7 @@ def map_over_subtree_inplace(
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
`func(node.dataset, *args, **kwargs) -> Dataset`.

Function will not be applied to any nodes without datasets,
*args : tuple, optional
Expand All @@ -1433,7 +1437,7 @@ def map_over_subtree_inplace(

for node in self.subtree:
if node.has_data:
node.ds = func(node.ds, *args, **kwargs)
node.dataset = func(node.dataset, *args, **kwargs)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down Expand Up @@ -1499,7 +1503,7 @@ def render(self):
"""Print tree structure, including any data stored at each node."""
for pre, fill, node in RenderDataTree(self):
print(f"{pre}DataTree('{self.name}')")
for ds_line in repr(node.ds)[1:]:
for ds_line in repr(node.dataset)[1:]:
print(f"{fill}{ds_line}")

def merge(self, datatree: DataTree) -> DataTree:
Expand All @@ -1513,7 +1517,7 @@ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree:
# TODO some kind of .collapse() or .flatten() method to merge a subtree

def to_dataarray(self) -> DataArray:
return self.ds.to_dataarray()
return self.dataset.to_dataarray()

@property
def groups(self):
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def map_over_subtree(func: Callable) -> Callable:
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via `.ds`.
via `.dataset`.
**kwargs : Any
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via `.ds`.
via `.dataset`.

Returns
-------
Expand Down Expand Up @@ -160,13 +160,14 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
strict=False,
):
node_args_as_datasetviews = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
a.dataset if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
v.dataset if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
strict=True,
Expand All @@ -183,7 +184,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
)
elif node_of_first_tree.has_attrs:
# propagate attrs
results = node_of_first_tree.ds
results = node_of_first_tree.dataset
else:
# nothing to propagate so use fastpath to create empty node in new tree
results = None
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def diff_nodewise_summary(a: DataTree, b: DataTree, compat):

summary = []
for node_a, node_b in zip(a.subtree, b.subtree, strict=True):
a_ds, b_ds = node_a.ds, node_b.ds
a_ds, b_ds = node_a.dataset, node_b.dataset

if not a_ds._all_compat(b_ds, compat):
dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str)
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):

# add compression
comp = dict(zlib=True, complevel=9)
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}

original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
roundtrip_dt = open_datatree(filepath, engine=self.engine)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_zarr_encoding(self, tmpdir, simple_datatree):
original_dt = simple_datatree

comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
original_dt.to_zarr(filepath, encoding=enc)
roundtrip_dt = open_datatree(filepath, engine="zarr")

Expand Down
Loading
Loading