Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop multi-indexes when assigning to a multi-indexed variable #6798

Merged
merged 9 commits into from
Jul 21, 2022
55 changes: 54 additions & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast

import numpy as np
import pandas as pd

from . import formatting
from .indexes import Index, Indexes, assert_no_index_corrupted
from .indexes import Index, Indexes, PandasMultiIndex, assert_no_index_corrupted
from .merge import merge_coordinates_without_align, merge_coords
from .utils import Frozen, ReprObject
from .variable import Variable, calculate_dimensions
Expand Down Expand Up @@ -57,6 +58,9 @@ def variables(self):
def _update_coords(self, coords, indexes):
raise NotImplementedError()

def _maybe_drop_multiindex_coords(self, coords):
raise NotImplementedError()

def __iter__(self) -> Iterator[Hashable]:
# needs to be in the same order as the dataset variables
for k in self.variables:
Expand Down Expand Up @@ -154,6 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:

def update(self, other: Mapping[Any, Any]) -> None:
other_vars = getattr(other, "variables", other)
self._maybe_drop_multiindex_coords(set(other_vars))
coords, indexes = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.xindexes
)
Expand Down Expand Up @@ -304,6 +309,15 @@ def _update_coords(
original_indexes.update(indexes)
self._data._indexes = original_indexes

def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
"""Drops variables in coords, and any associated variables as well."""
assert self._data.xindexes is not None
variables, indexes = drop_coords(
coords, self._data._variables, self._data.xindexes
)
self._data._variables = variables
self._data._indexes = indexes

def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
Expand Down Expand Up @@ -372,6 +386,14 @@ def _update_coords(
original_indexes.update(indexes)
self._data._indexes = original_indexes

def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
"""Drops variables in coords, and any associated variables as well."""
variables, indexes = drop_coords(
coords, self._data._coords, self._data.xindexes
)
self._data._coords = variables
self._data._indexes = indexes

@property
def variables(self):
return Frozen(self._data._coords)
Expand All @@ -397,6 +419,37 @@ def _ipython_key_completions_(self):
return self._data._ipython_key_completions_()


def drop_coords(
coords_to_drop: set[Hashable], variables, indexes: Indexes
) -> tuple[dict, dict]:
"""Drop index variables associated with variables in coords_to_drop."""
# Only warn when we're dropping the dimension with the multi-indexed coordinate
# If asked to drop a subset of the levels in a multi-index, we raise an error
# later but skip the warning here.
new_variables = dict(variables.copy())
new_indexes = dict(indexes.copy())
for key in coords_to_drop & set(indexes):
maybe_midx = indexes[key]
idx_coord_names = set(indexes.get_all_coords(key))
if (
isinstance(maybe_midx, PandasMultiIndex)
and key == maybe_midx.dim
and (idx_coord_names - coords_to_drop)
):
warnings.warn(
f"Updating MultiIndexed coordinate {key!r} would corrupt indices for "
f"other variables: {list(maybe_midx.index.names)!r}. "
f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before "
"assigning new coordinate values.",
DeprecationWarning,
stacklevel=4,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 4 for assign_coords and 5 for Dataset.assign but perhaps this is OK.

)
for k in idx_coord_names:
del new_variables[k]
del new_indexes[k]
return new_variables, new_indexes


def assert_coordinate_consistent(
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
) -> None:
Expand Down
1 change: 1 addition & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5764,6 +5764,7 @@ def assign(
data = self.copy()
# do all calculations first...
results: CoercibleMapping = data._calc_assign_results(variables)
data.coords._maybe_drop_multiindex_coords(set(results.keys()))
# ... and then assign
data.update(results)
return data
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,9 @@ def dims(self) -> Mapping[Hashable, int]:

return Frozen(self._dims)

def copy(self):
return type(self)(dict(self._indexes), dict(self._variables))

def get_unique(self) -> list[T_PandasOrXarrayIndex]:
"""Return a list of unique indexes, preserving order."""

Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,13 @@ def test_assign_coords(self) -> None:
with pytest.raises(ValueError):
da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray

def test_assign_coords_existing_multiindex(self) -> None:
data = self.mda
with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign_coords(x=range(4))

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3967,6 +3967,18 @@ def test_assign_multiindex_level(self) -> None:
data.assign(level_1=range(4))
data.assign_coords(level_1=range(4))

def test_assign_coords_existing_multiindex(self) -> None:
data = create_test_multiindex()
with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign_coords(x=range(4))

with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign(x=range(4))

def test_assign_all_multiindex_coords(self) -> None:
data = create_test_multiindex()
actual = data.assign(x=range(4), level_1=range(4), level_2=range(4))
Expand Down