diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9fbeaa95055..0e48dfcfc78 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,8 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`) + By `Keisuke Fujii `_. - Support using an existing, opened h5netcdf ``File`` with :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5ac79999795..6f1ea76be4c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,6 +64,7 @@ default_indexes, isel_variable_and_index, propagate_indexes, + remove_unused_levels_categories, roll_index, ) from .indexing import is_fancy_indexer @@ -3411,7 +3412,7 @@ def ensure_stackable(val): def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset": index = self.get_index(dim) - index = index.remove_unused_levels() + index = remove_unused_levels_categories(index) full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) # take a shortcut in case the MultiIndex was not modified. @@ -4460,7 +4461,7 @@ def to_dataframe(self): return self._to_dataframe(self.dims) def _set_sparse_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: from sparse import COO @@ -4468,9 +4469,11 @@ def _set_sparse_data_from_dataframe( if isinstance(idx, pd.MultiIndex): coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) is_sorted = idx.is_lexsorted + shape = tuple(lev.size for lev in idx.levels) else: coords = np.arange(idx.size).reshape(1, -1) is_sorted = True + shape = (idx.size,) for name, series in dataframe.items(): # Cast to a NumPy array first, in case the Series is a pandas @@ -4495,14 +4498,16 @@ def _set_sparse_data_from_dataframe( self[name] = (dims, data) def _set_numpy_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: idx = dataframe.index if isinstance(idx, pd.MultiIndex): # expand the DataFrame to include the product of all levels full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) dataframe = dataframe.reindex(full_idx) - + shape = tuple(lev.size for lev in idx.levels) + else: + shape = (idx.size,) for name, series in dataframe.items(): data = np.asarray(series).reshape(shape) self[name] = (dims, data) @@ -4543,7 +4548,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas if not dataframe.columns.is_unique: raise ValueError("cannot convert DataFrame with non-unique columns") - idx = dataframe.index + idx = remove_unused_levels_categories(dataframe.index) + dataframe = dataframe.set_index(idx) obj = cls() if isinstance(idx, pd.MultiIndex): @@ -4553,17 +4559,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas ) for dim, lev in zip(dims, idx.levels): obj[dim] = (dim, lev) - shape = tuple(lev.size for lev in idx.levels) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) obj[index_name] = (dims, idx) - shape = (idx.size,) if sparse: - obj._set_sparse_data_from_dataframe(dataframe, dims, shape) + obj._set_sparse_data_from_dataframe(dataframe, dims) else: - obj._set_numpy_data_from_dataframe(dataframe, dims, shape) + obj._set_numpy_data_from_dataframe(dataframe, dims) return obj def to_dask_dataframe(self, dim_order=None, set_index=False): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8337a0f082a..06bf08cefd2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -9,6 +9,26 @@ from .variable import Variable +def remove_unused_levels_categories(index): + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + levels.append(level) + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index + + class Indexes(collections.abc.Mapping): """Immutable proxy for Dataset or DataArrary indexes.""" diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e851b39c3e..4e58be1ad2f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No if label.ndim == 0: if isinstance(index, pd.MultiIndex): indexer, new_index = index.get_loc_level(label.item(), level=0) + elif isinstance(index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not a valid kwarg when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + ) + indexer = index.get_loc(label.item()) else: indexer = index.get_loc( label.item(), method=method, tolerance=tolerance diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f9eb37dbf2f..4e51e229b29 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1408,6 +1408,56 @@ def test_sel_dataarray_mindex(self): ) ) + def test_sel_categorical(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + actual = ds.sel(ind="bar") + expected = ds.isel(ind=1) + assert_identical(expected, actual) + + def test_sel_categorical_error(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + with pytest.raises(ValueError): + ds.sel(ind="bar", method="nearest") + with pytest.raises(ValueError): + ds.sel(ind="bar", tolerance="nearest") + + def test_categorical_index(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "foo"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])}, + ) + # test slice + actual = ds.sel(cat="foo") + expected = ds.isel(cat=[0, 2]) + assert_identical(expected, actual) + # make sure the conversion to the array works + actual = ds.sel(cat="foo")["cat"].values + assert (actual == np.array(["foo", "foo"])).all() + + ds = ds.set_index(index=["cat", "c"]) + actual = ds.unstack("index") + assert actual["var"].shape == (2, 2) + + def test_categorical_reindex(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])}, + ) + actual = ds.reindex(cat=["foo"])["cat"].values + assert (actual == np.array(["foo"])).all() + def test_sel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) @@ -3865,6 +3915,21 @@ def test_to_and_from_dataframe(self): expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical(self): + cat = pd.CategoricalDtype( + categories=["foo", "bar", "baz", "qux", "quux", "corge"] + ) + i1 = pd.Series(["foo", "bar", "foo"], dtype=cat) + i2 = pd.Series(["bar", "bar", "baz"], dtype=cat) + + df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]}) + ds = df.set_index("i1").to_xarray() + assert len(ds["i1"]) == 3 + + ds = df.set_index(["i1", "i2"]).to_xarray() + assert len(ds["i1"]) == 2 + assert len(ds["i2"]) == 2 + @requires_sparse def test_from_dataframe_sparse(self): import sparse