diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 65949a24369..42cc8130810 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast @@ -7,7 +8,7 @@ import pandas as pd from . import formatting -from .indexes import Index, Indexes, assert_no_index_corrupted +from .indexes import Index, Indexes, PandasMultiIndex, assert_no_index_corrupted from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject from .variable import Variable, calculate_dimensions @@ -57,6 +58,9 @@ def variables(self): def _update_coords(self, coords, indexes): raise NotImplementedError() + def _maybe_drop_multiindex_coords(self, coords): + raise NotImplementedError() + def __iter__(self) -> Iterator[Hashable]: # needs to be in the same order as the dataset variables for k in self.variables: @@ -154,6 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: def update(self, other: Mapping[Any, Any]) -> None: other_vars = getattr(other, "variables", other) + self._maybe_drop_multiindex_coords(set(other_vars)) coords, indexes = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.xindexes ) @@ -304,6 +309,15 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes + def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: + """Drops variables in coords, and any associated variables as well.""" + assert self._data.xindexes is not None + variables, indexes = drop_coords( + coords, self._data._variables, self._data.xindexes + ) + self._data._variables = variables + self._data._indexes = indexes + def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] @@ -372,6 +386,14 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes + def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: + """Drops variables in coords, and any associated variables as well.""" + variables, indexes = drop_coords( + coords, self._data._coords, self._data.xindexes + ) + self._data._coords = variables + self._data._indexes = indexes + @property def variables(self): return Frozen(self._data._coords) @@ -397,6 +419,37 @@ def _ipython_key_completions_(self): return self._data._ipython_key_completions_() +def drop_coords( + coords_to_drop: set[Hashable], variables, indexes: Indexes +) -> tuple[dict, dict]: + """Drop index variables associated with variables in coords_to_drop.""" + # Only warn when we're dropping the dimension with the multi-indexed coordinate + # If asked to drop a subset of the levels in a multi-index, we raise an error + # later but skip the warning here. + new_variables = dict(variables.copy()) + new_indexes = dict(indexes.copy()) + for key in coords_to_drop & set(indexes): + maybe_midx = indexes[key] + idx_coord_names = set(indexes.get_all_coords(key)) + if ( + isinstance(maybe_midx, PandasMultiIndex) + and key == maybe_midx.dim + and (idx_coord_names - coords_to_drop) + ): + warnings.warn( + f"Updating MultiIndexed coordinate {key!r} would corrupt indices for " + f"other variables: {list(maybe_midx.index.names)!r}. " + f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before " + "assigning new coordinate values.", + DeprecationWarning, + stacklevel=4, + ) + for k in idx_coord_names: + del new_variables[k] + del new_indexes[k] + return new_variables, new_indexes + + def assert_coordinate_consistent( obj: DataArray | Dataset, coords: Mapping[Any, Variable] ) -> None: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4849738f453..c677ee13c3d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5764,6 +5764,7 @@ def assign( data = self.copy() # do all calculations first... results: CoercibleMapping = data._calc_assign_results(variables) + data.coords._maybe_drop_multiindex_coords(set(results.keys())) # ... and then assign data.update(results) return data diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index d7133683d83..8ff0d40ff07 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1085,6 +1085,9 @@ def dims(self) -> Mapping[Hashable, int]: return Frozen(self._dims) + def copy(self): + return type(self)(dict(self._indexes), dict(self._variables)) + def get_unique(self) -> list[T_PandasOrXarrayIndex]: """Return a list of unique indexes, preserving order.""" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index db3c9824ba3..298840f3f66 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1499,6 +1499,13 @@ def test_assign_coords(self) -> None: with pytest.raises(ValueError): da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray + def test_assign_coords_existing_multiindex(self) -> None: + data = self.mda + with pytest.warns( + DeprecationWarning, match=r"Updating MultiIndexed coordinate" + ): + data.assign_coords(x=range(4)) + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 459acfd87fa..9ea47163d05 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3967,6 +3967,18 @@ def test_assign_multiindex_level(self) -> None: data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) + def test_assign_coords_existing_multiindex(self) -> None: + data = create_test_multiindex() + with pytest.warns( + DeprecationWarning, match=r"Updating MultiIndexed coordinate" + ): + data.assign_coords(x=range(4)) + + with pytest.warns( + DeprecationWarning, match=r"Updating MultiIndexed coordinate" + ): + data.assign(x=range(4)) + def test_assign_all_multiindex_coords(self) -> None: data = create_test_multiindex() actual = data.assign(x=range(4), level_1=range(4), level_2=range(4))