diff --git a/doc/api.rst b/doc/api.rst index 11ae5de8531..c3488389d4c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -107,6 +107,7 @@ Dataset contents Dataset.swap_dims Dataset.expand_dims Dataset.drop_vars + Dataset.drop_indexes Dataset.drop_duplicates Dataset.drop_dims Dataset.set_coords @@ -146,6 +147,7 @@ Indexing Dataset.reindex_like Dataset.set_index Dataset.reset_index + Dataset.set_xindex Dataset.reorder_levels Dataset.query @@ -298,6 +300,7 @@ DataArray contents DataArray.swap_dims DataArray.expand_dims DataArray.drop_vars + DataArray.drop_indexes DataArray.drop_duplicates DataArray.reset_coords DataArray.copy @@ -330,6 +333,7 @@ Indexing DataArray.reindex_like DataArray.set_index DataArray.reset_index + DataArray.set_xindex DataArray.reorder_levels DataArray.query @@ -1080,6 +1084,7 @@ Advanced API Variable IndexVariable as_variable + indexes.Index Context register_dataset_accessor register_dataarray_accessor @@ -1087,6 +1092,11 @@ Advanced API backends.BackendArray backends.BackendEntrypoint +Default, pandas-backed indexes built-in Xarray: + + indexes.PandasIndex + indexes.PandasMultiIndex + These backends provide a low-level interface for lazily loading data from external file-formats or protocols, and can be manually invoked to create arguments for the ``load_store`` and ``dump_to_store`` Dataset methods: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47086a687b8..a060fe8a459 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,11 @@ v2022.07.0 (unreleased) New Features ~~~~~~~~~~~~ + +- Add :py:meth:`Dataset.set_xindex` and :py:meth:`Dataset.drop_indexes` and + their DataArray counterpart for setting and dropping pandas or custom indexes + given a set of arbitrary coordinates. (:pull:`6971`) + By `BenoƮt Bovy `_ and `Justus Magin `_. - Enable taking the mean of dask-backed :py:class:`cftime.datetime` arrays (:pull:`6556`, :pull:`6940`). By `Deepak Cherian `_ and `Spencer Clark diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6c09d8c15b4..f98879b689c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2201,6 +2201,11 @@ def set_index( """Set DataArray (multi-)indexes using one or more existing coordinates. + This legacy method is limited to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See + :py:meth:`~DataArray.set_xindex` for setting a pandas or a custom + Xarray-compatible index from one or more arbitrary coordinates. + Parameters ---------- indexes : {dim: index, ...} @@ -2245,6 +2250,7 @@ def set_index( See Also -------- DataArray.reset_index + DataArray.set_xindex """ ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) @@ -2258,6 +2264,12 @@ def reset_index( ) -> DataArray: """Reset the specified index(es) or multi-index level(s). + This legacy method is specific to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See the more generic + :py:meth:`~DataArray.drop_indexes` and :py:meth:`~DataArray.set_xindex` + method to respectively drop and set pandas or custom indexes for + arbitrary coordinates. + Parameters ---------- dims_or_levels : Hashable or sequence of Hashable @@ -2276,10 +2288,41 @@ def reset_index( See Also -------- DataArray.set_index + DataArray.set_xindex + DataArray.drop_indexes """ ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop) return self._from_temp_dataset(ds) + def set_xindex( + self: T_DataArray, + coord_names: str | Sequence[Hashable], + index_cls: type[Index] | None = None, + **options, + ) -> T_DataArray: + """Set a new, Xarray-compatible index from one or more existing + coordinate(s). + + Parameters + ---------- + coord_names : str or list + Name(s) of the coordinate(s) used to build the index. + If several names are given, their order matters. + index_cls : subclass of :class:`~xarray.indexes.Index` + The type of index to create. By default, try setting + a pandas (multi-)index from the supplied coordinates. + **options + Options passed to the index constructor. + + Returns + ------- + obj : DataArray + Another dataarray, with this dataarray's data and with a new index. + + """ + ds = self._to_temp_dataset().set_xindex(coord_names, index_cls, **options) + return self._from_temp_dataset(ds) + def reorder_levels( self: T_DataArray, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, @@ -2590,6 +2633,31 @@ def drop_vars( ds = self._to_temp_dataset().drop_vars(names, errors=errors) return self._from_temp_dataset(ds) + def drop_indexes( + self: T_DataArray, + coord_names: Hashable | Iterable[Hashable], + *, + errors: ErrorOptions = "raise", + ) -> T_DataArray: + """Drop the indexes assigned to the given coordinates. + + Parameters + ---------- + coord_names : hashable or iterable of hashable + Name(s) of the coordinate(s) for which to drop the index. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the coordinates + passed have no index or are not in the dataset. + If 'ignore', no error is raised. + + Returns + ------- + dropped : DataArray + A new dataarray with dropped indexes. + """ + ds = self._to_temp_dataset().drop_indexes(coord_names, errors=errors) + return self._from_temp_dataset(ds) + def drop( self: T_DataArray, labels: Mapping[Any, Any] | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c500b537de3..7a73979cef9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3974,6 +3974,11 @@ def set_index( """Set Dataset (multi-)indexes using one or more existing coordinates or variables. + This legacy method is limited to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See + :py:meth:`~Dataset.set_xindex` for setting a pandas or a custom + Xarray-compatible index from one or more arbitrary coordinates. + Parameters ---------- indexes : {dim: index, ...} @@ -4021,6 +4026,7 @@ def set_index( See Also -------- Dataset.reset_index + Dataset.set_xindex Dataset.swap_dims """ dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") @@ -4067,7 +4073,7 @@ def set_index( f"dimension mismatch: try setting an index for dimension {dim!r} with " f"variable {var_name!r} that has dimensions {var.dims}" ) - idx = PandasIndex.from_variables({dim: var}) + idx = PandasIndex.from_variables({dim: var}, options={}) idx_vars = idx.create_variables({var_name: var}) # trick to preserve coordinate order in this case @@ -4129,6 +4135,12 @@ def reset_index( ) -> T_Dataset: """Reset the specified index(es) or multi-index level(s). + This legacy method is specific to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See the more generic + :py:meth:`~Dataset.drop_indexes` and :py:meth:`~Dataset.set_xindex` + method to respectively drop and set pandas or custom indexes for + arbitrary coordinates. + Parameters ---------- dims_or_levels : Hashable or Sequence of Hashable @@ -4146,6 +4158,8 @@ def reset_index( See Also -------- Dataset.set_index + Dataset.set_xindex + Dataset.drop_indexes """ if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): dims_or_levels = [dims_or_levels] @@ -4225,6 +4239,118 @@ def drop_or_convert(var_names): variables, coord_names=coord_names, indexes=indexes ) + def set_xindex( + self: T_Dataset, + coord_names: str | Sequence[Hashable], + index_cls: type[Index] | None = None, + **options, + ) -> T_Dataset: + """Set a new, Xarray-compatible index from one or more existing + coordinate(s). + + Parameters + ---------- + coord_names : str or list + Name(s) of the coordinate(s) used to build the index. + If several names are given, their order matters. + index_cls : subclass of :class:`~xarray.indexes.Index`, optional + The type of index to create. By default, try setting + a ``PandasIndex`` if ``len(coord_names) == 1``, + otherwise a ``PandasMultiIndex``. + **options + Options passed to the index constructor. + + Returns + ------- + obj : Dataset + Another dataset, with this dataset's data and with a new index. + + """ + # the Sequence check is required for mypy + if is_scalar(coord_names) or not isinstance(coord_names, Sequence): + coord_names = [coord_names] + + if index_cls is None: + if len(coord_names) == 1: + index_cls = PandasIndex + else: + index_cls = PandasMultiIndex + else: + if not issubclass(index_cls, Index): + raise TypeError(f"{index_cls} is not a subclass of xarray.Index") + + invalid_coords = set(coord_names) - self._coord_names + + if invalid_coords: + msg = ["invalid coordinate(s)"] + no_vars = invalid_coords - set(self._variables) + data_vars = invalid_coords - no_vars + if no_vars: + msg.append(f"those variables don't exist: {no_vars}") + if data_vars: + msg.append( + f"those variables are data variables: {data_vars}, use `set_coords` first" + ) + raise ValueError("\n".join(msg)) + + # we could be more clever here (e.g., drop-in index replacement if index + # coordinates do not conflict), but let's not allow this for now + indexed_coords = set(coord_names) & set(self._indexes) + + if indexed_coords: + raise ValueError( + f"those coordinates already have an index: {indexed_coords}" + ) + + coord_vars = {name: self._variables[name] for name in coord_names} + + index = index_cls.from_variables(coord_vars, options=options) + + new_coord_vars = index.create_variables(coord_vars) + + # special case for setting a pandas multi-index from level coordinates + # TODO: remove it once we depreciate pandas multi-index dimension (tuple + # elements) coordinate + if isinstance(index, PandasMultiIndex): + coord_names = [index.dim] + list(coord_names) + + variables: dict[Hashable, Variable] + indexes: dict[Hashable, Index] + + if len(coord_names) == 1: + variables = self._variables.copy() + indexes = self._indexes.copy() + + name = list(coord_names).pop() + if name in new_coord_vars: + variables[name] = new_coord_vars[name] + indexes[name] = index + else: + # reorder variables and indexes so that coordinates having the same + # index are next to each other + variables = {} + for name, var in self._variables.items(): + if name not in coord_names: + variables[name] = var + + indexes = {} + for name, idx in self._indexes.items(): + if name not in coord_names: + indexes[name] = idx + + for name in coord_names: + try: + variables[name] = new_coord_vars[name] + except KeyError: + variables[name] = self._variables[name] + indexes[name] = index + + return self._replace( + variables=variables, + coord_names=self._coord_names | set(coord_names), + indexes=indexes, + ) + def reorder_levels( self: T_Dataset, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, @@ -4951,6 +5077,59 @@ def drop_vars( variables, coord_names=coord_names, indexes=indexes ) + def drop_indexes( + self: T_Dataset, + coord_names: Hashable | Iterable[Hashable], + *, + errors: ErrorOptions = "raise", + ) -> T_Dataset: + """Drop the indexes assigned to the given coordinates. + + Parameters + ---------- + coord_names : hashable or iterable of hashable + Name(s) of the coordinate(s) for which to drop the index. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the coordinates + passed have no index or are not in the dataset. + If 'ignore', no error is raised. + + Returns + ------- + dropped : Dataset + A new dataset with dropped indexes. + + """ + # the Iterable check is required for mypy + if is_scalar(coord_names) or not isinstance(coord_names, Iterable): + coord_names = {coord_names} + else: + coord_names = set(coord_names) + + if errors == "raise": + invalid_coords = coord_names - self._coord_names + if invalid_coords: + raise ValueError(f"those coordinates don't exist: {invalid_coords}") + + unindexed_coords = set(coord_names) - set(self._indexes) + if unindexed_coords: + raise ValueError( + f"those coordinates do not have an index: {unindexed_coords}" + ) + + assert_no_index_corrupted(self.xindexes, coord_names, action="remove index(es)") + + variables = {} + for name, var in self._variables.items(): + if name in coord_names: + variables[name] = var.to_base_variable() + else: + variables[name] = var + + indexes = {k: v for k, v in self._indexes.items() if k not in coord_names} + + return self._replace(variables=variables, indexes=indexes) + def drop( self: T_Dataset, labels=None, @@ -7874,7 +8053,7 @@ def pad( # reset default index of dimension coordinates if (name,) == var.dims: dim_var = {name: variables[name]} - index = PandasIndex.from_variables(dim_var) + index = PandasIndex.from_variables(dim_var, options={}) index_vars = index.create_variables(dim_var) indexes[name] = index variables[name] = index_vars[name] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1150cb6b2f5..cc92e20d91a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -32,10 +32,19 @@ class Index: - """Base class inherited by all xarray-compatible indexes.""" + """Base class inherited by all xarray-compatible indexes. + + Do not use this class directly for creating index objects. + + """ @classmethod - def from_variables(cls, variables: Mapping[Any, Variable]) -> Index: + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> Index: raise NotImplementedError() @classmethod @@ -247,7 +256,12 @@ def _replace(self, index, dim=None, coord_dtype=None): return type(self)(index, dim, coord_dtype) @classmethod - def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasIndex: + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> PandasIndex: if len(variables) != 1: raise ValueError( f"PandasIndex only accepts one variable, found {len(variables)} variables" @@ -570,7 +584,12 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex return type(self)(index, dim, level_coords_dtype) @classmethod - def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasMultiIndex: + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> PandasMultiIndex: _check_dim_compat(variables) dim = next(iter(variables.values())).dims[0] @@ -998,7 +1017,7 @@ def create_default_index_implicit( ) else: dim_var = {name: dim_variable} - index = PandasIndex.from_variables(dim_var) + index = PandasIndex.from_variables(dim_var, options={}) index_vars = index.create_variables(dim_var) return index, index_vars @@ -1410,8 +1429,9 @@ def filter_indexes_from_coords( def assert_no_index_corrupted( indexes: Indexes[Index], coord_names: set[Hashable], + action: str = "remove coordinate(s)", ) -> None: - """Assert removing coordinates will not corrupt indexes.""" + """Assert removing coordinates or indexes will not corrupt indexes.""" # An index may be corrupted when the set of its corresponding coordinate name(s) # partially overlaps the set of coordinate names to remove @@ -1421,7 +1441,7 @@ def assert_no_index_corrupted( common_names_str = ", ".join(f"{k!r}" for k in common_names) index_names_str = ", ".join(f"{k!r}" for k in index_coords) raise ValueError( - f"cannot remove coordinate(s) {common_names_str}, which would corrupt " + f"cannot {action} {common_names_str}, which would corrupt " f"the following index built from coordinates {index_names_str}:\n" f"{index}" ) diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py new file mode 100644 index 00000000000..41321c9a0ff --- /dev/null +++ b/xarray/indexes/__init__.py @@ -0,0 +1,7 @@ +"""Xarray index objects for label-based selection and alignment of Dataset / +DataArray objects. + +""" +from ..core.indexes import Index, PandasIndex, PandasMultiIndex + +__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3b69f8e80fb..3602b87102d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2067,6 +2067,23 @@ def test_reorder_levels(self) -> None: with pytest.raises(ValueError, match=r"has no MultiIndex"): array.reorder_levels(x=["level_1", "level_2"]) + def test_set_xindex(self) -> None: + da = DataArray( + [1, 2, 3, 4], coords={"foo": ("x", ["a", "a", "b", "b"])}, dims="x" + ) + + class IndexWithOptions(Index): + def __init__(self, opt): + self.opt = opt + + @classmethod + def from_variables(cls, variables, options): + return cls(options["opt"]) + + indexed = da.set_xindex("foo", IndexWithOptions, opt=1) + assert "foo" in indexed.xindexes + assert getattr(indexed.xindexes["foo"], "opt") == 1 + def test_dataset_getitem(self) -> None: dv = self.ds["foo"] assert_identical(dv, self.dv) @@ -2526,6 +2543,14 @@ def test_drop_index_positions(self) -> None: expected = arr[:, 2:] assert_identical(actual, expected) + def test_drop_indexes(self) -> None: + arr = DataArray([1, 2, 3], coords={"x": ("x", [1, 2, 3])}, dims="x") + actual = arr.drop_indexes("x") + assert "x" not in actual.xindexes + + actual = arr.drop_indexes("not_a_coord", errors="ignore") + assert_identical(actual, arr) + def test_dropna(self) -> None: x = np.random.randn(4, 4) x[::2, 0] = np.nan diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8e6d6aab9d0..bc6410a6d4a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -30,7 +30,7 @@ from xarray.core import dtypes, indexing, utils from xarray.core.common import duck_array_ops, full_like from xarray.core.coordinates import DatasetCoordinates -from xarray.core.indexes import Index +from xarray.core.indexes import Index, PandasIndex from xarray.core.pycompat import integer_types, sparse_array_type from xarray.core.utils import is_scalar @@ -2648,6 +2648,41 @@ def test_drop_labels_by_position(self) -> None: with pytest.raises(KeyError): data.drop_isel(z=1) + def test_drop_indexes(self) -> None: + ds = Dataset( + coords={ + "x": ("x", [0, 1, 2]), + "y": ("y", [3, 4, 5]), + "foo": ("x", ["a", "a", "b"]), + } + ) + + actual = ds.drop_indexes("x") + assert "x" not in actual.xindexes + assert type(actual.x.variable) is Variable + + actual = ds.drop_indexes(["x", "y"]) + assert "x" not in actual.xindexes + assert "y" not in actual.xindexes + assert type(actual.x.variable) is Variable + assert type(actual.y.variable) is Variable + + with pytest.raises(ValueError, match="those coordinates don't exist"): + ds.drop_indexes("not_a_coord") + + with pytest.raises(ValueError, match="those coordinates do not have an index"): + ds.drop_indexes("foo") + + actual = ds.drop_indexes(["foo", "not_a_coord"], errors="ignore") + assert_identical(actual, ds) + + # test index corrupted + mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + ds = Dataset(coords={"x": mindex}) + + with pytest.raises(ValueError, match=".*would corrupt the following index.*"): + ds.drop_indexes("a") + def test_drop_dims(self) -> None: data = xr.Dataset( { @@ -3332,6 +3367,52 @@ def test_reorder_levels(self) -> None: with pytest.raises(ValueError, match=r"has no MultiIndex"): ds.reorder_levels(x=["level_1", "level_2"]) + def test_set_xindex(self) -> None: + ds = Dataset( + coords={"foo": ("x", ["a", "a", "b", "b"]), "bar": ("x", [0, 1, 2, 3])} + ) + + actual = ds.set_xindex("foo") + expected = ds.set_index(x="foo").rename_vars(x="foo") + assert_identical(actual, expected, check_default_indexes=False) + + actual_mindex = ds.set_xindex(["foo", "bar"]) + expected_mindex = ds.set_index(x=["foo", "bar"]) + assert_identical(actual_mindex, expected_mindex) + + class NotAnIndex: + ... + + with pytest.raises(TypeError, match=".*not a subclass of xarray.Index"): + ds.set_xindex("foo", NotAnIndex) # type: ignore + + with pytest.raises(ValueError, match="those variables don't exist"): + ds.set_xindex("not_a_coordinate", PandasIndex) + + ds["data_var"] = ("x", [1, 2, 3, 4]) + + with pytest.raises(ValueError, match="those variables are data variables"): + ds.set_xindex("data_var", PandasIndex) + + ds2 = Dataset(coords={"x": ("x", [0, 1, 2, 3])}) + + with pytest.raises(ValueError, match="those coordinates already have an index"): + ds2.set_xindex("x", PandasIndex) + + def test_set_xindex_options(self) -> None: + ds = Dataset(coords={"foo": ("x", ["a", "a", "b", "b"])}) + + class IndexWithOptions(Index): + def __init__(self, opt): + self.opt = opt + + @classmethod + def from_variables(cls, variables, options): + return cls(options["opt"]) + + indexed = ds.set_xindex("foo", IndexWithOptions, opt=1) + assert getattr(indexed.xindexes["foo"], "opt") == 1 + def test_stack(self) -> None: ds = Dataset( data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index e61a9859652..56267e7fd89 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -45,7 +45,7 @@ def index(self) -> CustomIndex: def test_from_variables(self) -> None: with pytest.raises(NotImplementedError): - Index.from_variables({}) + Index.from_variables({}, options={}) def test_concat(self) -> None: with pytest.raises(NotImplementedError): @@ -133,19 +133,19 @@ def test_from_variables(self) -> None: "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} ) - index = PandasIndex.from_variables({"x": var}) + index = PandasIndex.from_variables({"x": var}, options={}) assert index.dim == "x" assert index.index.equals(pd.Index(data)) assert index.coord_dtype == data.dtype var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises(ValueError, match=r".*only accepts one variable.*"): - PandasIndex.from_variables({"x": var, "foo": var2}) + PandasIndex.from_variables({"x": var, "foo": var2}, options={}) with pytest.raises( ValueError, match=r".*only accepts a 1-dimensional variable.*" ): - PandasIndex.from_variables({"foo": var2}) + PandasIndex.from_variables({"foo": var2}, options={}) def test_from_variables_index_adapter(self) -> None: # test index type is preserved when variable wraps a pd.Index @@ -153,7 +153,7 @@ def test_from_variables_index_adapter(self) -> None: pd_idx = pd.Index(data) var = xr.Variable("x", pd_idx) - index = PandasIndex.from_variables({"x": var}) + index = PandasIndex.from_variables({"x": var}, options={}) assert isinstance(index.index, pd.CategoricalIndex) def test_concat_periods(self): @@ -356,7 +356,7 @@ def test_from_variables(self) -> None: ) index = PandasMultiIndex.from_variables( - {"level1": v_level1, "level2": v_level2} + {"level1": v_level1, "level2": v_level2}, options={} ) expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) @@ -369,13 +369,15 @@ def test_from_variables(self) -> None: with pytest.raises( ValueError, match=r".*only accepts 1-dimensional variables.*" ): - PandasMultiIndex.from_variables({"var": var}) + PandasMultiIndex.from_variables({"var": var}, options={}) v_level3 = xr.Variable("y", [4, 5, 6]) with pytest.raises( ValueError, match=r"unmatched dimensions for multi-index variables.*" ): - PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) + PandasMultiIndex.from_variables( + {"level1": v_level1, "level3": v_level3}, options={} + ) def test_concat(self) -> None: pd_midx = pd.MultiIndex.from_product(