Skip to content

Commit

Permalink
Fix two bugs in DataTree.update() (#9214)
Browse files Browse the repository at this point in the history
* Fix two bugs in DataTree.update()

1. Fix handling of coordinates on a Dataset argument (previously these
   were silently dropped).
2. Do not copy inherited coordinates down to lower level nodes.

* add mypy annotation
  • Loading branch information
shoyer authored Jul 8, 2024
1 parent bac01c0 commit 179c670
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 26 deletions.
40 changes: 23 additions & 17 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import pandas as pd

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleValue
from xarray.core.merge import CoercibleMapping, CoercibleValue
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes

# """
Expand Down Expand Up @@ -954,23 +954,29 @@ def update(
Just like `dict.update` this is an in-place operation.
"""
# TODO separate by type
new_children: dict[str, DataTree] = {}
new_variables = {}
for k, v in other.items():
if isinstance(v, DataTree):
# avoid named node being stored under inconsistent key
new_child: DataTree = v.copy()
# Datatree's name is always a string until we fix that (#8836)
new_child.name = str(k)
new_children[str(k)] = new_child
elif isinstance(v, (DataArray, Variable)):
# TODO this should also accommodate other types that can be coerced into Variables
new_variables[k] = v
else:
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")

vars_merge_result = dataset_update_method(self.to_dataset(), new_variables)
new_variables: CoercibleMapping

if isinstance(other, Dataset):
new_variables = other
else:
new_variables = {}
for k, v in other.items():
if isinstance(v, DataTree):
# avoid named node being stored under inconsistent key
new_child: DataTree = v.copy()
# Datatree's name is always a string until we fix that (#8836)
new_child.name = str(k)
new_children[str(k)] = new_child
elif isinstance(v, (DataArray, Variable)):
# TODO this should also accommodate other types that can be coerced into Variables
new_variables[k] = v
else:
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")

vars_merge_result = dataset_update_method(
self.to_dataset(inherited=False), new_variables
)
data = Dataset._construct_direct(**vars_merge_result._asdict())

# TODO are there any subtleties with preserving order of children like this?
Expand Down
37 changes: 28 additions & 9 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ def test_update(self):
dt: DataTree = DataTree()
dt.update({"foo": xr.DataArray(0), "a": DataTree()})
expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None})
print(dt)
print(dt.children)
print(dt._children)
print(dt["a"])
print(expected)
assert_equal(dt, expected)

def test_update_new_named_dataarray(self):
Expand All @@ -268,14 +263,38 @@ def test_update_doesnt_alter_child_name(self):
def test_update_overwrite(self):
actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))})
actual.update({"a": DataTree(xr.Dataset({"x": 2}))})

expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))})
assert_equal(actual, expected)

print(actual)
print(expected)

def test_update_coordinates(self):
expected = DataTree.from_dict({"/": xr.Dataset(coords={"a": 1})})
actual = DataTree.from_dict({"/": xr.Dataset()})
actual.update(xr.Dataset(coords={"a": 1}))
assert_equal(actual, expected)

def test_update_inherited_coords(self):
expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"a": 1}),
"/b": xr.Dataset(coords={"c": 1}),
}
)
actual = DataTree.from_dict(
{
"/": xr.Dataset(coords={"a": 1}),
"/b": xr.Dataset(),
}
)
actual["/b"].update(xr.Dataset(coords={"c": 1}))
assert_identical(actual, expected)

# DataTree.identical() currently does not require that non-inherited
# coordinates are defined identically, so we need to check this
# explicitly
actual_node = actual.children["b"].to_dataset(inherited=False)
expected_node = expected.children["b"].to_dataset(inherited=False)
assert_identical(actual_node, expected_node)


class TestCopy:
def test_copy(self, create_test_datatree):
Expand Down

0 comments on commit 179c670

Please sign in to comment.