Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sel with categorical index #3670

Merged
merged 12 commits into from
Jan 25, 2020
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Support using an existing, opened h5netcdf ``File`` with
:py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an
:py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened
Expand Down
22 changes: 13 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
default_indexes,
isel_variable_and_index,
propagate_indexes,
remove_unused_levels_categories,
roll_index,
)
from .indexing import is_fancy_indexer
Expand Down Expand Up @@ -3411,7 +3412,7 @@ def ensure_stackable(val):

def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
index = self.get_index(dim)
index = index.remove_unused_levels()
index = remove_unused_levels_categories(index)
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)

# take a shortcut in case the MultiIndex was not modified.
Expand Down Expand Up @@ -4460,17 +4461,19 @@ def to_dataframe(self):
return self._to_dataframe(self.dims)

def _set_sparse_data_from_dataframe(
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
self, dataframe: pd.DataFrame, dims: tuple
) -> None:
from sparse import COO

idx = dataframe.index
if isinstance(idx, pd.MultiIndex):
coords = np.stack([np.asarray(code) for code in idx.codes], axis=0)
is_sorted = idx.is_lexsorted
shape = tuple(lev.size for lev in idx.levels)
else:
coords = np.arange(idx.size).reshape(1, -1)
is_sorted = True
shape = (idx.size,)

for name, series in dataframe.items():
# Cast to a NumPy array first, in case the Series is a pandas
Expand All @@ -4495,14 +4498,16 @@ def _set_sparse_data_from_dataframe(
self[name] = (dims, data)

def _set_numpy_data_from_dataframe(
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
self, dataframe: pd.DataFrame, dims: tuple
) -> None:
idx = dataframe.index
if isinstance(idx, pd.MultiIndex):
# expand the DataFrame to include the product of all levels
full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
dataframe = dataframe.reindex(full_idx)

shape = tuple(lev.size for lev in idx.levels)
else:
shape = (idx.size,)
for name, series in dataframe.items():
data = np.asarray(series).reshape(shape)
self[name] = (dims, data)
Expand Down Expand Up @@ -4543,7 +4548,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
if not dataframe.columns.is_unique:
raise ValueError("cannot convert DataFrame with non-unique columns")

idx = dataframe.index
idx = remove_unused_levels_categories(dataframe.index)
dataframe = dataframe.set_index(idx)
obj = cls()

if isinstance(idx, pd.MultiIndex):
Expand All @@ -4553,17 +4559,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
)
for dim, lev in zip(dims, idx.levels):
obj[dim] = (dim, lev)
shape = tuple(lev.size for lev in idx.levels)
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
obj[index_name] = (dims, idx)
shape = (idx.size,)

if sparse:
obj._set_sparse_data_from_dataframe(dataframe, dims, shape)
obj._set_sparse_data_from_dataframe(dataframe, dims)
else:
obj._set_numpy_data_from_dataframe(dataframe, dims, shape)
obj._set_numpy_data_from_dataframe(dataframe, dims)
return obj

def to_dask_dataframe(self, dim_order=None, set_index=False):
Expand Down
20 changes: 20 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
from .variable import Variable


def remove_unused_levels_categories(index):
"""
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
"""
if isinstance(index, pd.MultiIndex):
index = index.remove_unused_levels()
# if it contains CategoricalIndex, we need to remove unused categories
# manually. See https://github.com/pandas-dev/pandas/issues/30846
if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels):
levels = []
for i, level in enumerate(index.levels):
if isinstance(level, pd.CategoricalIndex):
level = level[index.codes[i]].remove_unused_categories()
levels.append(level)
Comment on lines +21 to +25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this could be cleaner by using a list comprehension:

levels = [
    level[index.codes[i]].remove_unused_categories()
    if isinstance(level, pd.CategoricalIndex)
    else level
    for i, level in enumerate(index.levels)
]

though that might be just me

index = pd.MultiIndex.from_arrays(levels, names=index.names)
elif isinstance(index, pd.CategoricalIndex):
index = index.remove_unused_categories()
return index


class Indexes(collections.abc.Mapping):
"""Immutable proxy for Dataset or DataArrary indexes."""

Expand Down
10 changes: 10 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No
if label.ndim == 0:
if isinstance(index, pd.MultiIndex):
indexer, new_index = index.get_loc_level(label.item(), level=0)
elif isinstance(index, pd.CategoricalIndex):
if method is not None:
raise ValueError(
"'method' is not a valid kwarg when indexing using a CategoricalIndex."
)
if tolerance is not None:
raise ValueError(
"'tolerance' is not a valid kwarg when indexing using a CategoricalIndex."
)
indexer = index.get_loc(label.item())
else:
indexer = index.get_loc(
label.item(), method=method, tolerance=tolerance
Expand Down
65 changes: 65 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,56 @@ def test_sel_dataarray_mindex(self):
)
)

def test_sel_categorical(self):
ind = pd.Series(["foo", "bar"], dtype="category")
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
ds = df.set_index("ind").to_xarray()
actual = ds.sel(ind="bar")
expected = ds.isel(ind=1)
assert_identical(expected, actual)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

def test_sel_categorical_error(self):
ind = pd.Series(["foo", "bar"], dtype="category")
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
ds = df.set_index("ind").to_xarray()
with pytest.raises(ValueError):
ds.sel(ind="bar", method="nearest")
with pytest.raises(ValueError):
ds.sel(ind="bar", tolerance="nearest")

def test_categorical_index(self):
cat = pd.CategoricalIndex(
["foo", "bar", "foo"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"var": ("cat", np.arange(3))},
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])},
)
# test slice
actual = ds.sel(cat="foo")
expected = ds.isel(cat=[0, 2])
assert_identical(expected, actual)
# make sure the conversion to the array works
actual = ds.sel(cat="foo")["cat"].values
assert (actual == np.array(["foo", "foo"])).all()

ds = ds.set_index(index=["cat", "c"])
actual = ds.unstack("index")
assert actual["var"].shape == (2, 2)

def test_categorical_reindex(self):
cat = pd.CategoricalIndex(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"var": ("cat", np.arange(3))},
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])},
)
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

def test_sel_drop(self):
data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]})
expected = Dataset({"foo": 1})
Expand Down Expand Up @@ -3865,6 +3915,21 @@ def test_to_and_from_dataframe(self):
expected = pd.DataFrame([[]], index=idx)
assert expected.equals(actual), (expected, actual)

def test_from_dataframe_categorical(self):
cat = pd.CategoricalDtype(
categories=["foo", "bar", "baz", "qux", "quux", "corge"]
)
i1 = pd.Series(["foo", "bar", "foo"], dtype=cat)
i2 = pd.Series(["bar", "bar", "baz"], dtype=cat)

df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]})
ds = df.set_index("i1").to_xarray()
assert len(ds["i1"]) == 3

ds = df.set_index(["i1", "i2"]).to_xarray()
assert len(ds["i1"]) == 2
assert len(ds["i2"]) == 2

@requires_sparse
def test_from_dataframe_sparse(self):
import sparse
Expand Down