diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index d81ba30f12f..f514b4ecbef 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -251,7 +251,7 @@ Finally, if a dataset does not have any coordinates it enumerates all data point .. ipython:: python :okwarning: - air1d_multi = air1d_multi.drop("date") + air1d_multi = air1d_multi.drop(["date", "time", "decimal_day"]) air1d_multi.plot() The same applies to 2D plots below. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b22c6e4d858..05bdfcf78bb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,10 +22,18 @@ v2022.03.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and + :py:meth:`DataArray.stack` so that the creation of multi-indexes is optional + (:pull:`5692`). By `Benoît Bovy `_. +- Multi-index levels are now accessible through their own, regular coordinates + instead of virtual coordinates (:pull:`5692`). + By `Benoît Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ +- The Dataset and DataArray ``rename*`` methods do not implicitly add or drop + indexes. (:pull:`5692`). By `Benoît Bovy `_. Deprecations ~~~~~~~~~~~~ @@ -37,6 +45,9 @@ Bug fixes - Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and ensure it skips missing values for float dtypes (consistent with other methods). This should not change the behavior (:pull:`6303`). By `Mathias Hauser `_. +- Many bugs fixed by the explicit indexes refactor, mainly related to multi-index (virtual) + coordinates. See the corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. Documentation ~~~~~~~~~~~~~ @@ -45,6 +56,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Many internal changes due to the explicit indexes refactor. See the + corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. .. _whats-new.2022.02.0: .. _whats-new.2022.03.0: diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index f9342e2a82a..d201e3a613f 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator from collections import defaultdict @@ -5,84 +7,562 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, + Generic, Hashable, Iterable, Mapping, - Optional, Tuple, + Type, TypeVar, - Union, ) import numpy as np import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex, get_indexer_nd -from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import IndexVariable, Variable +from .common import DataWithCoords +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_all_equal +from .utils import is_dict_like, is_full_slice, safe_cast_to_index +from .variable import Variable, as_compatible_data, calculate_dimensions if TYPE_CHECKING: - from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset - DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) - - -def _get_joiner(join, index_cls): - if join == "outer": - return functools.partial(functools.reduce, index_cls.union) - elif join == "inner": - return functools.partial(functools.reduce, index_cls.intersection) - elif join == "left": - return operator.itemgetter(0) - elif join == "right": - return operator.itemgetter(-1) - elif join == "exact": - # We cannot return a function to "align" in this case, because it needs - # access to the dimension name to give a good error message. - return None - elif join == "override": - # We rewrite all indexes and then use join='left' - return operator.itemgetter(0) - else: - raise ValueError(f"invalid value for join: {join}") +DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) + + +def reindex_variables( + variables: Mapping[Any, Variable], + dim_pos_indexers: Mapping[Any, Any], + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, +) -> dict[Hashable, Variable]: + """Conform a dictionary of variables onto a new set of variables reindexed + with dimension positional indexers and possibly filled with missing values. + + Not public API. + + """ + new_variables = {} + dim_sizes = calculate_dimensions(variables) + + masked_dims = set() + unchanged_dims = set() + for dim, indxr in dim_pos_indexers.items(): + # Negative values in dim_pos_indexers mean values missing in the new index + # See ``Index.reindex_like``. + if (indxr < 0).any(): + masked_dims.add(dim) + elif np.array_equal(indxr, np.arange(dim_sizes.get(dim, 0))): + unchanged_dims.add(dim) + + for name, var in variables.items(): + if isinstance(fill_value, dict): + fill_value_ = fill_value.get(name, dtypes.NA) + else: + fill_value_ = fill_value + + if sparse: + var = var._as_sparse(fill_value=fill_value_) + indxr = tuple( + slice(None) if d in unchanged_dims else dim_pos_indexers.get(d, slice(None)) + for d in var.dims + ) + needs_masking = any(d in masked_dims for d in var.dims) + + if needs_masking: + new_var = var._getitem_with_mask(indxr, fill_value=fill_value_) + elif all(is_full_slice(k) for k in indxr): + # no reindexing necessary + # here we need to manually deal with copying data, since + # we neither created a new ndarray nor used fancy indexing + new_var = var.copy(deep=copy) + else: + new_var = var[indxr] + + new_variables[name] = new_var + + return new_variables + + +CoordNamesAndDims = Tuple[Tuple[Hashable, Tuple[Hashable, ...]], ...] +MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] +NormalizedIndexes = Dict[MatchingIndexKey, Index] +NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] + + +class Aligner(Generic[DataAlignable]): + """Implements all the complex logic for the re-indexing and alignment of Xarray + objects. + + For internal use only, not public API. + Usage: + + aligner = Aligner(*objects, **kwargs) + aligner.align() + aligned_objects = aligner.results + + """ + + objects: tuple[DataAlignable, ...] + results: tuple[DataAlignable, ...] + objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] + join: str + exclude_dims: frozenset[Hashable] + exclude_vars: frozenset[Hashable] + copy: bool + fill_value: Any + sparse: bool + indexes: dict[MatchingIndexKey, Index] + index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + aligned_indexes: dict[MatchingIndexKey, Index] + aligned_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + reindex: dict[MatchingIndexKey, bool] + reindex_kwargs: dict[str, Any] + unindexed_dim_sizes: dict[Hashable, set] + new_indexes: Indexes[Index] + + def __init__( + self, + objects: Iterable[DataAlignable], + join: str = "inner", + indexes: Mapping[Any, Any] = None, + exclude_dims: Iterable = frozenset(), + exclude_vars: Iterable[Hashable] = frozenset(), + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, + ): + self.objects = tuple(objects) + self.objects_matching_indexes = () + + if join not in ["inner", "outer", "override", "exact", "left", "right"]: + raise ValueError(f"invalid value for join: {join}") + self.join = join + + self.copy = copy + self.fill_value = fill_value + self.sparse = sparse + + if method is None and tolerance is None: + self.reindex_kwargs = {} + else: + self.reindex_kwargs = {"method": method, "tolerance": tolerance} + + if isinstance(exclude_dims, str): + exclude_dims = [exclude_dims] + self.exclude_dims = frozenset(exclude_dims) + self.exclude_vars = frozenset(exclude_vars) + + if indexes is None: + indexes = {} + self.indexes, self.index_vars = self._normalize_indexes(indexes) + + self.all_indexes = {} + self.all_index_vars = {} + self.unindexed_dim_sizes = {} + + self.aligned_indexes = {} + self.aligned_index_vars = {} + self.reindex = {} + + self.results = tuple() + def _normalize_indexes( + self, + indexes: Mapping[Any, Any], + ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: + """Normalize the indexes/indexers used for re-indexing or alignment. -def _override_indexes(objects, all_indexes, exclude): - for dim, dim_indexes in all_indexes.items(): - if dim not in exclude: - lengths = { - getattr(index, "size", index.to_pandas_index().size) - for index in dim_indexes - } - if len(lengths) != 1: + Return dictionaries of xarray Index objects and coordinate variables + such that we can group matching indexes based on the dictionary keys. + + """ + if isinstance(indexes, Indexes): + xr_variables = dict(indexes.variables) + else: + xr_variables = {} + + xr_indexes: dict[Hashable, Index] = {} + for k, idx in indexes.items(): + if not isinstance(idx, Index): + if getattr(idx, "dims", (k,)) != (k,): + raise ValueError( + f"Indexer has dimensions {idx.dims} that are different " + f"from that to be indexed along '{k}'" + ) + data = as_compatible_data(idx) + pd_idx = safe_cast_to_index(data) + pd_idx.name = k + if isinstance(pd_idx, pd.MultiIndex): + idx = PandasMultiIndex(pd_idx, k) + else: + idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype) + xr_variables.update(idx.create_variables()) + xr_indexes[k] = idx + + normalized_indexes = {} + normalized_index_vars = {} + for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + coord_names_and_dims = [] + all_dims = set() + + for name, var in index_vars.items(): + dims = var.dims + coord_names_and_dims.append((name, dims)) + all_dims.update(dims) + + exclude_dims = all_dims & self.exclude_dims + if exclude_dims == all_dims: + continue + elif exclude_dims: + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) raise ValueError( - f"Indexes along dimension {dim!r} don't have the same length." - " Cannot use join='override'." + f"cannot exclude dimension(s) {excl_dims_str} from alignment because " + "these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" ) - objects = list(objects) - for idx, obj in enumerate(objects[1:]): - new_indexes = { - dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude - } + key = (tuple(coord_names_and_dims), type(idx)) + normalized_indexes[key] = idx + normalized_index_vars[key] = index_vars + + return normalized_indexes, normalized_index_vars + + def find_matching_indexes(self) -> None: + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + all_indexes_dim_sizes: dict[MatchingIndexKey, dict[Hashable, set]] + objects_matching_indexes: list[dict[MatchingIndexKey, Index]] + + all_indexes = defaultdict(list) + all_index_vars = defaultdict(list) + all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set)) + objects_matching_indexes = [] + + for obj in self.objects: + obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes) + objects_matching_indexes.append(obj_indexes) + for key, idx in obj_indexes.items(): + all_indexes[key].append(idx) + for key, index_vars in obj_index_vars.items(): + all_index_vars[key].append(index_vars) + for dim, size in calculate_dimensions(index_vars).items(): + all_indexes_dim_sizes[key][dim].add(size) + + self.objects_matching_indexes = tuple(objects_matching_indexes) + self.all_indexes = all_indexes + self.all_index_vars = all_index_vars + + if self.join == "override": + for dim_sizes in all_indexes_dim_sizes.values(): + for dim, sizes in dim_sizes.items(): + if len(sizes) > 1: + raise ValueError( + "cannot align objects with join='override' with matching indexes " + f"along dimension {dim!r} that don't have the same size" + ) + + def find_matching_unindexed_dims(self) -> None: + unindexed_dim_sizes = defaultdict(set) + + for obj in self.objects: + for dim in obj.dims: + if dim not in self.exclude_dims and dim not in obj.xindexes.dims: + unindexed_dim_sizes[dim].add(obj.sizes[dim]) + + self.unindexed_dim_sizes = unindexed_dim_sizes + + def assert_no_index_conflict(self) -> None: + """Check for uniqueness of both coordinate and dimension names accross all sets + of matching indexes. + + We need to make sure that all indexes used for re-indexing or alignment + are fully compatible and do not conflict each other. + + Note: perhaps we could choose less restrictive constraints and instead + check for conflicts among the dimension (position) indexers returned by + `Index.reindex_like()` for each matching pair of object index / aligned + index? + (ref: https://github.com/pydata/xarray/issues/1603#issuecomment-442965602) + + """ + matching_keys = set(self.all_indexes) | set(self.indexes) + + coord_count: dict[Hashable, int] = defaultdict(int) + dim_count: dict[Hashable, int] = defaultdict(int) + for coord_names_dims, _ in matching_keys: + dims_set: set[Hashable] = set() + for name, dims in coord_names_dims: + coord_count[name] += 1 + dims_set.update(dims) + for dim in dims_set: + dim_count[dim] += 1 + + for count, msg in [(coord_count, "coordinates"), (dim_count, "dimensions")]: + dup = {k: v for k, v in count.items() if v > 1} + if dup: + items_msg = ", ".join( + f"{k!r} ({v} conflicting indexes)" for k, v in dup.items() + ) + raise ValueError( + "cannot re-index or align objects with conflicting indexes found for " + f"the following {msg}: {items_msg}\n" + "Conflicting indexes may occur when\n" + "- they relate to different sets of coordinate and/or dimension names\n" + "- they don't have the same type\n" + "- they may be used to reindex data along common dimensions" + ) - objects[idx + 1] = obj._overwrite_indexes(new_indexes) + def _need_reindex(self, dims, cmp_indexes) -> bool: + """Whether or not we need to reindex variables for a set of + matching indexes. + + We don't reindex when all matching indexes are equal for two reasons: + - It's faster for the usual case (already aligned objects). + - It ensures it's possible to do operations that don't require alignment + on indexes with duplicate values (which cannot be reindexed with + pandas). This is useful, e.g., for overwriting such duplicate indexes. + + """ + has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) + return not (indexes_all_equal(cmp_indexes)) or has_unindexed_dims + + def _get_index_joiner(self, index_cls) -> Callable: + if self.join in ["outer", "inner"]: + return functools.partial( + functools.reduce, + functools.partial(index_cls.join, how=self.join), + ) + elif self.join == "left": + return operator.itemgetter(0) + elif self.join == "right": + return operator.itemgetter(-1) + elif self.join == "override": + # We rewrite all indexes and then use join='left' + return operator.itemgetter(0) + else: + # join='exact' return dummy lambda (error is raised) + return lambda _: None + + def align_indexes(self) -> None: + """Compute all aligned indexes and their corresponding coordinate variables.""" + + aligned_indexes = {} + aligned_index_vars = {} + reindex = {} + new_indexes = {} + new_index_vars = {} + + for key, matching_indexes in self.all_indexes.items(): + matching_index_vars = self.all_index_vars[key] + dims = {d for coord in matching_index_vars[0].values() for d in coord.dims} + index_cls = key[1] + + if self.join == "override": + joined_index = matching_indexes[0] + joined_index_vars = matching_index_vars[0] + need_reindex = False + elif key in self.indexes: + joined_index = self.indexes[key] + joined_index_vars = self.index_vars[key] + cmp_indexes = list( + zip( + [joined_index] + matching_indexes, + [joined_index_vars] + matching_index_vars, + ) + ) + need_reindex = self._need_reindex(dims, cmp_indexes) + else: + if len(matching_indexes) > 1: + need_reindex = self._need_reindex( + dims, + list(zip(matching_indexes, matching_index_vars)), + ) + else: + need_reindex = False + if need_reindex: + if self.join == "exact": + raise ValueError( + "cannot align objects with join='exact' where " + "index/labels/sizes are not equal along " + "these coordinates (dimensions): " + + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0]) + ) + joiner = self._get_index_joiner(index_cls) + joined_index = joiner(matching_indexes) + if self.join == "left": + joined_index_vars = matching_index_vars[0] + elif self.join == "right": + joined_index_vars = matching_index_vars[-1] + else: + joined_index_vars = joined_index.create_variables() + else: + joined_index = matching_indexes[0] + joined_index_vars = matching_index_vars[0] + + reindex[key] = need_reindex + aligned_indexes[key] = joined_index + aligned_index_vars[key] = joined_index_vars + + for name, var in joined_index_vars.items(): + new_indexes[name] = joined_index + new_index_vars[name] = var + + # Explicitly provided indexes that are not found in objects to align + # may relate to unindexed dimensions so we add them too + for key, idx in self.indexes.items(): + if key not in aligned_indexes: + index_vars = self.index_vars[key] + reindex[key] = False + aligned_indexes[key] = idx + aligned_index_vars[key] = index_vars + for name, var in index_vars.items(): + new_indexes[name] = idx + new_index_vars[name] = var + + self.aligned_indexes = aligned_indexes + self.aligned_index_vars = aligned_index_vars + self.reindex = reindex + self.new_indexes = Indexes(new_indexes, new_index_vars) + + def assert_unindexed_dim_sizes_equal(self) -> None: + for dim, sizes in self.unindexed_dim_sizes.items(): + index_size = self.new_indexes.dims.get(dim) + if index_size is not None: + sizes.add(index_size) + add_err_msg = ( + f" (note: an index is found along that dimension " + f"with size={index_size!r})" + ) + else: + add_err_msg = "" + if len(sizes) > 1: + raise ValueError( + f"cannot reindex or align along dimension {dim!r} " + f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg + ) - return objects + def override_indexes(self) -> None: + objects = list(self.objects) + + for i, obj in enumerate(objects[1:]): + new_indexes = {} + new_variables = {} + matching_indexes = self.objects_matching_indexes[i + 1] + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + for name, var in self.aligned_index_vars[key].items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + + objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables) + + self.results = tuple(objects) + + def _get_dim_pos_indexers( + self, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> dict[Hashable, Any]: + dim_pos_indexers = {} + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + if self.reindex[key]: + indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) # type: ignore[call-arg] + dim_pos_indexers.update(indexers) + + return dim_pos_indexers + + def _get_indexes_and_vars( + self, + obj: DataAlignable, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes = {} + new_variables = {} + + for key, aligned_idx in self.aligned_indexes.items(): + index_vars = self.aligned_index_vars[key] + obj_idx = matching_indexes.get(key) + if obj_idx is None: + # add the index if it relates to unindexed dimensions in obj + index_vars_dims = {d for var in index_vars.values() for d in var.dims} + if index_vars_dims <= set(obj.dims): + obj_idx = aligned_idx + if obj_idx is not None: + for name, var in index_vars.items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + + return new_indexes, new_variables + + def _reindex_one( + self, + obj: DataAlignable, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> DataAlignable: + new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) + dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) + + new_obj = obj._reindex_callback( + self, + dim_pos_indexers, + new_variables, + new_indexes, + self.fill_value, + self.exclude_dims, + self.exclude_vars, + ) + new_obj.encoding = obj.encoding + return new_obj + + def reindex_all(self) -> None: + self.results = tuple( + self._reindex_one(obj, matching_indexes) + for obj, matching_indexes in zip( + self.objects, self.objects_matching_indexes + ) + ) + + def align(self) -> None: + if not self.indexes and len(self.objects) == 1: + # fast path for the trivial case + (obj,) = self.objects + self.results = (obj.copy(deep=self.copy),) + + self.find_matching_indexes() + self.find_matching_unindexed_dims() + self.assert_no_index_conflict() + self.align_indexes() + self.assert_unindexed_dim_sizes_equal() + + if self.join == "override": + self.override_indexes() + else: + self.reindex_all() def align( - *objects: "DataAlignable", + *objects: DataAlignable, join="inner", copy=True, indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -) -> Tuple["DataAlignable", ...]: +) -> tuple[DataAlignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -251,8 +731,7 @@ def align( >>> a, b = xr.align(x, y, join="exact") Traceback (most recent call last): ... - "indexes along dimension {!r} are not equal".format(dim) - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' ... >>> a, b = xr.align(x, y, join="override") >>> a @@ -271,107 +750,16 @@ def align( * lon (lon) float64 100.0 120.0 """ - if indexes is None: - indexes = {} - - if not indexes and len(objects) == 1: - # fast path for the trivial case - (obj,) = objects - return (obj.copy(deep=copy),) - - all_indexes = defaultdict(list) - all_coords = defaultdict(list) - unlabeled_dim_sizes = defaultdict(set) - for obj in objects: - for dim in obj.dims: - if dim not in exclude: - all_coords[dim].append(obj.coords[dim]) - try: - index = obj.xindexes[dim] - except KeyError: - unlabeled_dim_sizes[dim].add(obj.sizes[dim]) - else: - all_indexes[dim].append(index) - - if join == "override": - objects = _override_indexes(objects, all_indexes, exclude) - - # We don't reindex over dimensions with all equal indexes for two reasons: - # - It's faster for the usual case (already aligned objects). - # - It ensures it's possible to do operations that don't require alignment - # on indexes with duplicate values (which cannot be reindexed with - # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joined_indexes = {} - for dim, matching_indexes in all_indexes.items(): - if dim in indexes: - index, _ = PandasIndex.from_pandas_index( - safe_cast_to_index(indexes[dim]), dim - ) - if ( - any(not index.equals(other) for other in matching_indexes) - or dim in unlabeled_dim_sizes - ): - joined_indexes[dim] = indexes[dim] - else: - if ( - any( - not matching_indexes[0].equals(other) - for other in matching_indexes[1:] - ) - or dim in unlabeled_dim_sizes - ): - if join == "exact": - raise ValueError(f"indexes along dimension {dim!r} are not equal") - joiner = _get_joiner(join, type(matching_indexes[0])) - index = joiner(matching_indexes) - # make sure str coords are not cast to object - index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim]) - joined_indexes[dim] = index - else: - index = all_coords[dim][0] - - if dim in unlabeled_dim_sizes: - unlabeled_sizes = unlabeled_dim_sizes[dim] - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - if isinstance(index, PandasIndex): - labeled_size = index.to_pandas_index().size - else: - labeled_size = index.size - if len(unlabeled_sizes | {labeled_size}) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension size(s) {unlabeled_sizes!r} " - f"than the size of the aligned dimension labels: {labeled_size!r}" - ) - - for dim, sizes in unlabeled_dim_sizes.items(): - if dim not in all_indexes and len(sizes) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension sizes: {sizes!r}" - ) - - result = [] - for obj in objects: - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - valid_indexers = {} - for k, index in joined_indexes.items(): - if k in obj.dims: - if isinstance(index, Index): - valid_indexers[k] = index.to_pandas_index() - else: - valid_indexers[k] = index - if not valid_indexers: - # fast path for no reindexing necessary - new_obj = obj.copy(deep=copy) - else: - new_obj = obj.reindex( - copy=copy, fill_value=fill_value, indexers=valid_indexers - ) - new_obj.encoding = obj.encoding - result.append(new_obj) - - return tuple(result) + aligner = Aligner( + objects, + join=join, + copy=copy, + indexes=indexes, + exclude_dims=exclude, + fill_value=fill_value, + ) + aligner.align() + return aligner.results def deep_align( @@ -457,197 +845,78 @@ def is_alignable(obj): return out -def reindex_like_indexers( - target: "Union[DataArray, Dataset]", other: "Union[DataArray, Dataset]" -) -> Dict[Hashable, pd.Index]: - """Extract indexers to align target with other. - - Not public API. - - Parameters - ---------- - target : Dataset or DataArray - Object to be aligned. - other : Dataset or DataArray - Object to be aligned with. - - Returns - ------- - Dict[Hashable, pandas.Index] providing indexes for reindex keyword - arguments. - - Raises - ------ - ValueError - If any dimensions without labels have different sizes. - """ - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - # this doesn't support yet indexes other than pd.Index - indexers = { - k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims - } - - for dim in other.dims: - if dim not in indexers and dim in target.dims: - other_size = other.sizes[dim] - target_size = target.sizes[dim] - if other_size != target_size: - raise ValueError( - "different size for unlabeled " - f"dimension on argument {dim!r}: {other_size!r} vs {target_size!r}" - ) - return indexers - - -def reindex_variables( - variables: Mapping[Any, Variable], - sizes: Mapping[Any, int], - indexes: Mapping[Any, Index], - indexers: Mapping, - method: Optional[str] = None, - tolerance: Union[Union[int, float], Iterable[Union[int, float]]] = None, +def reindex( + obj: DataAlignable, + indexers: Mapping[Any, Any], + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, - fill_value: Optional[Any] = dtypes.NA, + fill_value: Any = dtypes.NA, sparse: bool = False, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: - """Conform a dictionary of aligned variables onto a new set of variables, - filling in missing values with NaN. + exclude_vars: Iterable[Hashable] = frozenset(), +) -> DataAlignable: + """Re-index either a Dataset or a DataArray. Not public API. - Parameters - ---------- - variables : dict-like - Dictionary of xarray.Variable objects. - sizes : dict-like - Dictionary from dimension names to integer sizes. - indexes : dict-like - Dictionary of indexes associated with variables. - indexers : dict - Dictionary with keys given by dimension names and values given by - arrays of coordinates tick labels. Any mis-matched coordinate values - will be filled in with NaN, and any mis-matched dimension names will - simply be ignored. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional - Method to use for filling index values in ``indexers`` not found in - this dataset: - * None (default): don't fill gaps - * pad / ffill: propagate last valid index value forward - * backfill / bfill: propagate next valid index value backward - * nearest: use nearest valid index value - tolerance : optional - Maximum distance between original and new labels for inexact matches. - The values of the index at the matching locations must satisfy the - equation ``abs(index[indexer] - target) <= tolerance``. - Tolerance may be a scalar value, which applies the same tolerance - to all values, or list-like, which applies variable tolerance per - element. List-like must be the same size as the index and its dtype - must exactly match the index’s type. - copy : bool, optional - If ``copy=True``, data in the return values is always copied. If - ``copy=False`` and reindexing is unnecessary, or can be performed - with only slice operations, then the output may share memory with - the input. In either case, new xarray objects are always returned. - fill_value : scalar, optional - Value to use for newly missing values - sparse : bool, optional - Use an sparse-array - - Returns - ------- - reindexed : dict - Dict of reindexed variables. - new_indexes : dict - Dict of indexes associated with the reindexed variables. """ - from .dataarray import DataArray - - # create variables for the new dataset - reindexed: Dict[Hashable, Variable] = {} - - # build up indexers for assignment along each dimension - int_indexers = {} - new_indexes = dict(indexes) - masked_dims = set() - unchanged_dims = set() - - for dim, indexer in indexers.items(): - if isinstance(indexer, DataArray) and indexer.dims != (dim,): - raise ValueError( - "Indexer has dimensions {:s} that are different " - "from that to be indexed along {:s}".format(str(indexer.dims), dim) - ) - - target = safe_cast_to_index(indexers[dim]) - new_indexes[dim] = PandasIndex(target, dim) - - if dim in indexes: - # TODO (benbovy - flexible indexes): support other indexes than pd.Index? - index = indexes[dim].to_pandas_index() - - if not index.is_unique: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} because the " - "index has duplicate values" - ) - - int_indexer = get_indexer_nd(index, target, method, tolerance) - - # We uses negative values from get_indexer_nd to signify - # values that are missing in the index. - if (int_indexer < 0).any(): - masked_dims.add(dim) - elif np.array_equal(int_indexer, np.arange(len(index))): - unchanged_dims.add(dim) - int_indexers[dim] = int_indexer - - if dim in variables: - var = variables[dim] - args: tuple = (var.attrs, var.encoding) - else: - args = () - reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) - - for dim in sizes: - if dim not in indexes and dim in indexers: - existing_size = sizes[dim] - new_size = indexers[dim].size - if existing_size != new_size: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} without an " - f"index because its size {existing_size!r} is different from the size of " - f"the new index {new_size!r}" - ) + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # bad_keys = [k for k in indexers if k not in obj._indexes and k not in obj.dims] + # if bad_keys: + # raise ValueError( + # f"indexer keys {bad_keys} do not correspond to any indexed coordinate " + # "or unindexed dimension in the object to reindex" + # ) + + aligner = Aligner( + (obj,), + indexes=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + sparse=sparse, + exclude_vars=exclude_vars, + ) + aligner.align() + return aligner.results[0] - for name, var in variables.items(): - if name not in indexers: - if isinstance(fill_value, dict): - fill_value_ = fill_value.get(name, dtypes.NA) - else: - fill_value_ = fill_value - if sparse: - var = var._as_sparse(fill_value=fill_value_) - key = tuple( - slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) - for d in var.dims - ) - needs_masking = any(d in masked_dims for d in var.dims) - - if needs_masking: - new_var = var._getitem_with_mask(key, fill_value=fill_value_) - elif all(is_full_slice(k) for k in key): - # no reindexing necessary - # here we need to manually deal with copying data, since - # we neither created a new ndarray nor used fancy indexing - new_var = var.copy(deep=copy) - else: - new_var = var[key] +def reindex_like( + obj: DataAlignable, + other: Dataset | DataArray, + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = dtypes.NA, +) -> DataAlignable: + """Re-index either a Dataset or a DataArray like another Dataset/DataArray. - reindexed[name] = new_var + Not public API. - return reindexed, new_indexes + """ + if not other._indexes: + # This check is not performed in Aligner. + for dim in other.dims: + if dim in obj.dims: + other_size = other.sizes[dim] + obj_size = obj.sizes[dim] + if other_size != obj_size: + raise ValueError( + "different size for unlabeled " + f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}" + ) + + return reindex( + obj, + indexers=other.xindexes, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) def _get_broadcast_dims_map_common_coords(args, exclude): diff --git a/xarray/core/combine.py b/xarray/core/combine.py index d23a58522e6..78f016fdccd 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -86,7 +86,7 @@ def _infer_concat_order_from_coords(datasets): if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds.xindexes.get(dim) for ds in datasets] + indexes = [ds._indexes.get(dim) for ds in datasets] if any(index is None for index in indexes): raise ValueError( "Every dimension needs a coordinate for " diff --git a/xarray/core/common.py b/xarray/core/common.py index bee59c6cc7d..60a01951eac 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -404,7 +404,7 @@ def get_index(self, key: Hashable) -> pd.Index: raise KeyError(key) try: - return self.xindexes[key].to_pandas_index() + return self._indexes[key].to_pandas_index() except KeyError: return pd.Index(range(self.sizes[key]), name=key) @@ -1151,8 +1151,7 @@ def resample( category=FutureWarning, ) - # TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass - if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex): + if isinstance(self._indexes[dim_name].to_pandas_index(), CFTimeIndex): from .resample_cftime import CFTimeGrouper grouper = CFTimeGrouper(freq, closed, label, base, loffset) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ce37251576a..7676d8e558c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -23,6 +23,7 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align +from .indexes import Index, filter_indexes_from_coords from .merge import merge_attrs, merge_coordinates_without_align from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array @@ -204,13 +205,13 @@ def _get_coords_list(args) -> list[Coordinates]: return coords_list -def build_output_coords( +def build_output_coords_and_indexes( args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", -) -> list[dict[Any, Variable]]: - """Build output coordinates for an operation. +) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: + """Build output coordinates and indexes for an operation. Parameters ---------- @@ -225,7 +226,7 @@ def build_output_coords( Returns ------- - Dictionary of Variable objects with merged coordinates. + Dictionaries of Variable and Index objects with merged coordinates. """ coords_list = _get_coords_list(args) @@ -233,24 +234,30 @@ def build_output_coords( # we can skip the expensive merge (unpacked_coords,) = coords_list merged_vars = dict(unpacked_coords.variables) + merged_indexes = dict(unpacked_coords.xindexes) else: - # TODO: save these merged indexes, instead of re-computing them later - merged_vars, unused_indexes = merge_coordinates_without_align( + merged_vars, merged_indexes = merge_coordinates_without_align( coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] + output_indexes = [] for output_dims in signature.output_core_dims: dropped_dims = signature.all_input_core_dims - set(output_dims) if dropped_dims: - filtered = { + filtered_coords = { k: v for k, v in merged_vars.items() if dropped_dims.isdisjoint(v.dims) } + filtered_indexes = filter_indexes_from_coords( + merged_indexes, set(filtered_coords) + ) else: - filtered = merged_vars - output_coords.append(filtered) + filtered_coords = merged_vars + filtered_indexes = merged_indexes + output_coords.append(filtered_coords) + output_indexes.append(filtered_indexes) - return output_coords + return output_coords, output_indexes def apply_dataarray_vfunc( @@ -278,7 +285,7 @@ def apply_dataarray_vfunc( else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - result_coords = build_output_coords( + result_coords, result_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) @@ -287,12 +294,19 @@ def apply_dataarray_vfunc( if signature.num_outputs > 1: out = tuple( - DataArray(variable, coords, name=name, fastpath=True) - for variable, coords in zip(result_var, result_coords) + DataArray( + variable, coords=coords, indexes=indexes, name=name, fastpath=True + ) + for variable, coords, indexes in zip( + result_var, result_coords, result_indexes + ) ) else: (coords,) = result_coords - out = DataArray(result_var, coords, name=name, fastpath=True) + (indexes,) = result_indexes + out = DataArray( + result_var, coords=coords, indexes=indexes, name=name, fastpath=True + ) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): @@ -391,7 +405,9 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: dict[Hashable, Variable], + coord_variables: Mapping[Hashable, Variable], + indexes: dict[Hashable, Index], ) -> Dataset: """Create a dataset as quickly as possible. @@ -401,7 +417,7 @@ def _fast_dataset( variables.update(coord_variables) coord_names = set(coord_variables) - return Dataset._construct_direct(variables, coord_names) + return Dataset._construct_direct(variables, coord_names, indexes=indexes) def apply_dataset_vfunc( @@ -433,7 +449,7 @@ def apply_dataset_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords( + list_of_coords, list_of_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) args = [getattr(arg, "data_vars", arg) for arg in args] @@ -443,10 +459,14 @@ def apply_dataset_vfunc( ) if signature.num_outputs > 1: - out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords)) + out = tuple( + _fast_dataset(*args) + for args in zip(result_vars, list_of_coords, list_of_indexes) + ) else: (coord_vars,) = list_of_coords - out = _fast_dataset(result_vars, coord_vars) + (indexes,) = list_of_indexes + out = _fast_dataset(result_vars, coord_vars, indexes=indexes) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 1e6e246322e..8ee4672c49a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,14 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Hashable, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, overload import pandas as pd from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv -from .merge import _VALID_COMPAT, merge_attrs, unique_variable -from .variable import IndexVariable, Variable, as_variable +from .indexes import Index, PandasIndex +from .merge import ( + _VALID_COMPAT, + collect_variables_and_indexes, + merge_attrs, + merge_collected, +) +from .variable import Variable from .variable import concat as concat_vars if TYPE_CHECKING: @@ -28,7 +34,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -43,7 +49,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -240,30 +246,31 @@ def concat( ) -def _calc_concat_dim_coord(dim): - """ - Infer the dimension name and 1d coordinate variable (if appropriate) +def _calc_concat_dim_index( + dim_or_data: Hashable | Any, +) -> tuple[Hashable, PandasIndex | None]: + """Infer the dimension name and 1d index / coordinate variable (if appropriate) for concatenating along the new dimension. + """ from .dataarray import DataArray - if isinstance(dim, str): - coord = None - elif not isinstance(dim, (DataArray, Variable)): - dim_name = getattr(dim, "name", None) - if dim_name is None: - dim_name = "concat_dim" - coord = IndexVariable(dim_name, dim) - dim = dim_name - elif not isinstance(dim, DataArray): - coord = as_variable(dim).to_index_variable() - (dim,) = coord.dims + dim: Hashable | None + + if isinstance(dim_or_data, str): + dim = dim_or_data + index = None else: - coord = dim - if coord.name is None: - coord.name = dim.dims[0] - (dim,) = coord.dims - return dim, coord + if not isinstance(dim_or_data, (DataArray, Variable)): + dim = getattr(dim_or_data, "name", None) + if dim is None: + dim = "concat_dim" + else: + (dim,) = dim_or_data.dims + coord_dtype = getattr(dim_or_data, "dtype", None) + index = PandasIndex(dim_or_data, dim, coord_dtype=coord_dtype) + + return dim, index def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat): @@ -414,7 +421,7 @@ def _dataset_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -431,7 +438,8 @@ def _dataset_concat( "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" ) - dim, coord = _calc_concat_dim_coord(dim) + dim, index = _calc_concat_dim_index(dim) + # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] datasets = list( @@ -464,22 +472,20 @@ def _dataset_concat( variables_to_merge = (coord_names | data_names) - concat_over - dim_names result_vars = {} + result_indexes = {} + if variables_to_merge: - to_merge: dict[Hashable, list[Variable]] = { - var: [] for var in variables_to_merge + grouped = { + k: v + for k, v in collect_variables_and_indexes(list(datasets)).items() + if k in variables_to_merge } + merged_vars, merged_indexes = merge_collected( + grouped, compat=compat, equals=equals + ) + result_vars.update(merged_vars) + result_indexes.update(merged_indexes) - for ds in datasets: - for var in variables_to_merge: - if var in ds: - to_merge[var].append(ds.variables[var]) - - for var in variables_to_merge: - result_vars[var] = unique_variable( - var, to_merge[var], compat=compat, equals=equals.get(var, None) - ) - else: - result_vars = {} result_vars.update(dim_coords) # assign attrs and encoding from first dataset @@ -506,22 +512,64 @@ def ensure_common_dims(vars): var = var.set_dims(common_dims, common_shape) yield var - # stack up each variable to fill-out the dataset (in order) + # get the indexes to concatenate together, create a PandasIndex + # for any scalar coordinate variable found with ``name`` matching ``dim``. + # TODO: depreciate concat a mix of scalar and dimensional indexed coodinates? + # TODO: (benbovy - explicit indexes): check index types and/or coordinates + # of all datasets? + def get_indexes(name): + for ds in datasets: + if name in ds._indexes: + yield ds._indexes[name] + elif name == dim: + var = ds._variables[name] + if not var.dims: + yield PandasIndex([var.values], dim) + + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for k in datasets[0].variables: - if k in concat_over: + for name in datasets[0].variables: + if name in concat_over and name not in result_indexes: try: - vars = ensure_common_dims([ds[k].variable for ds in datasets]) + vars = ensure_common_dims([ds[name].variable for ds in datasets]) except KeyError: - raise ValueError(f"{k!r} is not present in all datasets.") - combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) - assert isinstance(combined, Variable) - result_vars[k] = combined - elif k in result_vars: + raise ValueError(f"{name!r} is not present in all datasets.") + + # Try concatenate the indexes, concatenate the variables when no index + # is found on all datasets. + indexes: list[Index] = list(get_indexes(name)) + if indexes: + if len(indexes) < len(datasets): + raise ValueError( + f"{name!r} must have either an index or no index in all datasets, " + f"found {len(indexes)}/{len(datasets)} datasets with an index." + ) + combined_idx = indexes[0].concat(indexes, dim, positions) + if name in datasets[0]._indexes: + idx_vars = datasets[0].xindexes.get_all_coords(name) + else: + # index created from a scalar coordinate + idx_vars = {name: datasets[0][name].variable} + result_indexes.update({k: combined_idx for k in idx_vars}) + combined_idx_vars = combined_idx.create_variables(idx_vars) + for k, v in combined_idx_vars.items(): + v.attrs = merge_attrs( + [ds.variables[k].attrs for ds in datasets], + combine_attrs=combine_attrs, + ) + result_vars[k] = v + else: + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + result_vars[name] = combined_var + + elif name in result_vars: # preserves original variable order - result_vars[k] = result_vars.pop(k) + result_vars[name] = result_vars.pop(name) result = Dataset(result_vars, attrs=result_attrs) + absent_coord_names = coord_names - set(result.variables) if absent_coord_names: raise ValueError( @@ -532,9 +580,13 @@ def ensure_common_dims(vars): result = result.drop_vars(unlabeled_dims, errors="ignore") - if coord is not None: - # add concat dimension last to ensure that its in the final Dataset - result[coord.name] = coord + if index is not None: + # add concat index / coordinate last to ensure that its in the final Dataset + result[dim] = index.create_variables()[dim] + result_indexes[dim] = index + + # TODO: add indexes at Dataset creation (when it is supported) + result = result._overwrite_indexes(result_indexes) return result @@ -545,7 +597,7 @@ def _dataarray_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9dfd64e9c99..458be214f81 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -16,11 +16,11 @@ import numpy as np import pandas as pd -from . import formatting, indexing -from .indexes import Index, Indexes +from . import formatting +from .indexes import Index, Indexes, assert_no_index_corrupted from .merge import merge_coordinates_without_align, merge_coords -from .utils import Frozen, ReprObject, either_dict_or_kwargs -from .variable import Variable +from .utils import Frozen, ReprObject +from .variable import Variable, calculate_dimensions if TYPE_CHECKING: from .dataarray import DataArray @@ -49,11 +49,11 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: raise NotImplementedError() @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: return self._data.indexes # type: ignore[attr-defined] @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: return self._data.xindexes # type: ignore[attr-defined] @property @@ -272,8 +272,6 @@ def to_dataset(self) -> "Dataset": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - variables = self._data._variables.copy() variables.update(coords) @@ -335,8 +333,6 @@ def __getitem__(self, key: Hashable) -> "DataArray": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable dims = calculate_dimensions(coords_plus_data) @@ -360,11 +356,13 @@ def to_dataset(self) -> "Dataset": from .dataset import Dataset coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()} - return Dataset._construct_direct(coords, set(coords)) + indexes = dict(self._data.xindexes) + return Dataset._construct_direct(coords, set(coords), indexes=indexes) def __delitem__(self, key: Hashable) -> None: if key not in self: raise KeyError(f"{key!r} is not a coordinate variable.") + assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] if self._data._indexes is not None and key in self._data._indexes: @@ -390,44 +388,3 @@ def assert_coordinate_consistent( f"dimension coordinate {k!r} conflicts between " f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" ) - - -def remap_label_indexers( - obj: Union["DataArray", "Dataset"], - indexers: Mapping[Any, Any] = None, - method: str = None, - tolerance=None, - **indexers_kwargs: Any, -) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing - """Remap indexers from obj.coords. - If indexer is an instance of DataArray and it has coordinate, then this coordinate - will be attached to pos_indexers. - - Returns - ------- - pos_indexers: Same type of indexers. - np.ndarray or Variable or DataArray - new_indexes: mapping of new dimensional-coordinate. - """ - from .dataarray import DataArray - - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers") - - v_indexers = { - k: v.variable.data if isinstance(v, DataArray) else v - for k, v in indexers.items() - } - - pos_indexers, new_indexes = indexing.remap_label_indexers( - obj, v_indexers, method=method, tolerance=tolerance - ) - # attach indexer's coordinate to pos_indexers - for k, v in indexers.items(): - if isinstance(v, Variable): - pos_indexers[k] = Variable(v.dims, pos_indexers[k]) - elif isinstance(v, DataArray): - # drop coordinates found in indexers since .sel() already - # ensures alignments - coords = {k: var for k, var in v._coords.items() if k not in indexers} - pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) - return pos_indexers, new_indexes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e04e5cb9c51..46b02019a72 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -22,6 +22,7 @@ from ..plot.plot import _PlotMethods from ..plot.utils import _get_units_from_attrs from . import ( + alignment, computation, dtypes, groupby, @@ -34,25 +35,22 @@ ) from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor -from .alignment import ( - _broadcast_helper, - _get_broadcast_dims_map_common_coords, - align, - reindex_like_indexers, -) +from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords, get_chunksizes from .computation import unify_chunks -from .coordinates import ( - DataArrayCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) -from .dataset import Dataset, split_indexes +from .coordinates import DataArrayCoordinates, assert_coordinate_consistent +from .dataset import Dataset from .formatting import format_item -from .indexes import Index, Indexes, default_indexes, propagate_indexes -from .indexing import is_fancy_indexer -from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords +from .indexes import ( + Index, + Indexes, + PandasMultiIndex, + filter_indexes_from_coords, + isel_indexes, +) +from .indexing import is_fancy_indexer, map_index_queries +from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords from .npcompat import QUANTILE_METHODS, ArrayLike from .options import OPTIONS, _get_keep_attrs from .utils import ( @@ -62,13 +60,7 @@ _default, either_dict_or_kwargs, ) -from .variable import ( - IndexVariable, - Variable, - as_compatible_data, - as_variable, - assert_unique_multiindex_level_names, -) +from .variable import IndexVariable, Variable, as_compatible_data, as_variable if TYPE_CHECKING: try: @@ -162,8 +154,6 @@ def _infer_coords_and_dims( "matching the dimension size" ) - assert_unique_multiindex_level_names(new_coords) - return new_coords, dims @@ -204,8 +194,8 @@ def __setitem__(self, key, value) -> None: labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - pos_indexers, _ = remap_label_indexers(self.data_array, key) - self.data_array[pos_indexers] = value + dim_indexers = map_index_queries(self.data_array, key).dim_indexers + self.data_array[dim_indexers] = value # Used as the key corresponding to a DataArray's variable when converting @@ -340,7 +330,7 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): _cache: dict[str, Any] _coords: dict[Any, Variable] _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _name: Hashable | None _variable: Variable @@ -370,14 +360,20 @@ def __init__( name: Hashable = None, attrs: Mapping = None, # internal parameters - indexes: dict[Hashable, pd.Index] = None, + indexes: dict[Hashable, Index] = None, fastpath: bool = False, ): if fastpath: variable = data assert dims is None assert attrs is None + assert indexes is not None else: + # TODO: (benbovy - explicit indexes) remove + # once it becomes part of the public interface + if indexes is not None: + raise ValueError("Providing explicit indexes is not supported yet") + # try to fill in arguments from data if they weren't supplied if coords is None: @@ -401,9 +397,7 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes = dict( - _extract_indexes_from_coords(coords) - ) # needed for to_dataset + indexes, coords = _create_indexes_from_coords(coords) # These fully describe a DataArray self._variable = variable @@ -413,10 +407,29 @@ def __init__( # TODO(shoyer): document this argument, once it becomes part of the # public interface. - self._indexes = indexes + self._indexes = indexes # type: ignore[assignment] self._close = None + @classmethod + def _construct_direct( + cls, + variable: Variable, + coords: dict[Any, Variable], + name: Hashable, + indexes: dict[Hashable, Index], + ) -> DataArray: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + obj = object.__new__(cls) + obj._variable = variable + obj._coords = coords + obj._name = name + obj._indexes = indexes + obj._close = None + return obj + def _replace( self: T_DataArray, variable: Variable = None, @@ -428,9 +441,11 @@ def _replace( variable = self.variable if coords is None: coords = self._coords + if indexes is None: + indexes = self._indexes if name is _default: name = self.name - return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) + return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( self, variable: Variable, name: Hashable | None | Default = _default @@ -446,37 +461,49 @@ def _replace_maybe_drop_dims( for k, v in self._coords.items() if v.shape == tuple(new_sizes[d] for d in v.dims) } - changed_dims = [ - k for k in variable.dims if variable.sizes[k] != self.sizes[k] - ] - indexes = propagate_indexes(self._indexes, exclude=changed_dims) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) else: allowed_dims = set(variable.dims) coords = { k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims } - indexes = propagate_indexes( - self._indexes, exclude=(set(self.dims) - allowed_dims) - ) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> DataArray: - if not len(indexes): + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + coords: Mapping[Any, Variable] = None, + drop_coords: list[Hashable] = None, + rename_dims: Mapping[Any, Any] = None, + ) -> DataArray: + """Maybe replace indexes and their corresponding coordinates.""" + if not indexes: return self - coords = self._coords.copy() - for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx.to_pandas_index()) - obj = self._replace(coords=coords) - - # switch from dimension to level names, if necessary - dim_names: dict[Any, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + + if coords is None: + coords = {} + if drop_coords is None: + drop_coords = [] + + new_variable = self.variable.copy() + new_coords = self._coords.copy() + new_indexes = dict(self._indexes) + + for name in indexes: + new_coords[name] = coords[name] + new_indexes[name] = indexes[name] + + for name in drop_coords: + new_coords.pop(name) + new_indexes.pop(name) + + if rename_dims: + new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims] + + return self._replace( + variable=new_variable, coords=new_coords, indexes=new_indexes + ) def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) @@ -499,8 +526,8 @@ def subset(dim, label): variables = {label: subset(dim, label) for label in self.get_index(dim)} variables.update({k: v for k, v in self._coords.items() if k != dim}) - indexes = propagate_indexes(self._indexes, exclude=dim) coord_names = set(self._coords) - {dim} + indexes = filter_indexes_from_coords(self._indexes, coord_names) dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs ) @@ -705,21 +732,6 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - @property - def _level_coords(self) -> dict[Hashable, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[Hashable, Hashable] = {} - - for cname, var in self._coords.items(): - if var.ndim == 1 and isinstance(var, IndexVariable): - level_names = var.level_names - if level_names is not None: - (dim,) = var.dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _getitem_coord(self, key): from .dataset import _get_virtual_variable @@ -727,9 +739,7 @@ def _getitem_coord(self, key): var = self._coords[key] except KeyError: dim_sizes = dict(zip(self.dims, self.shape)) - _, key, var = _get_virtual_variable( - self._coords, key, self._level_coords, dim_sizes - ) + _, key, var = _get_virtual_variable(self._coords, key, dim_sizes) return self._replace_maybe_drop_dims(var, name=key) @@ -774,7 +784,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates # uses empty dict -- everything here can already be found in self.coords. yield HybridMappingProxy(keys=self.dims, mapping={}) - yield HybridMappingProxy(keys=self._level_coords, mapping={}) def __contains__(self, key: Any) -> bool: return key in self.data @@ -817,14 +826,12 @@ def indexes(self) -> Indexes: DataArray.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._coords, self.dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property def coords(self) -> DataArrayCoordinates: @@ -852,7 +859,7 @@ def reset_coords( Dataset, or DataArray if ``drop == True`` """ if names is None: - names = set(self.coords) - set(self.dims) + names = set(self.coords) - set(self._indexes) dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) @@ -898,7 +905,8 @@ def _dask_finalize(results, name, func, *args, **kwargs): ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables - return DataArray(variable, coords, name=name, fastpath=True) + indexes = ds._indexes + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def load(self, **kwargs) -> DataArray: """Manually trigger loading of this array's data from disk or a @@ -1034,11 +1042,15 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: pandas.DataFrame.copy """ variable = self.variable.copy(deep=deep, data=data) - coords = {k: v.copy(deep=deep) for k, v in self._coords.items()} - if self._indexes is None: - indexes = self._indexes - else: - indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()} + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + coords = {} + for k, v in self._coords.items(): + if k in index_vars: + coords[k] = index_vars[k] + else: + coords[k] = v.copy(deep=deep) + return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> DataArray: @@ -1203,19 +1215,23 @@ def isel( # lists, or zero or one-dimensional np.ndarray's variable = self._variable.isel(indexers, missing_dims=missing_dims) + indexes, index_variables = isel_indexes(self.xindexes, indexers) coords = {} for coord_name, coord_value in self._coords.items(): - coord_indexers = { - k: v for k, v in indexers.items() if k in coord_value.dims - } - if coord_indexers: - coord_value = coord_value.isel(coord_indexers) - if drop and coord_value.ndim == 0: - continue + if coord_name in index_variables: + coord_value = index_variables[coord_name] + else: + coord_indexers = { + k: v for k, v in indexers.items() if k in coord_value.dims + } + if coord_indexers: + coord_value = coord_value.isel(coord_indexers) + if drop and coord_value.ndim == 0: + continue coords[coord_name] = coord_value - return self._replace(variable=variable, coords=coords) + return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( self, @@ -1460,6 +1476,37 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> DataArray: + """Callback called from ``Aligner`` to create a new reindexed DataArray.""" + + if isinstance(fill_value, dict): + fill_value = fill_value.copy() + sentinel = object() + value = fill_value.pop(self.name, sentinel) + if value is not sentinel: + fill_value[_THIS_ARRAY] = value + + ds = self._to_temp_dataset() + reindexed = ds._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + return self._from_temp_dataset(reindexed) + def reindex_like( self, other: DataArray | Dataset, @@ -1517,9 +1564,9 @@ def reindex_like( DataArray.reindex align """ - indexers = reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex_like( + self, + other=other, method=method, tolerance=tolerance, copy=copy, @@ -1606,22 +1653,15 @@ def reindex( DataArray.reindex_like align """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - if isinstance(fill_value, dict): - fill_value = fill_value.copy() - sentinel = object() - value = fill_value.pop(self.name, sentinel) - if value is not sentinel: - fill_value[_THIS_ARRAY] = value - - ds = self._to_temp_dataset().reindex( + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, indexers=indexers, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) - return self._from_temp_dataset(ds) def interp( self, @@ -2040,10 +2080,8 @@ def reset_index( -------- DataArray.set_index """ - coords, _ = split_indexes( - dims_or_levels, self._coords, set(), self._level_coords, drop=drop - ) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop) + return self._from_temp_dataset(ds) def reorder_levels( self, @@ -2068,21 +2106,14 @@ def reorder_levels( Another dataarray, with this dataarray's data but replaced coordinates. """ - dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") - replace_coords = {} - for dim, order in dim_order.items(): - coord = self._coords[dim] - index = coord.to_index() - if not isinstance(index, pd.MultiIndex): - raise ValueError(f"coordinate {dim!r} has no MultiIndex") - replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) - coords = self._coords.copy() - coords.update(replace_coords) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs) + return self._from_temp_dataset(ds) def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: bool = True, + index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> DataArray: """ @@ -2099,6 +2130,14 @@ def stack( replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type. Must be an Xarray index that + implements `.stack()`. By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -2129,13 +2168,18 @@ def stack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') See Also -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) + ds = self._to_temp_dataset().stack( + dimensions, + create_index=create_index, + index_cls=index_cls, + **dimensions_kwargs, + ) return self._from_temp_dataset(ds) def unstack( @@ -2189,7 +2233,7 @@ def unstack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') >>> roundtripped = stacked.unstack() >>> arr.identical(roundtripped) True @@ -2241,7 +2285,7 @@ def to_unstacked_dataset(self, dim, level=0): ('a', 1.0), ('a', 2.0), ('b', nan)], - names=['variable', 'y']) + name='z') >>> roundtripped = stacked.to_unstacked_dataset(dim="z") >>> data.identical(roundtripped) True @@ -2250,10 +2294,7 @@ def to_unstacked_dataset(self, dim, level=0): -------- Dataset.to_stacked_array """ - - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = self.xindexes[dim].to_pandas_index() + idx = self._indexes[dim].to_pandas_index() if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") @@ -4068,6 +4109,9 @@ def pad( For ``mode="constant"`` and ``constant_values=None``, integer types will be promoted to ``float`` and padded with ``np.nan``. + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b3112bdc7ab..e37718213aa 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -15,7 +15,6 @@ Any, Callable, Collection, - DefaultDict, Hashable, Iterable, Iterator, @@ -52,24 +51,21 @@ from .arithmetic import DatasetArithmetic from .common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes from .computation import unify_chunks -from .coordinates import ( - DatasetCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) +from .coordinates import DatasetCoordinates, assert_coordinate_consistent from .duck_array_ops import datetime_to_numeric from .indexes import ( Index, Indexes, PandasIndex, PandasMultiIndex, - default_indexes, - isel_variable_and_index, - propagate_indexes, + assert_no_index_corrupted, + create_default_index_implicit, + filter_indexes_from_coords, + isel_indexes, remove_unused_levels_categories, - roll_index, + roll_indexes, ) -from .indexing import is_fancy_indexer +from .indexing import is_fancy_indexer, map_index_queries from .merge import ( dataset_merge_method, dataset_update_method, @@ -99,8 +95,8 @@ IndexVariable, Variable, as_variable, - assert_unique_multiindex_level_names, broadcast_variables, + calculate_dimensions, ) if TYPE_CHECKING: @@ -135,13 +131,12 @@ def _get_virtual_variable( - variables, key: Hashable, level_vars: Mapping = None, dim_sizes: Mapping = None + variables, key: Hashable, dim_sizes: Mapping = None ) -> tuple[Hashable, Hashable, Variable]: - """Get a virtual variable (e.g., 'time.year' or a MultiIndex level) - from a dict of xarray.Variable objects (if possible) + """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable + objects (if possible) + """ - if level_vars is None: - level_vars = {} if dim_sizes is None: dim_sizes = {} @@ -154,202 +149,22 @@ def _get_virtual_variable( raise KeyError(key) split_key = key.split(".", 1) - var_name: str | None - if len(split_key) == 2: - ref_name, var_name = split_key - elif len(split_key) == 1: - ref_name, var_name = key, None - else: + if len(split_key) != 2: raise KeyError(key) - if ref_name in level_vars: - dim_var = variables[level_vars[ref_name]] - ref_var = dim_var.to_index_variable().get_level_variable(ref_name) - else: - ref_var = variables[ref_name] + ref_name, var_name = split_key + ref_var = variables[ref_name] - if var_name is None: - virtual_var = ref_var - var_name = key + if _contains_datetime_like_objects(ref_var): + ref_var = xr.DataArray(ref_var) + data = getattr(ref_var.dt, var_name).data else: - if _contains_datetime_like_objects(ref_var): - ref_var = xr.DataArray(ref_var) - data = getattr(ref_var.dt, var_name).data - else: - data = getattr(ref_var, var_name).data - virtual_var = Variable(ref_var.dims, data) + data = getattr(ref_var, var_name).data + virtual_var = Variable(ref_var.dims, data) return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: - """Calculate the dimensions corresponding to a set of variables. - - Returns dictionary mapping from dimension names to sizes. Raises ValueError - if any of the dimension sizes conflict. - """ - dims: dict[Hashable, int] = {} - last_used = {} - scalar_vars = {k for k, v in variables.items() if not v.dims} - for k, var in variables.items(): - for dim, size in zip(var.dims, var.shape): - if dim in scalar_vars: - raise ValueError( - f"dimension {dim!r} already exists as a scalar variable" - ) - if dim not in dims: - dims[dim] = size - last_used[dim] = k - elif dims[dim] != size: - raise ValueError( - f"conflicting sizes for dimension {dim!r}: " - f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" - ) - return dims - - -def merge_indexes( - indexes: Mapping[Any, Hashable | Sequence[Hashable]], - variables: Mapping[Any, Variable], - coord_names: set[Hashable], - append: bool = False, -) -> tuple[dict[Hashable, Variable], set[Hashable]]: - """Merge variables into multi-indexes. - - Not public API. Used in Dataset and DataArray set_index - methods. - """ - vars_to_replace: dict[Hashable, Variable] = {} - vars_to_remove: list[Hashable] = [] - dims_to_replace: dict[Hashable, Hashable] = {} - error_msg = "{} is not the name of an existing variable." - - for dim, var_names in indexes.items(): - if isinstance(var_names, str) or not isinstance(var_names, Sequence): - var_names = [var_names] - - names: list[Hashable] = [] - codes: list[list[int]] = [] - levels: list[list[int]] = [] - current_index_variable = variables.get(dim) - - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - if ( - current_index_variable is not None - and var.dims != current_index_variable.dims - ): - raise ValueError( - f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}" - ) - - if current_index_variable is not None and append: - current_index = current_index_variable.to_index() - if isinstance(current_index, pd.MultiIndex): - names.extend(current_index.names) - codes.extend(current_index.codes) - levels.extend(current_index.levels) - else: - names.append(f"{dim}_level_0") - cat = pd.Categorical(current_index.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - if not len(names) and len(var_names) == 1: - idx = pd.Index(variables[var_names[0]].values) - - else: # MultiIndex - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - names.append(n) - cat = pd.Categorical(var.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - idx = pd.MultiIndex(levels, codes, names=names) - for n in names: - dims_to_replace[n] = dim - - vars_to_replace[dim] = IndexVariable(dim, idx) - vars_to_remove.extend(var_names) - - new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} - new_variables.update(vars_to_replace) - - # update dimensions if necessary, GH: 3512 - for k, v in new_variables.items(): - if any(d in dims_to_replace for d in v.dims): - new_dims = [dims_to_replace.get(d, d) for d in v.dims] - new_variables[k] = v._replace(dims=new_dims) - new_coord_names = coord_names | set(vars_to_replace) - new_coord_names -= set(vars_to_remove) - return new_variables, new_coord_names - - -def split_indexes( - dims_or_levels: Hashable | Sequence[Hashable], - variables: Mapping[Any, Variable], - coord_names: set[Hashable], - level_coords: Mapping[Any, Hashable], - drop: bool = False, -) -> tuple[dict[Hashable, Variable], set[Hashable]]: - """Extract (multi-)indexes (levels) as variables. - - Not public API. Used in Dataset and DataArray reset_index - methods. - """ - if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): - dims_or_levels = [dims_or_levels] - - dim_levels: DefaultDict[Any, list[Hashable]] = defaultdict(list) - dims = [] - for k in dims_or_levels: - if k in level_coords: - dim_levels[level_coords[k]].append(k) - else: - dims.append(k) - - vars_to_replace = {} - vars_to_create: dict[Hashable, Variable] = {} - vars_to_remove = [] - - for d in dims: - index = variables[d].to_index() - if isinstance(index, pd.MultiIndex): - dim_levels[d] = index.names - else: - vars_to_remove.append(d) - if not drop: - vars_to_create[str(d) + "_"] = Variable(d, index, variables[d].attrs) - - for d, levs in dim_levels.items(): - index = variables[d].to_index() - if len(levs) == index.nlevels: - vars_to_remove.append(d) - else: - vars_to_replace[d] = IndexVariable(d, index.droplevel(levs)) - - if not drop: - for lev in levs: - idx = index.get_level_values(lev) - vars_to_create[idx.name] = Variable(d, idx, variables[d].attrs) - - new_variables = dict(variables) - for v in set(vars_to_remove): - del new_variables[v] - new_variables.update(vars_to_replace) - new_variables.update(vars_to_create) - new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) - - return new_variables, new_coord_names - - def _assert_empty(args: tuple, msg: str = "%s") -> None: if args: raise ValueError(msg % args) @@ -569,8 +384,8 @@ def __setitem__(self, key, value) -> None: ) # set new values - pos_indexers, _ = remap_label_indexers(self.dataset, key) - self.dataset[pos_indexers] = value + dim_indexers = map_index_queries(self.dataset, key).dim_indexers + self.dataset[dim_indexers] = value class Dataset(DataWithCoords, DatasetArithmetic, Mapping): @@ -702,7 +517,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): _dims: dict[Hashable, int] _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _variables: dict[Hashable, Variable] __slots__ = ( @@ -1080,6 +895,8 @@ def _construct_direct( """ if dims is None: dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -1096,7 +913,7 @@ def _replace( coord_names: set[Hashable] = None, dims: dict[Any, int] = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -1117,7 +934,7 @@ def _replace( self._dims = dims if attrs is not _default: self._attrs = attrs - if indexes is not _default: + if indexes is not None: self._indexes = indexes if encoding is not _default: self._encoding = encoding @@ -1131,8 +948,8 @@ def _replace( dims = self._dims.copy() if attrs is _default: attrs = copy.copy(self._attrs) - if indexes is _default: - indexes = copy.copy(self._indexes) + if indexes is None: + indexes = self._indexes.copy() if encoding is _default: encoding = copy.copy(self._encoding) obj = self._construct_direct( @@ -1145,7 +962,7 @@ def _replace_with_new_dims( variables: dict[Hashable, Variable], coord_names: set = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, inplace: bool = False, ) -> Dataset: """Replace variables with recalculated dimensions.""" @@ -1173,26 +990,79 @@ def _replace_vars_and_dims( variables, coord_names, dims, attrs, indexes=None, inplace=inplace ) - def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> Dataset: + def _overwrite_indexes( + self, + indexes: Mapping[Hashable, Index], + variables: Mapping[Hashable, Variable] = None, + drop_variables: list[Hashable] = None, + drop_indexes: list[Hashable] = None, + rename_dims: Mapping[Hashable, Hashable] = None, + ) -> Dataset: + """Maybe replace indexes. + + This function may do a lot more depending on index query + results. + + """ if not indexes: return self - variables = self._variables.copy() - new_indexes = dict(self.xindexes) - for name, idx in indexes.items(): - variables[name] = IndexVariable(name, idx.to_pandas_index()) - new_indexes[name] = idx - obj = self._replace(variables, indexes=new_indexes) - - # switch from dimension to level names, if necessary - dim_names: dict[Hashable, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = pd_idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + if variables is None: + variables = {} + if drop_variables is None: + drop_variables = [] + if drop_indexes is None: + drop_indexes = [] + + new_variables = self._variables.copy() + new_coord_names = self._coord_names.copy() + new_indexes = dict(self._indexes) + + index_variables = {} + no_index_variables = {} + for name, var in variables.items(): + old_var = self._variables.get(name) + if old_var is not None: + var.attrs.update(old_var.attrs) + var.encoding.update(old_var.encoding) + if name in indexes: + index_variables[name] = var + else: + no_index_variables[name] = var + + for name in indexes: + new_indexes[name] = indexes[name] + + for name, var in index_variables.items(): + new_coord_names.add(name) + new_variables[name] = var + + # append no-index variables at the end + for k in no_index_variables: + new_variables.pop(k) + new_variables.update(no_index_variables) + + for name in drop_indexes: + new_indexes.pop(name) + + for name in drop_variables: + new_variables.pop(name) + new_indexes.pop(name, None) + new_coord_names.remove(name) + + replaced = self._replace( + variables=new_variables, coord_names=new_coord_names, indexes=new_indexes + ) + + if rename_dims: + # skip rename indexes: they should already have the right name(s) + dims = replaced._rename_dims(rename_dims) + new_variables, new_coord_names = replaced._rename_vars({}, rename_dims) + return replaced._replace( + variables=new_variables, coord_names=new_coord_names, dims=dims + ) + else: + return replaced def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: """Returns a copy of this dataset. @@ -1292,10 +1162,11 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: pandas.DataFrame.copy """ if data is None: - variables = {k: v.copy(deep=deep) for k, v in self._variables.items()} + data = {} elif not utils.is_dict_like(data): raise ValueError("Data must be dict-like") - else: + + if data: var_keys = set(self.data_vars.keys()) data_keys = set(data.keys()) keys_not_in_vars = data_keys - var_keys @@ -1310,14 +1181,19 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: "Data must contain all variables in original " "dataset. Data is missing {}".format(keys_missing_from_data) ) - variables = { - k: v.copy(deep=deep, data=data.get(k)) - for k, v in self._variables.items() - } + + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + variables = {} + for k, v in self._variables.items(): + if k in index_vars: + variables[k] = index_vars[k] + else: + variables[k] = v.copy(deep=deep, data=data.get(k)) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) - return self._replace(variables, attrs=attrs) + return self._replace(variables, indexes=indexes, attrs=attrs) def as_numpy(self: Dataset) -> Dataset: """ @@ -1331,21 +1207,6 @@ def as_numpy(self: Dataset) -> Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - @property - def _level_coords(self) -> dict[str, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[str, Hashable] = {} - for name, index in self.xindexes.items(): - # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. - pd_index = index.to_pandas_index() - if isinstance(pd_index, pd.MultiIndex): - level_names = pd_index.names - (dim,) = self.variables[name].dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. @@ -1359,13 +1220,16 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims + self._variables, name, self.dims ) variables[var_name] = var if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) if (var_name,) == var.dims: - indexes[var_name] = var._to_xindex() + index, index_vars = create_default_index_implicit(var, names) + indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) needed_dims: OrderedSet[Hashable] = OrderedSet() for v in variables.values(): @@ -1381,8 +1245,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) - if k in self.xindexes: - indexes[k] = self.xindexes[k] + + indexes.update(filter_indexes_from_coords(self._indexes, coord_names)) return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1393,9 +1257,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: try: variable = self._variables[name] except KeyError: - _, name, variable = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims - ) + _, name, variable = _get_virtual_variable(self._variables, name, self.dims) needed_dims = set(variable.dims) @@ -1405,10 +1267,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] - if self._indexes is None: - indexes = None - else: - indexes = {k: v for k, v in self._indexes.items() if k in coords} + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) @@ -1435,9 +1294,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) - # uses empty dict -- everything here can already be found in self.coords. - yield HybridMappingProxy(keys=self._level_coords, mapping={}) - def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether 'key' is an array in the dataset or not. @@ -1628,11 +1484,12 @@ def _setitem_check(self, key, value): def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" + assert_no_index_corrupted(self.xindexes, {key}) + + if key in self._indexes: + del self._indexes[key] del self._variables[key] self._coord_names.discard(key) - if key in self.xindexes: - assert self._indexes is not None - del self._indexes[key] self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable @@ -1706,7 +1563,7 @@ def identical(self, other: Dataset) -> bool: return False @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: """Mapping of pandas.Index objects used for label based indexing. Raises an error if this Dataset has indexes that cannot be coerced @@ -1717,14 +1574,12 @@ def indexes(self) -> Indexes: Dataset.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._variables, self._dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: @@ -1788,14 +1643,14 @@ def reset_coords( Dataset """ if names is None: - names = self._coord_names - set(self.dims) + names = self._coord_names - set(self._indexes) else: if isinstance(names, str) or not isinstance(names, Iterable): names = [names] else: names = list(names) self._assert_all_in_dataset(names) - bad_coords = set(names) & set(self.dims) + bad_coords = set(names) & set(self._indexes) if bad_coords: raise ValueError( f"cannot remove index coordinates with reset_coords: {bad_coords}" @@ -2220,9 +2075,7 @@ def _validate_indexers( v = np.asarray(v) if v.dtype.kind in "US": - # TODO: benbovy - flexible indexes - # update when CFTimeIndex has its own xarray index class - index = self.xindexes[k].to_pandas_index() + index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): @@ -2358,24 +2211,22 @@ def isel( variables = {} dims: dict[Hashable, int] = {} coord_names = self._coord_names.copy() - indexes = self._indexes.copy() if self._indexes is not None else None - - for var_name, var_value in self._variables.items(): - var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims} - if var_indexers: - var_value = var_value.isel(var_indexers) - if drop and var_value.ndim == 0 and var_name in coord_names: - coord_names.remove(var_name) - if indexes: - indexes.pop(var_name, None) - continue - if indexes and var_name in indexes: - if var_value.ndim == 1: - indexes[var_name] = var_value._to_xindex() - else: - del indexes[var_name] - variables[var_name] = var_value - dims.update(zip(var_value.dims, var_value.shape)) + + indexes, index_variables = isel_indexes(self.xindexes, indexers) + + for name, var in self._variables.items(): + # preserve variable order + if name in index_variables: + var = index_variables[name] + else: + var_indexers = {k: v for k, v in indexers.items() if k in var.dims} + if var_indexers: + var = var.isel(var_indexers) + if drop and var.ndim == 0 and name in coord_names: + coord_names.remove(name) + continue + variables[name] = var + dims.update(zip(var.dims, var.shape)) return self._construct_direct( variables=variables, @@ -2394,29 +2245,22 @@ def _isel_fancy( drop: bool, missing_dims: str = "raise", ) -> Dataset: - # Note: we need to preserve the original indexers variable in order to merge the - # coords below - indexers_list = list(self._validate_indexers(indexers, missing_dims)) + valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} - indexes: dict[Hashable, Index] = {} + indexes, index_variables = isel_indexes(self.xindexes, valid_indexers) for name, var in self.variables.items(): - var_indexers = {k: v for k, v in indexers_list if k in var.dims} - if drop and name in var_indexers: - continue # drop this variable - - if name in self.xindexes: - new_var, new_index = isel_variable_and_index( - name, var, self.xindexes[name], var_indexers - ) - if new_index is not None: - indexes[name] = new_index - elif var_indexers: - new_var = var.isel(indexers=var_indexers) + if name in index_variables: + new_var = index_variables[name] else: - new_var = var.copy(deep=False) - + var_indexers = { + k: v for k, v in valid_indexers.items() if k in var.dims + } + if var_indexers: + new_var = var.isel(indexers=var_indexers) + else: + new_var = var.copy(deep=False) variables[name] = new_var coord_names = self._coord_names & variables.keys() @@ -2433,7 +2277,7 @@ def sel( self, indexers: Mapping[Any, Any] = None, method: str = None, - tolerance: Number = None, + tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, ) -> Dataset: @@ -2498,15 +2342,22 @@ def sel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - pos_indexers, new_indexes = remap_label_indexers( + query_results = map_index_queries( self, indexers=indexers, method=method, tolerance=tolerance ) - # TODO: benbovy - flexible indexes: also use variables returned by Index.query - # (temporary dirty fix). - new_indexes = {k: v[0] for k, v in new_indexes.items()} - result = self.isel(indexers=pos_indexers, drop=drop) - return result._overwrite_indexes(new_indexes) + if drop: + no_scalar_variables = {} + for k, v in query_results.variables.items(): + if v.dims: + no_scalar_variables[k] = v + else: + if k in self._coord_names: + query_results.drop_coords.append(k) + query_results.variables = no_scalar_variables + + result = self.isel(indexers=query_results.dim_indexers, drop=drop) + return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( self, @@ -2676,6 +2527,58 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Dataset: + """Callback called from ``Aligner`` to create a new reindexed Dataset.""" + + new_variables = variables.copy() + new_indexes = indexes.copy() + + # pass through indexes from excluded dimensions + # no extra check needed for multi-coordinate indexes, potential conflicts + # should already have been detected when aligning the indexes + for name, idx in self._indexes.items(): + var = self._variables[name] + if set(var.dims) <= exclude_dims: + new_indexes[name] = idx + new_variables[name] = var + + if not dim_pos_indexers: + # fast path for no reindexing necessary + if set(new_indexes) - set(self._indexes): + # this only adds new indexes and their coordinate variables + reindexed = self._overwrite_indexes(new_indexes, new_variables) + else: + reindexed = self.copy(deep=aligner.copy) + else: + to_reindex = { + k: v + for k, v in self.variables.items() + if k not in variables and k not in exclude_vars + } + reindexed_vars = alignment.reindex_variables( + to_reindex, + dim_pos_indexers, + copy=aligner.copy, + fill_value=fill_value, + sparse=aligner.sparse, + ) + new_variables.update(reindexed_vars) + new_coord_names = self._coord_names | set(new_indexes) + reindexed = self._replace_with_new_dims( + new_variables, new_coord_names, indexes=new_indexes + ) + + return reindexed + def reindex_like( self, other: Dataset | DataArray, @@ -2732,13 +2635,13 @@ def reindex_like( Dataset.reindex align """ - indexers = alignment.reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex_like( + self, + other=other, method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, - tolerance=tolerance, ) def reindex( @@ -2822,6 +2725,7 @@ def reindex( temperature (station) float64 10.98 14.3 12.06 10.9 pressure (station) float64 211.8 322.9 218.8 445.9 >>> x.indexes + Indexes: station: Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') Create a new index and reindex the dataset. By default values in the new index that @@ -2945,14 +2849,14 @@ def reindex( original dataset, use the :py:meth:`~Dataset.fillna()` method. """ - return self._reindex( - indexers, - method, - tolerance, - copy, - fill_value, - sparse=False, - **indexers_kwargs, + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, ) def _reindex( @@ -2966,28 +2870,18 @@ def _reindex( **indexers_kwargs: Any, ) -> Dataset: """ - same to _reindex but support sparse option + Same as reindex but supports sparse option. """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - - bad_dims = [d for d in indexers if d not in self.dims] - if bad_dims: - raise ValueError(f"invalid reindex dimensions: {bad_dims}") - - variables, indexes = alignment.reindex_variables( - self.variables, - self.sizes, - self.xindexes, - indexers, - method, - tolerance, + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, sparse=sparse, ) - coord_names = set(self._coord_names) - coord_names.update(indexers) - return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp( self, @@ -3180,7 +3074,7 @@ def _validate_interp_indexer(x, new_x): } variables: dict[Hashable, Variable] = {} - to_reindex: dict[Hashable, Variable] = {} + reindex: bool = False for name, var in obj._variables.items(): if name in indexers: continue @@ -3198,42 +3092,49 @@ def _validate_interp_indexer(x, new_x): elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): # For types that we do not understand do stepwise # interpolation to avoid modifying the elements. - # Use reindex_variables instead because it supports + # reindex the variable instead because it supports # booleans and objects and retains the dtype but inside # this loop there might be some duplicate code that slows it # down, therefore collect these signals and run it later: - to_reindex[name] = var + reindex = True elif all(d not in indexers for d in var.dims): # For anything else we can only keep variables if they # are not dependent on any coords that are being # interpolated along: variables[name] = var - if to_reindex: - # Reindex variables: - variables_reindex = alignment.reindex_variables( - variables=to_reindex, - sizes=obj.sizes, - indexes=obj.xindexes, - indexers={k: v[-1] for k, v in validated_indexers.items()}, + if reindex: + reindex_indexers = { + k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,) + } + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, method=method_non_numeric, - )[0] - variables.update(variables_reindex) + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed._indexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() - # Get the indexes that are not being interpolated along: - indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) # Attach indexer as coordinate - variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) if v.dims == (k,): - indexes[k] = v._to_xindex() + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables({k: v}) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) @@ -3294,7 +3195,14 @@ def interp_like( """ if kwargs is None: kwargs = {} - coords = alignment.reindex_like_indexers(self, other) + + # pick only dimension coordinates with a single index + coords = {} + other_indexes = other.xindexes + for dim in self.dims: + other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") + if len(other_dim_coords) == 1: + coords[dim] = other_dim_coords[dim] numeric_coords: dict[Hashable, pd.Index] = {} object_coords: dict[Hashable, pd.Index] = {} @@ -3335,28 +3243,34 @@ def _rename_vars(self, name_dict, dims_dict): def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} - def _rename_indexes(self, name_dict, dims_set): - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645 - if self._indexes is None: - return None + def _rename_indexes(self, name_dict, dims_dict): + if not self._indexes: + return {}, {} + indexes = {} - for k, v in self.indexes.items(): - new_name = name_dict.get(k, k) - if new_name not in dims_set: - continue - if isinstance(v, pd.MultiIndex): - new_names = [name_dict.get(k, k) for k in v.names] - indexes[new_name] = PandasMultiIndex( - v.rename(names=new_names), new_name - ) - else: - indexes[new_name] = PandasIndex(v.rename(new_name), new_name) - return indexes + variables = {} + + for index, coord_names in self.xindexes.group_by_index(): + new_index = index.rename(name_dict, dims_dict) + new_coord_names = [name_dict.get(k, k) for k in coord_names] + indexes.update({k: new_index for k in new_coord_names}) + new_index_vars = new_index.create_variables( + { + new: self._variables[old] + for old, new in zip(coord_names, new_coord_names) + } + ) + variables.update(new_index_vars) + + return indexes, variables def _rename_all(self, name_dict, dims_dict): variables, coord_names = self._rename_vars(name_dict, dims_dict) dims = self._rename_dims(dims_dict) - indexes = self._rename_indexes(name_dict, dims.keys()) + + indexes, index_vars = self._rename_indexes(name_dict, dims_dict) + variables = {k: index_vars.get(k, v) for k, v in variables.items()} + return variables, coord_names, dims, indexes def rename( @@ -3398,7 +3312,6 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) - assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( @@ -3572,21 +3485,19 @@ def swap_dims( dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: var = v.to_index_variable() - if k in self.xindexes: - indexes[k] = self.xindexes[k] + var.dims = dims + if k in self._indexes: + indexes[k] = self._indexes[k] + variables[k] = var else: - new_index = var.to_index() - if new_index.nlevels == 1: - # make sure index name matches dimension name - new_index = new_index.rename(k) - if isinstance(new_index, pd.MultiIndex): - indexes[k] = PandasMultiIndex(new_index, k) - else: - indexes[k] = PandasIndex(new_index, k) + index, index_vars = create_default_index_implicit(var) + indexes.update({name: index for name in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) else: var = v.to_base_variable() - var.dims = dims - variables[k] = var + var.dims = dims + variables[k] = var return self._replace_with_new_dims(variables, coord_names, indexes=indexes) @@ -3666,6 +3577,7 @@ def expand_dims( ) variables: dict[Hashable, Variable] = {} + indexes: dict[Hashable, Index] = dict(self._indexes) coord_names = self._coord_names.copy() # If dim is a dict, then ensure that the values are either integers # or iterables. @@ -3675,7 +3587,9 @@ def expand_dims( # save the coordinates to the variables dict, and set the # value within the dim dict to the length of the iterable # for later use. - variables[k] = xr.IndexVariable((k,), v) + index = PandasIndex(v, k) + indexes[k] = index + variables.update(index.create_variables()) coord_names.add(k) dim[k] = variables[k].size elif isinstance(v, int): @@ -3711,15 +3625,15 @@ def expand_dims( all_dims.insert(d, c) variables[k] = v.set_dims(dict(all_dims)) else: - # If dims includes a label of a non-dimension coordinate, - # it will be promoted to a 1D coordinate with a single value. - variables[k] = v.set_dims(k).to_index_variable() - - new_dims = self._dims.copy() - new_dims.update(dim) + if k not in variables: + # If dims includes a label of a non-dimension coordinate, + # it will be promoted to a 1D coordinate with a single value. + index, index_vars = create_default_index_implicit(v.set_dims(k)) + indexes[k] = index + variables.update(index_vars) - return self._replace_vars_and_dims( - variables, dims=new_dims, coord_names=coord_names + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes ) def set_index( @@ -3780,11 +3694,87 @@ def set_index( Dataset.reset_index Dataset.swap_dims """ - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") - variables, coord_names = merge_indexes( - indexes, self._variables, self._coord_names, append=append + dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") + + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + maybe_drop_indexes: list[Hashable] = [] + drop_variables: list[Hashable] = [] + replace_dims: dict[Hashable, Hashable] = {} + + for dim, _var_names in dim_coords.items(): + if isinstance(_var_names, str) or not isinstance(_var_names, Sequence): + var_names = [_var_names] + else: + var_names = list(_var_names) + + invalid_vars = set(var_names) - set(self._variables) + if invalid_vars: + raise ValueError( + ", ".join([str(v) for v in invalid_vars]) + + " variable(s) do not exist" + ) + + current_coord_names = self.xindexes.get_all_coords(dim, errors="ignore") + + # drop any pre-existing index involved + maybe_drop_indexes += list(current_coord_names) + var_names + for k in var_names: + maybe_drop_indexes += list( + self.xindexes.get_all_coords(k, errors="ignore") + ) + + drop_variables += var_names + + if len(var_names) == 1 and (not append or dim not in self._indexes): + var_name = var_names[0] + var = self._variables[var_name] + if var.dims != (dim,): + raise ValueError( + f"dimension mismatch: try setting an index for dimension {dim!r} with " + f"variable {var_name!r} that has dimensions {var.dims}" + ) + idx = PandasIndex.from_variables({dim: var}) + idx_vars = idx.create_variables({var_name: var}) + else: + if append: + current_variables = { + k: self._variables[k] for k in current_coord_names + } + else: + current_variables = {} + idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( + dim, + current_variables, + {k: self._variables[k] for k in var_names}, + ) + for n in idx.index.names: + replace_dims[n] = dim + + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes_: dict[Any, Index] = { + k: v for k, v in self._indexes.items() if k not in maybe_drop_indexes + } + indexes_.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + # update dimensions if necessary, GH: 3512 + for k, v in variables.items(): + if any(d in replace_dims for d in v.dims): + new_dims = [replace_dims.get(d, d) for d in v.dims] + variables[k] = v._replace(dims=new_dims) + + coord_names = self._coord_names - set(drop_variables) | set(new_variables) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes_ ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) def reset_index( self, @@ -3811,14 +3801,56 @@ def reset_index( -------- Dataset.set_index """ - variables, coord_names = split_indexes( - dims_or_levels, - self._variables, - self._coord_names, - cast(Mapping[Hashable, Hashable], self._level_coords), - drop=drop, - ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) + if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): + dims_or_levels = [dims_or_levels] + + invalid_coords = set(dims_or_levels) - set(self._indexes) + if invalid_coords: + raise ValueError( + f"{tuple(invalid_coords)} are not coordinates with an index" + ) + + drop_indexes: list[Hashable] = [] + drop_variables: list[Hashable] = [] + replaced_indexes: list[PandasMultiIndex] = [] + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + + for name in dims_or_levels: + index = self._indexes[name] + drop_indexes += list(self.xindexes.get_all_coords(name)) + + if isinstance(index, PandasMultiIndex) and name not in self.dims: + # special case for pd.MultiIndex (name is an index level): + # replace by a new index with dropped level(s) instead of just drop the index + if index not in replaced_indexes: + level_names = index.index.names + level_vars = { + k: self._variables[k] + for k in level_names + if k not in dims_or_levels + } + if level_vars: + idx = index.keep_levels(level_vars) + idx_vars = idx.create_variables(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + replaced_indexes.append(index) + + if drop: + drop_variables.append(name) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + coord_names = set(new_variables) | self._coord_names + + return self._replace(variables, coord_names=coord_names, indexes=indexes) def reorder_levels( self, @@ -3845,61 +3877,149 @@ def reorder_levels( """ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() - indexes = dict(self.xindexes) + indexes = dict(self._indexes) + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + for dim, order in dim_order.items(): - coord = self._variables[dim] - # TODO: benbovy - flexible indexes: update when MultiIndex - # has its own class inherited from xarray.Index - index = self.xindexes[dim].to_pandas_index() - if not isinstance(index, pd.MultiIndex): + index = self._indexes[dim] + + if not isinstance(index, PandasMultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") - new_index = index.reorder_levels(order) - variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = PandasMultiIndex(new_index, dim) + + level_vars = {k: self._variables[k] for k in order} + idx = index.reorder_levels(level_vars) + idx_vars = idx.create_variables(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in new_indexes} + indexes.update(new_indexes) + + variables = {k: v for k, v in self._variables.items() if k not in new_variables} + variables.update(new_variables) return self._replace(variables, indexes=indexes) - def _stack_once(self, dims, new_dim): + def _get_stack_index( + self, + dim, + multi=False, + create_index=False, + ) -> tuple[Index | None, dict[Hashable, Variable]]: + """Used by stack and unstack to get one pandas (multi-)index among + the indexed coordinates along dimension `dim`. + + If exactly one index is found, return it with its corresponding + coordinate variables(s), otherwise return None and an empty dict. + + If `create_index=True`, create a new index if none is found or raise + an error if multiple indexes are found. + + """ + stack_index: Index | None = None + stack_coords: dict[Hashable, Variable] = {} + + for name, index in self._indexes.items(): + var = self._variables[name] + if ( + var.ndim == 1 + and var.dims[0] == dim + and ( + # stack: must be a single coordinate index + not multi + and not self.xindexes.is_multi(name) + # unstack: must be an index that implements .unstack + or multi + and type(index).unstack is not Index.unstack + ) + ): + if stack_index is not None and index is not stack_index: + # more than one index found, stop + if create_index: + raise ValueError( + f"cannot stack dimension {dim!r} with `create_index=True` " + "and with more than one index found along that dimension" + ) + return None, {} + stack_index = index + stack_coords[name] = var + + if create_index and stack_index is None: + if dim in self._variables: + var = self._variables[dim] + else: + _, _, var = _get_virtual_variable(self._variables, dim, self.dims) + # dummy index (only `stack_coords` will be used to construct the multi-index) + stack_index = PandasIndex([0], dim) + stack_coords = {dim: var} + + return stack_index, stack_coords + + def _stack_once(self, dims, new_dim, index_cls, create_index=True): if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: dims = list(infix_dims(dims, self.dims)) - variables = {} - for name, var in self.variables.items(): - if name not in dims: - if any(d in var.dims for d in dims): - add_dims = [d for d in dims if d not in var.dims] - vdims = list(var.dims) + add_dims - shape = [self.dims[d] for d in vdims] - exp_var = var.set_dims(vdims, shape) - stacked_var = exp_var.stack(**{new_dim: dims}) - variables[name] = stacked_var - else: - variables[name] = var.copy(deep=False) - # consider dropping levels that are unused? - levels = [self.get_index(dim) for dim in dims] - idx = utils.multiindex_from_product_levels(levels, names=dims) - variables[new_dim] = IndexVariable(new_dim, idx) + new_variables: dict[Hashable, Variable] = {} + stacked_var_names: list[Hashable] = [] + drop_indexes: list[Hashable] = [] - coord_names = set(self._coord_names) - set(dims) | {new_dim} - - indexes = {k: v for k, v in self.xindexes.items() if k not in dims} - indexes[new_dim] = PandasMultiIndex(idx, new_dim) + for name, var in self.variables.items(): + if any(d in var.dims for d in dims): + add_dims = [d for d in dims if d not in var.dims] + vdims = list(var.dims) + add_dims + shape = [self.dims[d] for d in vdims] + exp_var = var.set_dims(vdims, shape) + stacked_var = exp_var.stack(**{new_dim: dims}) + new_variables[name] = stacked_var + stacked_var_names.append(name) + else: + new_variables[name] = var.copy(deep=False) + + # drop indexes of stacked coordinates (if any) + for name in stacked_var_names: + drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore")) + + new_indexes = {} + new_coord_names = set(self._coord_names) + if create_index or create_index is None: + product_vars: dict[Any, Variable] = {} + for dim in dims: + idx, idx_vars = self._get_stack_index(dim, create_index=create_index) + if idx is not None: + product_vars.update(idx_vars) + + if len(product_vars) == len(dims): + idx = index_cls.stack(product_vars, new_dim) + new_indexes[new_dim] = idx + new_indexes.update({k: idx for k in product_vars}) + idx_vars = idx.create_variables(product_vars) + # keep consistent multi-index coordinate order + for k in idx_vars: + new_variables.pop(k, None) + new_variables.update(idx_vars) + new_coord_names.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes + new_variables, coord_names=new_coord_names, indexes=indexes ) def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: bool | None = True, + index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> Dataset: """ Stack any number of existing dimensions into a single new dimension. - New dimensions will be added at the end, and the corresponding + New dimensions will be added at the end, and by default the corresponding coordinate variables will be combined into a MultiIndex. Parameters @@ -3910,6 +4030,14 @@ def stack( ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type (must be an Xarray index that + implements `.stack()`). By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -3926,7 +4054,7 @@ def stack( dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): - result = result._stack_once(dims, new_dim) + result = result._stack_once(dims, new_dim, index_cls, create_index) return result def to_stacked_array( @@ -3997,9 +4125,9 @@ def to_stacked_array( array([[0, 1, 2, 6], [3, 4, 5, 7]]) Coordinates: - * z (z) MultiIndex - - variable (z) object 'a' 'a' 'a' 'b' - - y (z) object 'u' 'v' 'w' nan + * z (z) object MultiIndex + * variable (z) object 'a' 'a' 'a' 'b' + * y (z) object 'u' 'v' 'w' nan Dimensions without coordinates: x """ @@ -4035,32 +4163,30 @@ def ensure_stackable(val): stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] data_array = xr.concat(stackable_vars, dim=new_dim) - # coerce the levels of the MultiIndex to have the same type as the - # input dimensions. This code is messy, so it might be better to just - # input a dummy value for the singleton dimension. - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = data_array.xindexes[new_dim].to_pandas_index() - levels = [idx.levels[0]] + [ - level.astype(self[level.name].dtype) for level in idx.levels[1:] - ] - new_idx = idx.set_levels(levels) - data_array[new_dim] = IndexVariable(new_dim, new_idx) - if name is not None: data_array.name = name return data_array - def _unstack_once(self, dim: Hashable, fill_value, sparse: bool = False) -> Dataset: - index = self.get_index(dim) - index = remove_unused_levels_categories(index) - + def _unstack_once( + self, + dim: Hashable, + index_and_vars: tuple[Index, dict[Hashable, Variable]], + fill_value, + sparse: bool = False, + ) -> Dataset: + index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + for name, idx in new_indexes.items(): + variables.update(idx.create_variables(index_vars)) for name, var in self.variables.items(): - if name != dim: + if name not in index_vars: if dim in var.dims: if isinstance(fill_value, Mapping): fill_value_ = fill_value[name] @@ -4068,55 +4194,66 @@ def _unstack_once(self, dim: Hashable, fill_value, sparse: bool = False) -> Data fill_value_ = fill_value variables[name] = var._unstack_once( - index=index, dim=dim, fill_value=fill_value_, sparse=sparse + index=clean_index, + dim=dim, + fill_value=fill_value_, + sparse=sparse, ) else: variables[name] = var - for name, lev in zip(index.names, index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(index.names) + coord_names = set(self._coord_names) - {dim} | set(new_indexes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) - def _unstack_full_reindex(self, dim: Hashable, fill_value, sparse: bool) -> Dataset: - index = self.get_index(dim) - index = remove_unused_levels_categories(index) - full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) + def _unstack_full_reindex( + self, + dim: Hashable, + index_and_vars: tuple[Index, dict[Hashable, Variable]], + fill_value, + sparse: bool, + ) -> Dataset: + index, index_vars = index_and_vars + variables: dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + new_index_variables = {} + for name, idx in new_indexes.items(): + new_index_variables.update(idx.create_variables(index_vars)) + + new_dim_sizes = {k: v.size for k, v in new_index_variables.items()} + variables.update(new_index_variables) # take a shortcut in case the MultiIndex was not modified. - if index.equals(full_idx): + full_idx = pd.MultiIndex.from_product( + clean_index.levels, names=clean_index.names + ) + if clean_index.equals(full_idx): obj = self else: + # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex + xr_full_idx = PandasMultiIndex(full_idx, dim) + indexers = Indexes( + {k: xr_full_idx for k in index_vars}, + xr_full_idx.create_variables(index_vars), + ) obj = self._reindex( - {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse + indexers, copy=False, fill_value=fill_value, sparse=sparse ) - new_dim_names = index.names - new_dim_sizes = [lev.size for lev in index.levels] - - variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} - for name, var in obj.variables.items(): - if name != dim: + if name not in index_vars: if dim in var.dims: - new_dims = dict(zip(new_dim_names, new_dim_sizes)) - variables[name] = var.unstack({dim: new_dims}) + variables[name] = var.unstack({dim: new_dim_sizes}) else: variables[name] = var - for name, lev in zip(new_dim_names, index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(new_dim_names) + coord_names = set(self._coord_names) - {dim} | set(new_dim_sizes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -4155,10 +4292,9 @@ def unstack( -------- Dataset.stack """ + if dim is None: - dims = [ - d for d in self.dims if isinstance(self.get_index(d), pd.MultiIndex) - ] + dims = list(self.dims) else: if isinstance(dim, str) or not isinstance(dim, Iterable): dims = [dim] @@ -4171,13 +4307,21 @@ def unstack( f"Dataset does not contain the dimensions: {missing_dims}" ) - non_multi_dims = [ - d for d in dims if not isinstance(self.get_index(d), pd.MultiIndex) - ] + # each specified dimension must have exactly one multi-index + stacked_indexes: dict[Any, tuple[Index, dict[Hashable, Variable]]] = {} + for d in dims: + idx, idx_vars = self._get_stack_index(d, multi=True) + if idx is not None: + stacked_indexes[d] = idx, idx_vars + + if dim is None: + dims = list(stacked_indexes) + else: + non_multi_dims = set(dims) - set(stacked_indexes) if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - f"have a MultiIndex: {non_multi_dims}" + f"have exactly one multi-index: {tuple(non_multi_dims)}" ) result = self.copy(deep=False) @@ -4187,7 +4331,7 @@ def unstack( # We only check the non-index variables. # https://github.com/pydata/xarray/issues/5902 nonindexes = [ - self.variables[k] for k in set(self.variables) - set(self.xindexes) + self.variables[k] for k in set(self.variables) - set(self._indexes) ] # Notes for each of these cases: # 1. Dask arrays don't support assignment by index, which the fast unstack @@ -4209,9 +4353,13 @@ def unstack( for dim in dims: if needs_full_reindex: - result = result._unstack_full_reindex(dim, fill_value, sparse) + result = result._unstack_full_reindex( + dim, stacked_indexes[dim], fill_value, sparse + ) else: - result = result._unstack_once(dim, fill_value, sparse) + result = result._unstack_once( + dim, stacked_indexes[dim], fill_value, sparse + ) return result def update(self, other: CoercibleMapping) -> Dataset: @@ -4378,9 +4526,11 @@ def drop_vars( if errors == "raise": self._assert_all_in_dataset(names) + assert_no_index_corrupted(self.xindexes, names) + variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k not in names} + indexes = {k: v for k, v in self._indexes.items() if k not in names} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -5094,7 +5244,7 @@ def reduce( ) coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes @@ -5297,15 +5447,16 @@ def to_array(self, dim="variable", name=None): broadcast_vars = broadcast_variables(*data_vars) data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0) - coords = dict(self.coords) - coords[dim] = list(self.data_vars) - indexes = propagate_indexes(self._indexes) - dims = (dim,) + broadcast_vars[0].dims + variable = Variable(dims, data, self.attrs, fastpath=True) - return DataArray( - data, coords, dims, attrs=self.attrs, name=name, indexes=indexes - ) + coords = {k: v.variable for k, v in self.coords.items()} + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + new_dim_index = PandasIndex(list(self.data_vars), dim) + indexes[dim] = new_dim_index + coords.update(new_dim_index.create_variables()) + + return DataArray._construct_direct(variable, coords, name, indexes) def _normalize_dim_order( self, dim_order: list[Hashable] = None @@ -5517,7 +5668,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Datase # forwarding arguments to pandas.Series.to_numpy? arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] - obj = cls() + indexes: dict[Hashable, Index] = {} + index_vars: dict[Hashable, Variable] = {} if isinstance(idx, pd.MultiIndex): dims = tuple( @@ -5525,11 +5677,17 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Datase for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels): - obj[dim] = (dim, lev) + xr_idx = PandasIndex(lev, dim) + indexes[dim] = xr_idx + index_vars.update(xr_idx.create_variables()) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) - obj[index_name] = (dims, idx) + xr_idx = PandasIndex(idx, index_name) + indexes[index_name] = xr_idx + index_vars.update(xr_idx.create_variables()) + + obj = cls._construct_direct(index_vars, set(index_vars), indexes=indexes) if sparse: obj._set_sparse_data_from_dataframe(idx, arrays, dims) @@ -5888,10 +6046,13 @@ def diff(self, dim, n=1, label="upper"): else: raise ValueError("The 'label' argument has to be either 'upper' or 'lower'") + indexes, index_vars = isel_indexes(self.xindexes, kwargs_new) variables = {} for name, var in self.variables.items(): - if dim in var.dims: + if name in index_vars: + variables[name] = index_vars[name] + elif dim in var.dims: if name in self.data_vars: variables[name] = var.isel(**kwargs_end) - var.isel(**kwargs_start) else: @@ -5899,16 +6060,6 @@ def diff(self, dim, n=1, label="upper"): else: variables[name] = var - indexes = dict(self.xindexes) - if dim in indexes: - if isinstance(indexes[dim], PandasIndex): - # maybe optimize? (pandas index already indexed above with var.isel) - new_index = indexes[dim].index[kwargs_new[dim]] - if isinstance(new_index, pd.MultiIndex): - indexes[dim] = PandasMultiIndex(new_index, dim) - else: - indexes[dim] = PandasIndex(new_index, dim) - difference = self._replace_with_new_dims(variables, indexes=indexes) if n > 1: @@ -6047,29 +6198,27 @@ def roll( if invalid: raise ValueError(f"dimensions {invalid!r} do not exist") - unrolled_vars = () if roll_coords else self.coords + unrolled_vars: tuple[Hashable, ...] + + if roll_coords: + indexes, index_vars = roll_indexes(self.xindexes, shifts) + unrolled_vars = () + else: + indexes = dict(self._indexes) + index_vars = dict(self.xindexes.variables) + unrolled_vars = tuple(self.coords) variables = {} for k, var in self.variables.items(): - if k not in unrolled_vars: + if k in index_vars: + variables[k] = index_vars[k] + elif k not in unrolled_vars: variables[k] = var.roll( shifts={k: s for k, s in shifts.items() if k in var.dims} ) else: variables[k] = var - if roll_coords: - indexes: dict[Hashable, Index] = {} - idx: pd.Index - for k, idx in self.xindexes.items(): - (dim,) = self.variables[k].dims - if dim in shifts: - indexes[k] = roll_index(idx, shifts[dim]) - else: - indexes[k] = idx - else: - indexes = dict(self.xindexes) - return self._replace(variables, indexes=indexes) def sortby(self, variables, ascending=True): @@ -6322,7 +6471,7 @@ def quantile( # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None @@ -6555,7 +6704,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -7163,6 +7312,9 @@ def pad( promoted to ``float`` and padded with ``np.nan``. To avoid type promotion specify ``constant_values=np.nan`` + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> ds = xr.Dataset({"foo": ("x", range(5))}) @@ -7188,6 +7340,15 @@ def pad( coord_pad_options = {} variables = {} + + # keep indexes that won't be affected by pad and drop all other indexes + xindexes = self.xindexes + pad_dims = set(pad_width) + indexes = {} + for k, idx in xindexes.items(): + if not pad_dims.intersection(xindexes.get_all_dims(k)): + indexes[k] = idx + for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: @@ -7207,8 +7368,15 @@ def pad( mode=coord_pad_mode, **coord_pad_options, # type: ignore[arg-type] ) - - return self._replace_vars_and_dims(variables) + # reset default index of dimension coordinates + if (name,) == var.dims: + dim_var = {name: variables[name]} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) + indexes[name] = index + variables[name] = index_vars[name] + + return self._replace_with_new_dims(variables, indexes=indexes) def idxmin( self, diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 2a9f8a27815..81617ae38f9 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -2,6 +2,7 @@ """ import contextlib import functools +from collections import defaultdict from datetime import datetime, timedelta from itertools import chain, zip_longest from typing import Collection, Hashable, Optional @@ -266,10 +267,10 @@ def inline_sparse_repr(array): def inline_variable_array_repr(var, max_width): """Build a one-line summary of a variable's data.""" - if var._in_memory: - return format_array_flat(var, max_width) - elif hasattr(var._data, "_repr_inline_"): + if hasattr(var._data, "_repr_inline_"): return var._data._repr_inline_(max_width) + elif var._in_memory: + return format_array_flat(var, max_width) elif isinstance(var._data, dask_array_type): return inline_dask_repr(var.data) elif isinstance(var._data, sparse_array_type): @@ -282,68 +283,33 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( - name: Hashable, var, col_width: int, marker: str = " ", max_width: int = None + name: Hashable, var, col_width: int, max_width: int = None, is_index: bool = False ): """Summarize a variable in one line, e.g., for the Dataset.__repr__.""" + variable = getattr(var, "variable", var) + if max_width is None: max_width_options = OPTIONS["display_width"] if not isinstance(max_width_options, int): raise TypeError(f"`max_width` value of `{max_width}` is not a valid int") else: max_width = max_width_options + + marker = "*" if is_index else " " first_col = pretty_print(f" {marker} {name} ", col_width) - if var.dims: - dims_str = "({}) ".format(", ".join(map(str, var.dims))) + + if variable.dims: + dims_str = "({}) ".format(", ".join(map(str, variable.dims))) else: dims_str = "" - front_str = f"{first_col}{dims_str}{var.dtype} " + front_str = f"{first_col}{dims_str}{variable.dtype} " values_width = max_width - len(front_str) - values_str = inline_variable_array_repr(var, values_width) + values_str = inline_variable_array_repr(variable, values_width) return front_str + values_str -def _summarize_coord_multiindex(coord, col_width, marker): - first_col = pretty_print(f" {marker} {coord.name} ", col_width) - return f"{first_col}({str(coord.dims[0])}) MultiIndex" - - -def _summarize_coord_levels(coord, col_width, marker="-"): - if len(coord) > 100 and col_width < len(coord): - n_values = col_width - indices = list(range(0, n_values)) + list(range(-n_values, 0)) - subset = coord[indices] - else: - subset = coord - - return "\n".join( - summarize_variable( - lname, subset.get_level_variable(lname), col_width, marker=marker - ) - for lname in subset.level_names - ) - - -def summarize_datavar(name, var, col_width): - return summarize_variable(name, var.variable, col_width) - - -def summarize_coord(name: Hashable, var, col_width: int): - is_index = name in var.dims - marker = "*" if is_index else " " - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - return "\n".join( - [ - _summarize_coord_multiindex(coord, col_width, marker), - _summarize_coord_levels(coord, col_width), - ] - ) - return summarize_variable(name, var.variable, col_width, marker) - - def summarize_attr(key, value, col_width=None): """Summary for __repr__ - use ``X.attrs[key]`` for full value.""" # Indent key and add ':', then right-pad if col_width is not None @@ -359,23 +325,6 @@ def summarize_attr(key, value, col_width=None): EMPTY_REPR = " *empty*" -def _get_col_items(mapping): - """Get all column items to format, including both keys of `mapping` - and MultiIndex levels if any. - """ - from .variable import IndexVariable - - col_items = [] - for k, v in mapping.items(): - col_items.append(k) - var = getattr(v, "variable", v) - if isinstance(var, IndexVariable): - level_names = var.to_index_variable().level_names - if level_names is not None: - col_items += list(level_names) - return col_items - - def _calculate_col_width(col_items): max_name_length = max(len(str(s)) for s in col_items) if col_items else 0 col_width = max(max_name_length, 7) + 6 @@ -383,10 +332,21 @@ def _calculate_col_width(col_items): def _mapping_repr( - mapping, title, summarizer, expand_option_name, col_width=None, max_rows=None + mapping, + title, + summarizer, + expand_option_name, + col_width=None, + max_rows=None, + indexes=None, ): if col_width is None: col_width = _calculate_col_width(mapping) + + summarizer_kwargs = defaultdict(dict) + if indexes is not None: + summarizer_kwargs = {k: {"is_index": k in indexes} for k in mapping} + summary = [f"{title}:"] if mapping: len_mapping = len(mapping) @@ -396,15 +356,22 @@ def _mapping_repr( summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = calc_max_rows_first(max_rows) keys = list(mapping.keys()) - summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]] + summary += [ + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[:first_rows] + ] if max_rows > 1: last_rows = calc_max_rows_last(max_rows) summary += [pretty_print(" ...", col_width) + " ..."] summary += [ - summarizer(k, mapping[k], col_width) for k in keys[-last_rows:] + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[-last_rows:] ] else: - summary += [summarizer(k, v, col_width) for k, v in mapping.items()] + summary += [ + summarizer(k, v, col_width, **summarizer_kwargs[k]) + for k, v in mapping.items() + ] else: summary += [EMPTY_REPR] return "\n".join(summary) @@ -413,7 +380,7 @@ def _mapping_repr( data_vars_repr = functools.partial( _mapping_repr, title="Data variables", - summarizer=summarize_datavar, + summarizer=summarize_variable, expand_option_name="display_expand_data_vars", ) @@ -428,21 +395,25 @@ def _mapping_repr( def coords_repr(coords, col_width=None, max_rows=None): if col_width is None: - col_width = _calculate_col_width(_get_col_items(coords)) + col_width = _calculate_col_width(coords) return _mapping_repr( coords, title="Coordinates", - summarizer=summarize_coord, + summarizer=summarize_variable, expand_option_name="display_expand_coords", col_width=col_width, + indexes=coords.xindexes, max_rows=max_rows, ) def indexes_repr(indexes): - summary = [] - for k, v in indexes.items(): - summary.append(wrap_indent(repr(v), f"{k}: ")) + summary = ["Indexes:"] + if indexes: + for k, v in indexes.items(): + summary.append(wrap_indent(repr(v), f"{k}: ")) + else: + summary += [EMPTY_REPR] return "\n".join(summary) @@ -604,7 +575,7 @@ def array_repr(arr): if hasattr(arr, "coords"): if arr.coords: - col_width = _calculate_col_width(_get_col_items(arr.coords)) + col_width = _calculate_col_width(arr.coords) summary.append( coords_repr(arr.coords, col_width=col_width, max_rows=max_rows) ) @@ -624,7 +595,7 @@ def array_repr(arr): def dataset_repr(ds): summary = [f""] - col_width = _calculate_col_width(_get_col_items(ds.variables)) + col_width = _calculate_col_width(ds.variables) max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) @@ -655,9 +626,20 @@ def diff_dim_summary(a, b): return "" -def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): - def extra_items_repr(extra_keys, mapping, ab_side): - extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] +def _diff_mapping_repr( + a_mapping, + b_mapping, + compat, + title, + summarizer, + col_width=None, + a_indexes=None, + b_indexes=None, +): + def extra_items_repr(extra_keys, mapping, ab_side, kwargs): + extra_repr = [ + summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys + ] if extra_repr: header = f"{title} only on the {ab_side} object:" return [header] + extra_repr @@ -671,6 +653,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): diff_items = [] + a_summarizer_kwargs = defaultdict(dict) + if a_indexes is not None: + a_summarizer_kwargs = {k: {"is_index": k in a_indexes} for k in a_mapping} + b_summarizer_kwargs = defaultdict(dict) + if b_indexes is not None: + b_summarizer_kwargs = {k: {"is_index": k in b_indexes} for k in b_mapping} + for k in a_keys & b_keys: try: # compare xarray variable @@ -690,7 +679,8 @@ def extra_items_repr(extra_keys, mapping, ab_side): if not compatible: temp = [ - summarizer(k, vars[k], col_width) for vars in (a_mapping, b_mapping) + summarizer(k, a_mapping[k], col_width, **a_summarizer_kwargs[k]), + summarizer(k, b_mapping[k], col_width, **b_summarizer_kwargs[k]), ] if compat == "identical" and is_variable: @@ -712,19 +702,29 @@ def extra_items_repr(extra_keys, mapping, ab_side): if diff_items: summary += [f"Differing {title.lower()}:"] + diff_items - summary += extra_items_repr(a_keys - b_keys, a_mapping, "left") - summary += extra_items_repr(b_keys - a_keys, b_mapping, "right") + summary += extra_items_repr(a_keys - b_keys, a_mapping, "left", a_summarizer_kwargs) + summary += extra_items_repr( + b_keys - a_keys, b_mapping, "right", b_summarizer_kwargs + ) return "\n".join(summary) -diff_coords_repr = functools.partial( - _diff_mapping_repr, title="Coordinates", summarizer=summarize_coord -) +def diff_coords_repr(a, b, compat, col_width=None): + return _diff_mapping_repr( + a, + b, + compat, + "Coordinates", + summarize_variable, + col_width=col_width, + a_indexes=a.indexes, + b_indexes=b.indexes, + ) diff_data_vars_repr = functools.partial( - _diff_mapping_repr, title="Data variables", summarizer=summarize_datavar + _diff_mapping_repr, title="Data variables", summarizer=summarize_variable ) @@ -786,9 +786,7 @@ def diff_dataset_repr(a, b, compat): ) ] - col_width = _calculate_col_width( - set(_get_col_items(a.variables) + _get_col_items(b.variables)) - ) + col_width = _calculate_col_width(set(list(a.variables) + list(b.variables))) summary.append(diff_dim_summary(a, b)) summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 36c252f276e..db62466a8d3 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -31,12 +31,12 @@ def short_data_repr_html(array): return f"
{text}
" -def format_dims(dims, coord_names): +def format_dims(dims, dims_with_index): if not dims: return "" dim_css_map = { - k: " class='xr-has-index'" if k in coord_names else "" for k, v in dims.items() + dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims } dims_li = "".join( @@ -66,38 +66,7 @@ def _icon(icon_name): ) -def _summarize_coord_multiindex(name, coord): - preview = f"({', '.join(escape(l) for l in coord.level_names)})" - return summarize_variable( - name, coord, is_index=True, dtype="MultiIndex", preview=preview - ) - - -def summarize_coord(name, var): - is_index = name in var.dims - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - coords = {name: _summarize_coord_multiindex(name, coord)} - for lname in coord.level_names: - var = coord.get_level_variable(lname) - coords[lname] = summarize_variable(lname, var) - return coords - - return {name: summarize_variable(name, var, is_index)} - - -def summarize_coords(variables): - coords = {} - for k, v in variables.items(): - coords.update(**summarize_coord(k, v)) - - vars_li = "".join(f"
  • {v}
  • " for v in coords.values()) - - return f"
      {vars_li}
    " - - -def summarize_variable(name, var, is_index=False, dtype=None, preview=None): +def summarize_variable(name, var, is_index=False, dtype=None): variable = var.variable if hasattr(var, "variable") else var cssclass_idx = " class='xr-has-index'" if is_index else "" @@ -110,7 +79,7 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): data_id = "data-" + str(uuid.uuid4()) disabled = "" if len(var.attrs) else "disabled" - preview = preview or escape(inline_variable_array_repr(variable, 35)) + preview = escape(inline_variable_array_repr(variable, 35)) attrs_ul = summarize_attrs(var.attrs) data_repr = short_data_repr_html(variable) @@ -134,6 +103,17 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): ) +def summarize_coords(variables): + li_items = [] + for k, v in variables.items(): + li_content = summarize_variable(k, v, is_index=k in variables.xindexes) + li_items.append(f"
  • {li_content}
  • ") + + vars_li = "".join(li_items) + + return f"
      {vars_li}
    " + + def summarize_vars(variables): vars_li = "".join( f"
  • {summarize_variable(k, v)}
  • " @@ -184,7 +164,7 @@ def _mapping_section( def dim_section(obj): - dim_list = format_dims(obj.dims, list(obj.coords)) + dim_list = format_dims(obj.dims, obj.xindexes.dims) return collapsible_section( "Dimensions", inline_details=dim_list, enabled=False, collapsed=True @@ -265,15 +245,18 @@ def _obj_repr(obj, header_components, sections): def array_repr(arr): dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + if hasattr(arr, "xindexes"): + indexed_dims = arr.xindexes.dims + else: + indexed_dims = {} obj_type = f"xarray.{type(arr).__name__}" arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" - coord_names = list(arr.coords) if hasattr(arr, "coords") else [] header_components = [ f"
    {obj_type}
    ", f"
    {arr_name}
    ", - format_dims(dims, coord_names), + format_dims(dims, indexed_dims), ] sections = [array_section(arr)] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d3ec824159c..db3dfbd384b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -9,7 +9,7 @@ from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat -from .indexes import propagate_indexes +from .indexes import create_default_index_implicit, filter_indexes_from_coords from .options import _get_keep_attrs from .pycompat import integer_types from .utils import ( @@ -20,7 +20,7 @@ peek_at, safe_cast_to_index, ) -from .variable import IndexVariable, Variable, as_variable +from .variable import IndexVariable, Variable def check_reduce_dims(reduce_dims, dimensions): @@ -54,6 +54,8 @@ def unique_value_groups(ar, sort=True): the corresponding value in `unique_values`. """ inverse, values = pd.factorize(ar, sort=sort) + if isinstance(values, pd.MultiIndex): + values.names = ar.names groups = [[] for _ in range(len(values))] for n, g in enumerate(inverse): if g >= 0: @@ -469,6 +471,7 @@ def _infer_concat_args(self, applied_example): (dim,) = coord.dims if isinstance(coord, _DummyGroup): coord = None + coord = getattr(coord, "variable", coord) return coord, dim, positions def _binary_op(self, other, f, reflexive=False): @@ -519,7 +522,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] - obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims) + obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj def fillna(self, value): @@ -763,6 +766,8 @@ def _concat_shortcut(self, applied, dim, positions=None): # speed things up, but it's not very interpretable and there are much # faster alternatives (e.g., doing the grouped aggregation in a # compiled language) + # TODO: benbovy - explicit indexes: this fast implementation doesn't + # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) reordered = _maybe_reorder(stacked, dim, positions) return self._obj._replace_maybe_drop_dims(reordered) @@ -854,13 +859,11 @@ def _combine(self, applied, shortcut=False): if isinstance(combined, type(self._obj)): # only restore dimension order for arrays combined = self._restore_dim_order(combined) - # assign coord when the applied function does not return that coord + # assign coord and index when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - if shortcut: - coord_var = as_variable(coord) - combined._coords[coord.name] = coord_var - else: - combined.coords[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, coords=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined @@ -976,7 +979,9 @@ def _combine(self, applied): combined = _maybe_reorder(combined, dim, positions) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - combined[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, variables=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b66fbdf6504..e02e1f569b2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,44 +1,70 @@ +from __future__ import annotations + import collections.abc +import copy +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, Dict, + Generic, Hashable, Iterable, Iterator, Mapping, - Optional, Sequence, - Tuple, - Union, + TypeVar, + cast, ) import numpy as np import pandas as pd -from . import formatting, utils -from .indexing import ( - LazilyIndexedArray, - PandasIndexingAdapter, - PandasMultiIndexingAdapter, -) -from .utils import is_dict_like, is_scalar +from . import formatting, nputils, utils +from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter +from .types import T_Index +from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: - from .variable import IndexVariable, Variable + from .variable import Variable -IndexVars = Dict[Hashable, "IndexVariable"] +IndexVars = Dict[Any, "Variable"] class Index: """Base class inherited by all xarray-compatible indexes.""" @classmethod - def from_variables( - cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover + def from_variables(cls, variables: Mapping[Any, Variable]) -> Index: + raise NotImplementedError() + + @classmethod + def concat( + cls: type[T_Index], + indexes: Sequence[T_Index], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> T_Index: + raise NotImplementedError() + + @classmethod + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Index: + raise NotImplementedError( + f"{cls!r} cannot be used for creating an index of stacked coordinates" + ) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: raise NotImplementedError() + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + if variables is not None: + # pass through + return dict(**variables) + else: + return {} + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a TypeError if this is not supported. @@ -47,27 +73,54 @@ def to_pandas_index(self) -> pd.Index: pandas.Index object. """ - raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") - def query( - self, labels: Dict[Hashable, Any] - ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover - raise NotImplementedError() + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Index | None: + return None + + def sel(self, labels: dict[Any, Any]) -> IndexSelResult: + raise NotImplementedError(f"{self!r} doesn't support label-based selection") + + def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: + raise NotImplementedError( + f"{self!r} doesn't support alignment with inner/outer join method" + ) + + def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") def equals(self, other): # pragma: no cover raise NotImplementedError() - def union(self, other): # pragma: no cover - raise NotImplementedError() + def roll(self, shifts: Mapping[Any, int]) -> Index | None: + return None - def intersection(self, other): # pragma: no cover - raise NotImplementedError() + def rename( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> Index: + return self - def copy(self, deep: bool = True): # pragma: no cover - raise NotImplementedError() + def __copy__(self) -> Index: + return self.copy(deep=False) + + def __deepcopy__(self, memo=None) -> Index: + # memo does nothing but is required for compatibility with + # copy.deepcopy + return self.copy(deep=True) + + def copy(self, deep: bool = True) -> Index: + cls = self.__class__ + copied = cls.__new__(cls) + if deep: + for k, v in self.__dict__.items(): + setattr(copied, k, copy.deepcopy(v)) + else: + copied.__dict__.update(self.__dict__) + return copied def __getitem__(self, indexer: Any): - # if not implemented, index will be dropped from the Dataset or DataArray raise NotImplementedError() @@ -134,6 +187,23 @@ def _is_nested_tuple(possible_tuple): ) +def normalize_label(value, dtype=None) -> np.ndarray: + if getattr(value, "ndim", 1) <= 1: + value = _asarray_tuplesafe(value) + if dtype is not None and dtype.kind == "f" and value.dtype.kind != "b": + # pd.Index built from coordinate with float precision != 64 + # see https://github.com/pydata/xarray/pull/3153 for details + # bypass coercing dtype for boolean indexers (ignore index) + # see https://github.com/pydata/xarray/issues/5727 + value = np.asarray(value, dtype=dtype) + return value + + +def as_scalar(value: np.ndarray): + # see https://github.com/pydata/xarray/pull/4292 for details + return value[()] if value.dtype.kind in "mM" else value.item() + + def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -147,16 +217,37 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): class PandasIndex(Index): """Wrap a pandas.Index as an xarray compatible index.""" - __slots__ = ("index", "dim") + index: pd.Index + dim: Hashable + coord_dtype: Any - def __init__(self, array: Any, dim: Hashable): - self.index = utils.safe_cast_to_index(array) + __slots__ = ("index", "dim", "coord_dtype") + + def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + index = utils.safe_cast_to_index(array).copy() + + if index.name is None: + index.name = dim + + self.index = index self.dim = dim - @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): - from .variable import IndexVariable + if coord_dtype is None: + coord_dtype = get_valid_numpy_dtype(index) + self.coord_dtype = coord_dtype + def _replace(self, index, dim=None, coord_dtype=None): + if dim is None: + dim = self.dim + if coord_dtype is None: + coord_dtype = self.coord_dtype + return type(self)(index, dim, coord_dtype) + + @classmethod + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasIndex: if len(variables) != 1: raise ValueError( f"PandasIndex only accepts one variable, found {len(variables)} variables" @@ -172,35 +263,113 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): dim = var.dims[0] - obj = cls(var.data, dim) + # TODO: (benbovy - explicit indexes): add __index__ to ExplicitlyIndexesNDArrayMixin? + # this could be eventually used by Variable.to_index() and would remove the need to perform + # the checks below. - data = PandasIndexingAdapter(obj.index) - index_var = IndexVariable( - dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True - ) + # preserve wrapped pd.Index (if any) + data = getattr(var._data, "array", var.data) + # multi-index level variable: get level index + if isinstance(var._data, PandasMultiIndexingAdapter): + level = var._data.level + if level is not None: + data = var._data.array.get_level_values(level) + + obj = cls(data, dim, coord_dtype=var.dtype) + assert not isinstance(obj.index, pd.MultiIndex) + obj.index.name = name + + return obj + + @staticmethod + def _concat_indexes(indexes, dim, positions=None) -> pd.Index: + new_pd_index: pd.Index + + if not indexes: + new_pd_index = pd.Index([]) + else: + if not all(idx.dim == dim for idx in indexes): + dims = ",".join({f"{idx.dim!r}" for idx in indexes}) + raise ValueError( + f"Cannot concatenate along dimension {dim!r} indexes with " + f"dimensions: {dims}" + ) + pd_indexes = [idx.index for idx in indexes] + new_pd_index = pd_indexes[0].append(pd_indexes[1:]) - return obj, {name: index_var} + if positions is not None: + indices = nputils.inverse_permutation(np.concatenate(positions)) + new_pd_index = new_pd_index.take(indices) + + return new_pd_index @classmethod - def from_pandas_index(cls, index: pd.Index, dim: Hashable): + def concat( + cls, + indexes: Sequence[PandasIndex], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> PandasIndex: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + coord_dtype = None + else: + coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) + + return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: from .variable import IndexVariable - if index.name is None: - name = dim - index = index.copy() - index.name = dim - else: - name = index.name + name = self.index.name + attrs: Mapping[Hashable, Any] | None + encoding: Mapping[Hashable, Any] | None - data = PandasIndexingAdapter(index) - index_var = IndexVariable(dim, data, fastpath=True) + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None - return cls(index, dim), {name: index_var} + data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) + var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) + return {name: var} def to_pandas_index(self) -> pd.Index: return self.index - def query(self, labels, method=None, tolerance=None): + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> PandasIndex | None: + from .variable import Variable + + indxr = indexers[self.dim] + if isinstance(indxr, Variable): + if indxr.dims != (self.dim,): + # can't preserve a index if result has new dimensions + return None + else: + indxr = indxr.data + if not isinstance(indxr, slice) and is_scalar(indxr): + # scalar indexer: drop index + return None + + return self._replace(self.index[indxr]) + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from .dataarray import DataArray + from .variable import Variable + + if method is not None and not isinstance(method, str): + raise TypeError("``method`` must be a string") + assert len(labels) == 1 coord_name, label = next(iter(labels.items())) @@ -212,147 +381,423 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() + label_array = normalize_label(label, dtype=self.coord_dtype) + if label_array.ndim == 0: + label_value = as_scalar(label_array) if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( - "'method' is not a valid kwarg when indexing using a CategoricalIndex." + "'method' is not supported when indexing using a CategoricalIndex." ) if tolerance is not None: raise ValueError( - "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + "'tolerance' is not supported when indexing using a CategoricalIndex." ) indexer = self.index.get_loc(label_value) else: if method is not None: - indexer = get_indexer_nd(self.index, label, method, tolerance) + indexer = get_indexer_nd( + self.index, label_array, method, tolerance + ) if np.any(indexer < 0): raise KeyError( f"not all values found in index {coord_name!r}" ) else: indexer = self.index.get_loc(label_value) - elif label.dtype.kind == "b": - indexer = label + elif label_array.dtype.kind == "b": + indexer = label_array else: - indexer = get_indexer_nd(self.index, label, method, tolerance) + indexer = get_indexer_nd(self.index, label_array, method, tolerance) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") - return indexer, None + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + indexer = DataArray(indexer, coords=label._coords, dims=label.dims) - def equals(self, other): - return self.index.equals(other.index) + return IndexSelResult({self.dim: indexer}) - def union(self, other): - new_index = self.index.union(other.index) - return type(self)(new_index, self.dim) + def equals(self, other: Index): + if not isinstance(other, PandasIndex): + return False + return self.index.equals(other.index) and self.dim == other.dim - def intersection(self, other): - new_index = self.index.intersection(other.index) - return type(self)(new_index, self.dim) + def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + if how == "outer": + index = self.index.union(other.index) + else: + # how = "inner" + index = self.index.intersection(other.index) + + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + return type(self)(index, self.dim, coord_dtype=coord_dtype) + + def reindex_like( + self, other: PandasIndex, method=None, tolerance=None + ) -> dict[Hashable, Any]: + if not self.index.is_unique: + raise ValueError( + f"cannot reindex or align along dimension {self.dim!r} because the " + "(pandas) index has duplicate values" + ) + + return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: + shift = shifts[self.dim] % self.index.shape[0] + + if shift != 0: + new_pd_idx = self.index[-shift:].append(self.index[:-shift]) + else: + new_pd_idx = self.index[:] + + return self._replace(new_pd_idx) + + def rename(self, name_dict, dims_dict): + if self.index.name not in name_dict and self.dim not in dims_dict: + return self + + new_name = name_dict.get(self.index.name, self.index.name) + index = self.index.rename(new_name) + new_dim = dims_dict.get(self.dim, self.dim) + return self._replace(index, dim=new_dim) def copy(self, deep=True): - return type(self)(self.index.copy(deep=deep), self.dim) + if deep: + index = self.index.copy(deep=True) + else: + # index will be copied in constructor + index = self.index + return self._replace(index) def __getitem__(self, indexer: Any): - return type(self)(self.index[indexer], self.dim) - + return self._replace(self.index[indexer]) -def _create_variables_from_multiindex(index, dim, level_meta=None): - from .variable import IndexVariable - if level_meta is None: - level_meta = {} +def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal"): + """Check that all multi-index variable candidates are 1-dimensional and + either share the same (single) dimension or each have a different dimension. - variables = {} + """ + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") - dim_coord_adapter = PandasMultiIndexingAdapter(index) - variables[dim] = IndexVariable( - dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True - ) + dims = {var.dims for var in variables.values()} - for level in index.names: - meta = level_meta.get(level, {}) - data = PandasMultiIndexingAdapter( - index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter + if all_dims == "equal" and len(dims) > 1: + raise ValueError( + "unmatched dimensions for multi-index variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) ) - variables[level] = IndexVariable( - dim, - data, - attrs=meta.get("attrs"), - encoding=meta.get("encoding"), - fastpath=True, + + if all_dims == "different" and len(dims) < len(variables): + raise ValueError( + "conflicting dimensions for multi-index product variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) ) - return variables + +def remove_unused_levels_categories(index: pd.Index) -> pd.Index: + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + else: + level = level[index.codes[i]] + levels.append(level) + # TODO: calling from_array() reorders MultiIndex levels. It would + # be best to avoid this, if possible, e.g., by using + # MultiIndex.remove_unused_levels() (which does not reorder) on the + # part of the MultiIndex that is not categorical, or by fixing this + # upstream in pandas. + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index class PandasMultiIndex(PandasIndex): - @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): - if any([var.ndim != 1 for var in variables.values()]): - raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") + """Wrap a pandas.MultiIndex as an xarray compatible index.""" - dims = {var.dims for var in variables.values()} - if len(dims) != 1: - raise ValueError( - "unmatched dimensions for variables " - + ",".join([str(k) for k in variables]) - ) + level_coords_dtype: dict[str, Any] + + __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") + + def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): + super().__init__(array, dim) + + # default index level names + names = [] + for i, idx in enumerate(self.index.levels): + name = idx.name or f"{dim}_level_{i}" + if name == dim: + raise ValueError( + f"conflicting multi-index level name {name!r} with dimension {dim!r}" + ) + names.append(name) + self.index.names = names + + if level_coords_dtype is None: + level_coords_dtype = { + idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels + } + self.level_coords_dtype = level_coords_dtype + + def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex: + if dim is None: + dim = self.dim + index.name = dim + if level_coords_dtype is None: + level_coords_dtype = self.level_coords_dtype + return type(self)(index, dim, level_coords_dtype) + + @classmethod + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasMultiIndex: + _check_dim_compat(variables) + dim = next(iter(variables.values())).dims[0] - dim = next(iter(dims))[0] index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() ) - obj = cls(index, dim) + index.name = dim + level_coords_dtype = {name: var.dtype for name, var in variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - level_meta = { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - for name, var in variables.items() - } - index_vars = _create_variables_from_multiindex( - index, dim, level_meta=level_meta - ) + return obj - return obj, index_vars + @classmethod + def concat( # type: ignore[override] + cls, + indexes: Sequence[PandasMultiIndex], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> PandasMultiIndex: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + level_coords_dtype = None + else: + level_coords_dtype = {} + for name in indexes[0].level_coords_dtype: + level_coords_dtype[name] = np.result_type( + *[idx.level_coords_dtype[name] for idx in indexes] + ) + + return cls(new_pd_index, dim=dim, level_coords_dtype=level_coords_dtype) + + @classmethod + def stack( + cls, variables: Mapping[Any, Variable], dim: Hashable + ) -> PandasMultiIndex: + """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a + new dimension. + + Level variables must have a dimension distinct from each other. + + Keeps levels the same (doesn't refactorize them) so that it gives back the original + labels after a stack/unstack roundtrip. + + """ + _check_dim_compat(variables, all_dims="different") + + level_indexes = [utils.safe_cast_to_index(var) for var in variables.values()] + for name, idx in zip(variables, level_indexes): + if isinstance(idx, pd.MultiIndex): + raise ValueError( + f"cannot create a multi-index along stacked dimension {dim!r} " + f"from variable {name!r} that wraps a multi-index" + ) + + split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] + + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + level_coords_dtype = {k: var.dtype for k, var in variables.items()} + + return cls(index, dim, level_coords_dtype=level_coords_dtype) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: + clean_index = remove_unused_levels_categories(self.index) + + new_indexes: dict[Hashable, Index] = {} + for name, lev in zip(clean_index.names, clean_index.levels): + idx = PandasIndex( + lev.copy(), name, coord_dtype=self.level_coords_dtype[name] + ) + new_indexes[name] = idx + + return new_indexes, clean_index @classmethod - def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): - index_vars = _create_variables_from_multiindex(index, dim) - return cls(index, dim), index_vars + def from_variables_maybe_expand( + cls, + dim: Hashable, + current_variables: Mapping[Any, Variable], + variables: Mapping[Any, Variable], + ) -> tuple[PandasMultiIndex, IndexVars]: + """Create a new multi-index maybe by expanding an existing one with + new variables as index levels. + + The index and its corresponding coordinates may be created along a new dimension. + """ + names: list[Hashable] = [] + codes: list[list[int]] = [] + levels: list[list[int]] = [] + level_variables: dict[Any, Variable] = {} + + _check_dim_compat({**current_variables, **variables}) + + if len(current_variables) > 1: + # expand from an existing multi-index + data = cast( + PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data + ) + current_index = data.array + names.extend(current_index.names) + codes.extend(current_index.codes) + levels.extend(current_index.levels) + for name in current_index.names: + level_variables[name] = current_variables[name] + + elif len(current_variables) == 1: + # expand from one 1D variable (no multi-index): convert it to an index level + var = next(iter(current_variables.values())) + new_var_name = f"{dim}_level_0" + names.append(new_var_name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[new_var_name] = var + + for name, var in variables.items(): + names.append(name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[name] = var + + index = pd.MultiIndex(levels, codes, names=names) + level_coords_dtype = {k: var.dtype for k, var in level_variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = obj.create_variables(level_variables) + + return obj, index_vars + + def keep_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex | PandasIndex: + """Keep only the provided levels and return a new multi-index with its + corresponding coordinates. + + """ + index = self.index.droplevel( + [k for k in self.index.names if k not in level_variables] + ) + + if isinstance(index, pd.MultiIndex): + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + else: + return PandasIndex( + index, self.dim, coord_dtype=self.level_coords_dtype[index.name] + ) + + def reorder_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex: + """Re-arrange index levels using input order and return a new multi-index with + its corresponding coordinates. + + """ + index = self.index.reorder_levels(level_variables.keys()) + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from .variable import IndexVariable + + if variables is None: + variables = {} + + index_vars: IndexVars = {} + for name in (self.dim,) + self.index.names: + if name == self.dim: + level = None + dtype = None + else: + level = name + dtype = self.level_coords_dtype[name] + + var = variables.get(name, None) + if var is not None: + attrs = var.attrs + encoding = var.encoding + else: + attrs = {} + encoding = {} + + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + index_vars[name] = IndexVariable( + self.dim, + data, + attrs=attrs, + encoding=encoding, + fastpath=True, + ) + + return index_vars + + def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: + from .dataarray import DataArray + from .variable import Variable - def query(self, labels, method=None, tolerance=None): if method is not None or tolerance is not None: raise ValueError( "multi-index does not support ``method`` and ``tolerance``" ) new_index = None + scalar_coord_values = {} # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): - is_nested_vals = _is_nested_tuple(tuple(labels.values())) - if len(labels) == self.index.nlevels and not is_nested_vals: - indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) + label_values = {} + for k, v in labels.items(): + label_array = normalize_label(v, dtype=self.level_coords_dtype[k]) + try: + label_values[k] = as_scalar(label_array) + except ValueError: + # label should be an item not an array-like + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + + has_slice = any([isinstance(v, slice) for v in label_values.values()]) + + if len(label_values) == self.index.nlevels and not has_slice: + indexer = self.index.get_loc( + tuple(label_values[k] for k in self.index.names) + ) else: - for k, v in labels.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - f"available along coordinate {k!r} (multi-index level)" - ) indexer, new_index = self.index.get_loc_level( - tuple(labels.values()), level=tuple(labels.keys()) + tuple(label_values.values()), level=tuple(label_values.keys()) ) + scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen if indexer.dtype.kind == "b" and indexer.sum() == 0: raise KeyError(f"{labels} not found") @@ -376,7 +821,7 @@ def query(self, labels, method=None, tolerance=None): raise ValueError( f"invalid multi-index level names {invalid_levels}" ) - return self.query(label) + return self.sel(label) elif isinstance(label, slice): indexer = _query_slice(self.index, label, coord_name) @@ -387,94 +832,383 @@ def query(self, labels, method=None, tolerance=None): elif len(label) == self.index.nlevels: indexer = self.index.get_loc(label) else: - indexer, new_index = self.index.get_loc_level( - label, level=list(range(len(label))) - ) + levels = [self.index.names[i] for i in range(len(label))] + indexer, new_index = self.index.get_loc_level(label, level=levels) + scalar_coord_values.update({k: v for k, v in zip(levels, label)}) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - indexer, new_index = self.index.get_loc_level(label.item(), level=0) - elif label.dtype.kind == "b": - indexer = label + label_array = normalize_label(label) + if label_array.ndim == 0: + label_value = as_scalar(label_array) + indexer, new_index = self.index.get_loc_level(label_value, level=0) + scalar_coord_values[self.index.names[0]] = label_value + elif label_array.dtype.kind == "b": + indexer = label_array else: - if label.ndim > 1: + if label_array.ndim > 1: raise ValueError( "Vectorized selection is not available along " f"coordinate {coord_name!r} with a multi-index" ) - indexer = get_indexer_nd(self.index, label) + indexer = get_indexer_nd(self.index, label_array) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + # do not include label-indexer DataArray coordinates that conflict + # with the level names of this index + coords = { + k: v + for k, v in label._coords.items() + if k not in self.index.names + } + indexer = DataArray(indexer, coords=coords, dims=label.dims) + if new_index is not None: if isinstance(new_index, pd.MultiIndex): - new_index, new_vars = PandasMultiIndex.from_pandas_index( - new_index, self.dim + level_coords_dtype = { + k: self.level_coords_dtype[k] for k in new_index.names + } + new_index = self._replace( + new_index, level_coords_dtype=level_coords_dtype ) + dims_dict = {} + drop_coords = [] else: - new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim) - return indexer, (new_index, new_vars) + new_index = PandasIndex( + new_index, + new_index.name, + coord_dtype=self.level_coords_dtype[new_index.name], + ) + dims_dict = {self.dim: new_index.index.name} + drop_coords = [self.dim] + + # variable(s) attrs and encoding metadata are propagated + # when replacing the indexes in the resulting xarray object + new_vars = new_index.create_variables() + indexes = cast(Dict[Any, Index], {k: new_index for k in new_vars}) + + # add scalar variable for each dropped level + variables = new_vars + for name, val in scalar_coord_values.items(): + variables[name] = Variable([], val) + + return IndexSelResult( + {self.dim: indexer}, + indexes=indexes, + variables=variables, + drop_indexes=list(scalar_coord_values), + drop_coords=drop_coords, + rename_dims=dims_dict, + ) + + else: + return IndexSelResult({self.dim: indexer}) + + def join(self, other, how: str = "inner"): + if how == "outer": + # bug in pandas? need to reset index.name + other_index = other.index.copy() + other_index.name = None + index = self.index.union(other_index) + index.name = self.dim else: - return indexer, None + # how = "inner" + index = self.index.intersection(other.index) + level_coords_dtype = { + k: np.result_type(lvl_dtype, other.level_coords_dtype[k]) + for k, lvl_dtype in self.level_coords_dtype.items() + } + + return type(self)(index, self.dim, level_coords_dtype=level_coords_dtype) + + def rename(self, name_dict, dims_dict): + if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: + return self + + # pandas 1.3.0: could simply do `self.index.rename(names_dict)` + new_names = [name_dict.get(k, k) for k in self.index.names] + index = self.index.rename(new_names) + + new_dim = dims_dict.get(self.dim, self.dim) + new_level_coords_dtype = { + k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + } + return self._replace( + index, dim=new_dim, level_coords_dtype=new_level_coords_dtype + ) + + +def create_default_index_implicit( + dim_variable: Variable, + all_variables: Mapping | Iterable[Hashable] | None = None, +) -> tuple[PandasIndex, IndexVars]: + """Create a default index from a dimension variable. + + Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, + otherwise create a PandasIndex (note that this will become obsolete once we + depreciate implcitly passing a pandas.MultiIndex as a coordinate). -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: - """ - Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ - if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() - # if it contains CategoricalIndex, we need to remove unused categories - # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): - levels = [] - for i, level in enumerate(index.levels): - if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() - else: - level = level[index.codes[i]] - levels.append(level) - # TODO: calling from_array() reorders MultiIndex levels. It would - # be best to avoid this, if possible, e.g., by using - # MultiIndex.remove_unused_levels() (which does not reorder) on the - # part of the MultiIndex that is not categorical, or by fixing this - # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() - return index + if all_variables is None: + all_variables = {} + if not isinstance(all_variables, Mapping): + all_variables = {k: None for k in all_variables} + + name = dim_variable.dims[0] + array = getattr(dim_variable._data, "array", None) + index: PandasIndex + + if isinstance(array, pd.MultiIndex): + index = PandasMultiIndex(array, name) + index_vars = index.create_variables() + # check for conflict between level names and variable names + duplicate_names = [k for k in index_vars if k in all_variables and k != name] + if duplicate_names: + # dirty workaround for an edge case where both the dimension + # coordinate and the level coordinates are given for the same + # multi-index object => do not raise an error + # TODO: remove this check when removing the multi-index dimension coordinate + if len(duplicate_names) < len(index.index.names): + conflict = True + else: + duplicate_vars = [all_variables[k] for k in duplicate_names] + conflict = any( + v is None or not dim_variable.equals(v) for v in duplicate_vars + ) + + if conflict: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" + ) + else: + dim_var = {name: dim_variable} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) + return index, index_vars -class Indexes(collections.abc.Mapping): - """Immutable proxy for Dataset or DataArrary indexes.""" - __slots__ = ("_indexes",) +# generic type that represents either a pandas or an xarray index +T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex", Index, pd.Index) + + +class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): + """Immutable proxy for Dataset or DataArrary indexes. + + Keys are coordinate names and values may correspond to either pandas or + xarray indexes. + + Also provides some utility methods. + + """ + + _indexes: dict[Any, T_PandasOrXarrayIndex] + _variables: dict[Any, Variable] + + __slots__ = ( + "_indexes", + "_variables", + "_dims", + "__coord_name_id", + "__id_index", + "__id_coord_names", + ) - def __init__(self, indexes: Mapping[Any, Union[pd.Index, Index]]) -> None: - """Not for public consumption. + def __init__( + self, + indexes: dict[Any, T_PandasOrXarrayIndex], + variables: dict[Any, Variable], + ): + """Constructor not for public consumption. Parameters ---------- - indexes : Dict[Any, pandas.Index] + indexes : dict Indexes held by this object. + variables : dict + Indexed coordinate variables in this object. + """ self._indexes = indexes + self._variables = variables + + self._dims: Mapping[Hashable, int] | None = None + self.__coord_name_id: dict[Any, int] | None = None + self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None + self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None + + @property + def _coord_name_id(self) -> dict[Any, int]: + if self.__coord_name_id is None: + self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()} + return self.__coord_name_id + + @property + def _id_index(self) -> dict[int, T_PandasOrXarrayIndex]: + if self.__id_index is None: + self.__id_index = {id(idx): idx for idx in self.get_unique()} + return self.__id_index + + @property + def _id_coord_names(self) -> dict[int, tuple[Hashable, ...]]: + if self.__id_coord_names is None: + id_coord_names: Mapping[int, list[Hashable]] = defaultdict(list) + for k, v in self._coord_name_id.items(): + id_coord_names[v].append(k) + self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()} + + return self.__id_coord_names + + @property + def variables(self) -> Mapping[Hashable, Variable]: + return Frozen(self._variables) + + @property + def dims(self) -> Mapping[Hashable, int]: + from .variable import calculate_dimensions + + if self._dims is None: + self._dims = calculate_dimensions(self._variables) + + return Frozen(self._dims) + + def get_unique(self) -> list[T_PandasOrXarrayIndex]: + """Return a list of unique indexes, preserving order.""" + + unique_indexes: list[T_PandasOrXarrayIndex] = [] + seen: set[T_PandasOrXarrayIndex] = set() + + for index in self._indexes.values(): + if index not in seen: + unique_indexes.append(index) + seen.add(index) + + return unique_indexes + + def is_multi(self, key: Hashable) -> bool: + """Return True if ``key`` maps to a multi-coordinate index, + False otherwise. + """ + return len(self._id_coord_names[self._coord_name_id[key]]) > 1 + + def get_all_coords( + self, key: Hashable, errors: str = "raise" + ) -> dict[Hashable, Variable]: + """Return all coordinates having the same index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + coords : dict + A dictionary of all coordinate variables having the same index. + + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if key not in self._indexes: + if errors == "raise": + raise ValueError(f"no index found for {key!r} coordinate") + else: + return {} + + all_coord_names = self._id_coord_names[self._coord_name_id[key]] + return {k: self._variables[k] for k in all_coord_names} + + def get_all_dims( + self, key: Hashable, errors: str = "raise" + ) -> Mapping[Hashable, int]: + """Return all dimensions shared by an index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + dims : dict + A dictionary of all dimensions shared by an index. + + """ + from .variable import calculate_dimensions + + return calculate_dimensions(self.get_all_coords(key, errors=errors)) + + def group_by_index( + self, + ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]: + """Returns a list of unique indexes and their corresponding coordinates.""" + + index_coords = [] + + for i in self._id_index: + index = self._id_index[i] + coords = {k: self._variables[k] for k in self._id_coord_names[i]} + index_coords.append((index, coords)) + + return index_coords - def __iter__(self) -> Iterator[pd.Index]: + def to_pandas_indexes(self) -> Indexes[pd.Index]: + """Returns an immutable proxy for Dataset or DataArrary pandas indexes. + + Raises an error if this proxy contains indexes that cannot be coerced to + pandas.Index objects. + + """ + indexes: dict[Hashable, pd.Index] = {} + + for k, idx in self._indexes.items(): + if isinstance(idx, pd.Index): + indexes[k] = idx + elif isinstance(idx, Index): + indexes[k] = idx.to_pandas_index() + + return Indexes(indexes, self._variables) + + def copy_indexes( + self, deep: bool = True + ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: + """Return a new dictionary with copies of indexes, preserving + unique indexes. + + """ + new_indexes = {} + new_index_vars = {} + for idx, coords in self.group_by_index(): + new_idx = idx.copy(deep=deep) + idx_vars = idx.create_variables(coords) + new_indexes.update({k: new_idx for k in coords}) + new_index_vars.update(idx_vars) + + return new_indexes, new_index_vars + + def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]: return iter(self._indexes) - def __len__(self): + def __len__(self) -> int: return len(self._indexes) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self._indexes - def __getitem__(self, key) -> pd.Index: + def __getitem__(self, key) -> T_PandasOrXarrayIndex: return self._indexes[key] def __repr__(self): @@ -482,8 +1216,8 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, "Variable"], dims: Iterable -) -> Dict[Hashable, Index]: + coords: Mapping[Any, Variable], dims: Iterable +) -> dict[Hashable, Index]: """Default indexes for a Dataset/DataArray. Parameters @@ -498,76 +1232,167 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return {key: coords[key]._to_xindex() for key in dims if key in coords} + indexes: dict[Hashable, Index] = {} + coord_names = set(coords) + for name, var in coords.items(): + if name in dims: + index, index_vars = create_default_index_implicit(var, coords) + if set(index_vars) <= coord_names: + indexes.update({k: index for k in index_vars}) + + return indexes -def isel_variable_and_index( - name: Hashable, - variable: "Variable", - index: Index, - indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]], -) -> Tuple["Variable", Optional[Index]]: - """Index a Variable and an Index together. - If the index cannot be indexed, return None (it will be dropped). +def indexes_equal( + index: Index, + other_index: Index, + variable: Variable, + other_variable: Variable, + cache: dict[tuple[int, int], bool | None] = None, +) -> bool: + """Check if two indexes are equal, possibly with cached results. - (note: not compatible yet with xarray flexible indexes). + If the two indexes are not of the same type or they do not implement + equality, fallback to coordinate labels equality check. """ - from .variable import Variable + if cache is None: + # dummy cache + cache = {} + + key = (id(index), id(other_index)) + equal: bool | None = None + + if key not in cache: + if type(index) is type(other_index): + try: + equal = index.equals(other_index) + except NotImplementedError: + equal = None + else: + cache[key] = equal + else: + equal = None + else: + equal = cache[key] - if not indexers: - # nothing to index - return variable.copy(deep=False), index + if equal is None: + equal = variable.equals(other_variable) - if len(variable.dims) > 1: - raise NotImplementedError( - "indexing multi-dimensional variable with indexes is not supported yet" - ) + return cast(bool, equal) - new_variable = variable.isel(indexers) - if new_variable.dims != (name,): - # can't preserve a index if result has new dimensions - return new_variable, None +def indexes_all_equal( + elements: Sequence[tuple[Index, dict[Hashable, Variable]]] +) -> bool: + """Check if indexes are all equal. - # we need to compute the new index - (dim,) = variable.dims - indexer = indexers[dim] - if isinstance(indexer, Variable): - indexer = indexer.data - try: - new_index = index[indexer] - except NotImplementedError: - new_index = None + If they are not of the same type or they do not implement this check, check + if their coordinate variables are all equal instead. - return new_variable, new_index + """ + def check_variables(): + variables = [e[1] for e in elements] + return any( + not variables[0][k].equals(other_vars[k]) + for other_vars in variables[1:] + for k in variables[0] + ) -def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: - """Roll an pandas.Index.""" - pd_index = index.to_pandas_index() - count %= pd_index.shape[0] - if count != 0: - new_idx = pd_index[-count:].append(pd_index[:-count]) + indexes = [e[0] for e in elements] + same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) + if same_type: + try: + not_equal = any( + not indexes[0].equals(other_idx) for other_idx in indexes[1:] + ) + except NotImplementedError: + not_equal = check_variables() else: - new_idx = pd_index[:] - return PandasIndex(new_idx, index.dim) + not_equal = check_variables() + + return not not_equal + + +def _apply_indexes( + indexes: Indexes[Index], + args: Mapping[Any, Any], + func: str, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} + new_index_variables: dict[Hashable, Variable] = {} + + for index, index_vars in indexes.group_by_index(): + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in args.items() if k in index_dims} + if index_args: + new_index = getattr(index, func)(index_args) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + return new_indexes, new_index_variables -def propagate_indexes( - indexes: Optional[Dict[Hashable, Index]], exclude: Optional[Any] = None -) -> Optional[Dict[Hashable, Index]]: - """Creates new indexes dict from existing dict optionally excluding some dimensions.""" - if exclude is None: - exclude = () - if is_scalar(exclude): - exclude = (exclude,) +def isel_indexes( + indexes: Indexes[Index], + indexers: Mapping[Any, Any], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_indexes(indexes, indexers, "isel") - if indexes is not None: - new_indexes = {k: v for k, v in indexes.items() if k not in exclude} - else: - new_indexes = None # type: ignore[assignment] - return new_indexes +def roll_indexes( + indexes: Indexes[Index], + shifts: Mapping[Any, int], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_indexes(indexes, shifts, "roll") + + +def filter_indexes_from_coords( + indexes: Mapping[Any, Index], + filtered_coord_names: set, +) -> dict[Hashable, Index]: + """Filter index items given a (sub)set of coordinate names. + + Drop all multi-coordinate related index items for any key missing in the set + of coordinate names. + + """ + filtered_indexes: dict[Any, Index] = dict(**indexes) + + index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) + for name, idx in indexes.items(): + index_coord_names[id(idx)].add(name) + + for idx_coord_names in index_coord_names.values(): + if not idx_coord_names <= filtered_coord_names: + for k in idx_coord_names: + del filtered_indexes[k] + + return filtered_indexes + + +def assert_no_index_corrupted( + indexes: Indexes[Index], + coord_names: set[Hashable], +) -> None: + """Assert removing coordinates will not corrupt indexes.""" + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of coordinate names to remove + for index, index_coords in indexes.group_by_index(): + common_names = set(index_coords) & coord_names + if common_names and len(common_names) != len(index_coords): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coords) + raise ValueError( + f"cannot remove coordinate(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{index}" + ) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 17d026baa59..c797e6652de 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1,10 +1,23 @@ import enum import functools import operator -from collections import defaultdict +from collections import Counter, defaultdict from contextlib import suppress +from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import numpy as np import pandas as pd @@ -19,7 +32,176 @@ is_duck_dask_array, sparse_array_type, ) -from .utils import maybe_cast_to_coords_dtype +from .types import T_Xarray +from .utils import either_dict_or_kwargs, get_valid_numpy_dtype + +if TYPE_CHECKING: + from .indexes import Index + from .variable import Variable + + +@dataclass +class IndexSelResult: + """Index query results. + + Attributes + ---------- + dim_indexers: dict + A dictionary where keys are array dimensions and values are + location-based indexers. + indexes: dict, optional + New indexes to replace in the resulting DataArray or Dataset. + variables : dict, optional + New variables to replace in the resulting DataArray or Dataset. + drop_coords : list, optional + Coordinate(s) to drop in the resulting DataArray or Dataset. + drop_indexes : list, optional + Index(es) to drop in the resulting DataArray or Dataset. + rename_dims : dict, optional + A dictionary in the form ``{old_dim: new_dim}`` for dimension(s) to + rename in the resulting DataArray or Dataset. + + """ + + dim_indexers: Dict[Any, Any] + indexes: Dict[Any, "Index"] = field(default_factory=dict) + variables: Dict[Any, "Variable"] = field(default_factory=dict) + drop_coords: List[Hashable] = field(default_factory=list) + drop_indexes: List[Hashable] = field(default_factory=list) + rename_dims: Dict[Any, Hashable] = field(default_factory=dict) + + def as_tuple(self): + """Unlike ``dataclasses.astuple``, return a shallow copy. + + See https://stackoverflow.com/a/51802661 + + """ + return ( + self.dim_indexers, + self.indexes, + self.variables, + self.drop_coords, + self.drop_indexes, + self.rename_dims, + ) + + +def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult: + all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) + duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} + + if duplicate_dims: + # TODO: this message is not right when combining indexe(s) queries with + # location-based indexing on a dimension with no dimension-coordinate (failback) + fmt_dims = [ + f"{dim!r}: {count} indexes involved" + for dim, count in duplicate_dims.items() + ] + raise ValueError( + "Xarray does not support label-based selection with more than one index " + "over the following dimension(s):\n" + + "\n".join(fmt_dims) + + "\nSuggestion: use a multi-index for each of those dimension(s)." + ) + + dim_indexers = {} + indexes = {} + variables = {} + drop_coords = [] + drop_indexes = [] + rename_dims = {} + + for res in results: + dim_indexers.update(res.dim_indexers) + indexes.update(res.indexes) + variables.update(res.variables) + drop_coords += res.drop_coords + drop_indexes += res.drop_indexes + rename_dims.update(res.rename_dims) + + return IndexSelResult( + dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims + ) + + +def group_indexers_by_index( + obj: T_Xarray, + indexers: Mapping[Any, Any], + options: Mapping[str, Any], +) -> List[Tuple["Index", Dict[Any, Any]]]: + """Returns a list of unique indexes and their corresponding indexers.""" + unique_indexes = {} + grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict) + + for key, label in indexers.items(): + index: "Index" = obj.xindexes.get(key, None) + + if index is not None: + index_id = id(index) + unique_indexes[index_id] = index + grouped_indexers[index_id][key] = label + elif key in obj.coords: + raise KeyError(f"no index found for coordinate {key!r}") + elif key not in obj.dims: + raise KeyError(f"{key!r} is not a valid dimension or coordinate") + elif len(options): + raise ValueError( + f"cannot supply selection options {options!r} for dimension {key!r}" + "that has no associated coordinate or index" + ) + else: + # key is a dimension without a "dimension-coordinate" + # failback to location-based selection + # TODO: depreciate this implicit behavior and suggest using isel instead? + unique_indexes[None] = None + grouped_indexers[None][key] = label + + return [(unique_indexes[k], grouped_indexers[k]) for k in unique_indexes] + + +def map_index_queries( + obj: T_Xarray, + indexers: Mapping[Any, Any], + method=None, + tolerance=None, + **indexers_kwargs: Any, +) -> IndexSelResult: + """Execute index queries from a DataArray / Dataset and label-based indexers + and return the (merged) query results. + + """ + from .dataarray import DataArray + + # TODO benbovy - flexible indexes: remove when custom index options are available + if method is None and tolerance is None: + options = {} + else: + options = {"method": method, "tolerance": tolerance} + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") + grouped_indexers = group_indexers_by_index(obj, indexers, options) + + results = [] + for index, labels in grouped_indexers: + if index is None: + # forward dimension indexers with no index/coordinate + results.append(IndexSelResult(labels)) + else: + results.append(index.sel(labels, **options)) # type: ignore[call-arg] + + merged = merge_sel_results(results) + + # drop dimension coordinates found in dimension indexers + # (also drop multi-index if any) + # (.sel() already ensures alignment) + for k, v in merged.dim_indexers.items(): + if isinstance(v, DataArray): + if k in v._indexes: + v = v.reset_index(k) + drop_coords = [name for name in v._coords if name in merged.dim_indexers] + merged.dim_indexers[k] = v.drop_vars(drop_coords) + + return merged def expanded_indexer(key, ndim): @@ -56,80 +238,6 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): - # TODO: benbovy - flexible indexes: indexers are still grouped by dimension - # - Make xarray.Index hashable so that it can be used as key in a mapping? - indexes = {} - grouped_indexers = defaultdict(dict) - - # TODO: data_obj.xindexes should eventually return the PandasIndex instance - # for each multi-index levels - xindexes = dict(data_obj.xindexes) - for level, dim in data_obj._level_coords.items(): - xindexes[level] = xindexes[dim] - - for key, label in indexers.items(): - try: - index = xindexes[key] - coord = data_obj.coords[key] - dim = coord.dims[0] - if dim not in indexes: - indexes[dim] = index - - label = maybe_cast_to_coords_dtype(label, coord.dtype) - grouped_indexers[dim][key] = label - - except KeyError: - if key in data_obj.coords: - raise KeyError(f"no index found for coordinate {key}") - elif key not in data_obj.dims: - raise KeyError(f"{key} is not a valid dimension or coordinate") - # key is a dimension without coordinate: we'll reuse the provided labels - elif method is not None or tolerance is not None: - raise ValueError( - "cannot supply ``method`` or ``tolerance`` " - "when the indexed dimension does not have " - "an associated coordinate." - ) - grouped_indexers[None][key] = label - - return indexes, grouped_indexers - - -def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): - """Given an xarray data object and label based indexers, return a mapping - of equivalent location based indexers. Also return a mapping of updated - pandas index objects (in case of multi-index level drop). - """ - if method is not None and not isinstance(method, str): - raise TypeError("``method`` must be a string") - - pos_indexers = {} - new_indexes = {} - - indexes, grouped_indexers = group_indexers_by_index( - data_obj, indexers, method, tolerance - ) - - forward_pos_indexers = grouped_indexers.pop(None, None) - if forward_pos_indexers is not None: - for dim, label in forward_pos_indexers.items(): - pos_indexers[dim] = label - - for dim, index in indexes.items(): - labels = grouped_indexers[dim] - idxr, new_idx = index.query(labels, method=method, tolerance=tolerance) - pos_indexers[dim] = idxr - if new_idx is not None: - new_indexes[dim] = new_idx - - # TODO: benbovy - flexible indexes: support the following cases: - # - an index query returns positional indexers over multiple dimensions - # - check/combine positional indexers returned by multiple indexes over the same dimension - - return pos_indexers, new_indexes - - def _normalize_slice(sl, size): """Ensure that given slice only contains positive start and stop values (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])""" @@ -1272,18 +1380,9 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): self.array = utils.safe_cast_to_index(array) if dtype is None: - if isinstance(array, pd.PeriodIndex): - dtype_ = np.dtype("O") - elif hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype_ = array.categories.dtype - elif not utils.is_valid_numpy_dtype(array.dtype): - dtype_ = np.dtype("O") - else: - dtype_ = array.dtype + self._dtype = get_valid_numpy_dtype(array) else: - dtype_ = np.dtype(dtype) # type: ignore[assignment] - self._dtype = dtype_ + self._dtype = np.dtype(dtype) # type: ignore[assignment] @property def dtype(self) -> np.dtype: @@ -1303,6 +1402,26 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: def shape(self) -> Tuple[int]: return (len(self.array),) + def _convert_scalar(self, item): + if item is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + item = np.datetime64("NaT", "ns") + elif isinstance(item, timedelta): + item = np.timedelta64(getattr(item, "value", item), "ns") + elif isinstance(item, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + item = np.asarray(item.to_datetime64()) + elif self.dtype != object: + item = np.asarray(item, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + return utils.to_0d_array(item) + def __getitem__( self, indexer ) -> Union[ @@ -1319,34 +1438,14 @@ def __getitem__( (key,) = key if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(self.array.values)[indexer] + return NumpyIndexingAdapter(np.asarray(self))[indexer] result = self.array[key] if isinstance(result, pd.Index): - result = type(self)(result, dtype=self.dtype) + return type(self)(result, dtype=self.dtype) else: - # result is a scalar - if result is pd.NaT: - # work around the impossibility of casting NaT with asarray - # note: it probably would be better in general to return - # pd.Timestamp rather np.than datetime64 but this is easier - # (for now) - result = np.datetime64("NaT", "ns") - elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, "value", result), "ns") - elif isinstance(result, pd.Timestamp): - # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 - # numpy fails to convert pd.Timestamp to np.datetime64[ns] - result = np.asarray(result.to_datetime64()) - elif self.dtype != object: - result = np.asarray(result, dtype=self.dtype) - - # as for numpy.ndarray indexing, we always want the result to be - # a NumPy array. - result = utils.to_0d_array(result) - - return result + return self._convert_scalar(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1382,11 +1481,9 @@ def __init__( array: pd.MultiIndex, dtype: DTypeLike = None, level: Optional[str] = None, - adapter: Optional[PandasIndexingAdapter] = None, ): super().__init__(array, dtype) self.level = level - self.adapter = adapter def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if self.level is not None: @@ -1394,16 +1491,47 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: else: return super().__array__(dtype) - @functools.lru_cache(1) + def _convert_scalar(self, item): + if isinstance(item, tuple) and self.level is not None: + idx = tuple(self.array.names).index(self.level) + item = item[idx] + return super()._convert_scalar(item) + def __getitem__(self, indexer): - if self.adapter is None: - return super().__getitem__(indexer) - else: - return self.adapter.__getitem__(indexer) + result = super().__getitem__(indexer) + if isinstance(result, type(self)): + result.level = self.level + + return result def __repr__(self) -> str: if self.level is None: return super().__repr__() else: - props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + props = ( + f"(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + ) return f"{type(self).__name__}{props}" + + def _repr_inline_(self, max_width) -> str: + # special implementation to speed-up the repr for big multi-indexes + if self.level is None: + return "MultiIndex" + else: + from .formatting import format_array_flat + + if self.size > 100 and max_width < self.size: + n_values = max_width + indices = np.concatenate( + [np.arange(0, n_values), np.arange(-n_values, 0)] + ) + subset = self[OuterIndexer((indices,))] + else: + subset = self + + return format_array_flat(np.asarray(subset), max_width) + + def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter": + # see PandasIndexingAdapter.copy + array = self.array.copy(deep=True) if deep else self.array + return type(self)(array, self._dtype, self.level) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index e5407ae79c3..b428d4ae958 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from typing import ( TYPE_CHECKING, AbstractSet, @@ -19,9 +20,15 @@ from . import dtypes from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, PandasIndex +from .indexes import ( + Index, + Indexes, + create_default_index_implicit, + filter_indexes_from_coords, + indexes_equal, +) from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import Variable, as_variable, assert_unique_multiindex_level_names +from .variable import Variable, as_variable, calculate_dimensions if TYPE_CHECKING: from .coordinates import Coordinates @@ -32,9 +39,9 @@ ArrayLike = Any VariableLike = Union[ ArrayLike, - Tuple[DimsLike, ArrayLike], - Tuple[DimsLike, ArrayLike, Mapping], - Tuple[DimsLike, ArrayLike, Mapping, Mapping], + tuple[DimsLike, ArrayLike], + tuple[DimsLike, ArrayLike, Mapping], + tuple[DimsLike, ArrayLike, Mapping, Mapping], ] XarrayValue = Union[DataArray, Variable, VariableLike] DatasetLike = Union[Dataset, Mapping[Any, XarrayValue]] @@ -165,11 +172,44 @@ def _assert_compat_valid(compat): MergeElement = Tuple[Variable, Optional[Index]] +def _assert_prioritized_valid( + grouped: dict[Hashable, list[MergeElement]], + prioritized: Mapping[Any, MergeElement], +) -> None: + """Make sure that elements given in prioritized will not corrupt any + index given in grouped. + """ + prioritized_names = set(prioritized) + grouped_by_index: dict[int, list[Hashable]] = defaultdict(list) + indexes: dict[int, Index] = {} + + for name, elements_list in grouped.items(): + for (_, index) in elements_list: + if index is not None: + grouped_by_index[id(index)].append(name) + indexes[id(index)] = index + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of names given in prioritized + for index_id, index_coord_names in grouped_by_index.items(): + index_names = set(index_coord_names) + common_names = index_names & prioritized_names + if common_names and len(common_names) != len(index_names): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coord_names) + raise ValueError( + f"cannot set or update variable(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{indexes[index_id]!r}" + ) + + def merge_collected( grouped: dict[Hashable, list[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", - combine_attrs="override", + combine_attrs: str | None = "override", + equals: dict[Hashable, bool] = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -179,6 +219,8 @@ def merge_collected( prioritized : mapping compat : str Type of equality check to use when checking for conflicts. + equals : mapping, optional + corresponding to result of compat test Returns ------- @@ -188,11 +230,15 @@ def merge_collected( """ if prioritized is None: prioritized = {} + if equals is None: + equals = {} _assert_compat_valid(compat) + _assert_prioritized_valid(grouped, prioritized) merged_vars: dict[Hashable, Variable] = {} merged_indexes: dict[Hashable, Index] = {} + index_cmp_cache: dict[tuple[int, int], bool | None] = {} for name, elements_list in grouped.items(): if name in prioritized: @@ -206,17 +252,19 @@ def merge_collected( for variable, index in elements_list if index is not None ] - if indexed_elements: # TODO(shoyer): consider adjusting this logic. Are we really # OK throwing away variable without an index in favor of # indexed variables, without even checking if values match? variable, index = indexed_elements[0] - for _, other_index in indexed_elements[1:]: - if not index.equals(other_index): + for other_var, other_index in indexed_elements[1:]: + if not indexes_equal( + index, other_index, variable, other_var, index_cmp_cache + ): raise MergeError( - f"conflicting values for index {name!r} on objects to be " - f"combined:\nfirst value: {index!r}\nsecond value: {other_index!r}" + f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n" + f"first index: {index!r}\nsecond index: {other_index!r}\n" + f"first variable: {variable!r}\nsecond variable: {other_var!r}\n" ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: @@ -234,7 +282,9 @@ def merge_collected( else: variables = [variable for variable, _ in elements_list] try: - merged_vars[name] = unique_variable(name, variables, compat) + merged_vars[name] = unique_variable( + name, variables, compat, equals.get(name, None) + ) except MergeError: if compat != "minimal": # we need more than "minimal" compatibility (for which @@ -251,6 +301,7 @@ def merge_collected( def collect_variables_and_indexes( list_of_mappings: list[DatasetLike], + indexes: Mapping[Any, Any] | None = None, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -260,15 +311,21 @@ def collect_variables_and_indexes( - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in an xarray.Variable - or an xarray.DataArray + + If a mapping of indexes is given, those indexes are assigned to all variables + with a matching key/name. + """ from .dataarray import DataArray from .dataset import Dataset - grouped: dict[Hashable, list[tuple[Variable, Index | None]]] = {} + if indexes is None: + indexes = {} + + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) def append(name, variable, index): - values = grouped.setdefault(name, []) - values.append((variable, index)) + grouped[name].append((variable, index)) def append_all(variables, indexes): for name, variable in variables.items(): @@ -276,27 +333,26 @@ def append_all(variables, indexes): for mapping in list_of_mappings: if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping.xindexes) + append_all(mapping.variables, mapping._indexes) continue for name, variable in mapping.items(): if isinstance(variable, DataArray): coords = variable._coords.copy() # use private API for speed - indexes = dict(variable.xindexes) + indexes = dict(variable._indexes) # explicitly overwritten variables should take precedence coords.pop(name, None) indexes.pop(name, None) append_all(coords, indexes) variable = as_variable(variable, name=name) - - if variable.dims == (name,): - idx_variable = variable.to_index_variable() - index = variable._to_xindex() - append(name, idx_variable, index) + if name in indexes: + append(name, variable, indexes[name]) + elif variable.dims == (name,): + idx, idx_vars = create_default_index_implicit(variable) + append_all(idx_vars, {k: idx for k in idx_vars}) else: - index = None - append(name, variable, index) + append(name, variable, None) return grouped @@ -305,14 +361,14 @@ def collect_from_coordinates( list_of_coords: list[Coordinates], ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes to be merged from Coordinate objects.""" - grouped: dict[Hashable, list[tuple[Variable, Index | None]]] = {} + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) for coords in list_of_coords: variables = coords.variables indexes = coords.xindexes for name, variable in variables.items(): - value = grouped.setdefault(name, []) - value.append((variable, indexes.get(name))) + grouped[name].append((variable, indexes.get(name))) + return grouped @@ -342,7 +398,14 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) + # TODO: indexes should probably be filtered in collected elements + # before merging them + merged_coords, merged_indexes = merge_collected( + filtered, prioritized, combine_attrs=combine_attrs + ) + merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords)) + + return merged_coords, merged_indexes def determine_coords( @@ -471,26 +534,56 @@ def merge_coords( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected(collected, prioritized, compat=compat) - assert_unique_multiindex_level_names(variables) return variables, out_indexes -def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" - objects = [data, coords] + indexes, coords = _create_indexes_from_coords(coords, data_vars) + objects = [data_vars, coords] explicit_coords = coords.keys() - indexes = dict(_extract_indexes_from_coords(coords)) return merge_core( - objects, compat, join, explicit_coords=explicit_coords, indexes=indexes + objects, + compat, + join, + explicit_coords=explicit_coords, + indexes=Indexes(indexes, coords), ) -def _extract_indexes_from_coords(coords): - """Yields the name & index of valid indexes from a mapping of coords""" - for name, variable in coords.items(): - variable = as_variable(variable, name=name) +def _create_indexes_from_coords(coords, data_vars=None): + """Maybe create default indexes from a mapping of coordinates. + + Return those indexes and updated coordinates. + """ + all_variables = dict(coords) + if data_vars is not None: + all_variables.update(data_vars) + + indexes = {} + updated_coords = {} + + # this is needed for backward compatibility: when a pandas multi-index + # is given as data variable, it is promoted as index / level coordinates + # TODO: depreciate this implicit behavior + index_vars = { + k: v + for k, v in all_variables.items() + if k in coords or isinstance(v, pd.MultiIndex) + } + + for name, obj in index_vars.items(): + variable = as_variable(obj, name=name) + if variable.dims == (name,): - yield name, variable._to_xindex() + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + updated_coords.update(idx_vars) + all_variables.update(idx_vars) + else: + updated_coords[name] = obj + + return indexes, updated_coords def assert_valid_explicit_coords(variables, dims, explicit_coords): @@ -566,7 +659,7 @@ class _MergeResult(NamedTuple): variables: dict[Hashable, Variable] coord_names: set[Hashable] dims: dict[Hashable, int] - indexes: dict[Hashable, pd.Index] + indexes: dict[Hashable, Index] attrs: dict[Hashable, Any] @@ -621,7 +714,7 @@ def merge_core( MergeError if the merge cannot be done successfully. """ from .dataarray import DataArray - from .dataset import Dataset, calculate_dimensions + from .dataset import Dataset _assert_compat_valid(compat) @@ -629,13 +722,11 @@ def merge_core( aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - collected = collect_variables_and_indexes(aligned) - + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs ) - assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) @@ -870,7 +961,7 @@ def merge( >>> xr.merge([x, y, z], join="exact") Traceback (most recent call last): ... - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' where ... Raises ------ @@ -976,18 +1067,9 @@ def dataset_update_method(dataset: Dataset, other: CoercibleMapping) -> _MergeRe if coord_names: other[key] = value.drop_vars(coord_names) - # use ds.coords and not ds.indexes, else str coords are cast to object - # TODO: benbovy - flexible indexes: make it work with any xarray index - indexes = {} - for key, index in dataset.xindexes.items(): - if isinstance(index, PandasIndex): - indexes[key] = dataset.coords[key] - else: - indexes[key] = index - return merge_core( [dataset, other], priority_arg=1, - indexes=indexes, # type: ignore + indexes=dataset.xindexes, combine_attrs="override", ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index c1776145e21..3d33631bebd 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -321,12 +321,10 @@ def interp_na( if not is_scalar(max_gap): raise ValueError("max_gap must be a scalar.") - # TODO: benbovy - flexible indexes: update when CFTimeIndex (and DatetimeIndex?) - # has its own class inheriting from xarray.Index if ( - dim in self.xindexes + dim in self._indexes and isinstance( - self.xindexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + self._indexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) ) and use_coordinate ): diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3f6bb34a36e..fd1f3f9e999 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -292,7 +292,7 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.xindexes.items(): + for name, index in result._indexes.items(): if name in expected["shapes"]: if result.sizes[name] != expected["shapes"][name]: raise ValueError( @@ -359,27 +359,27 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0].xindexes) + input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.xindexes) + input_indexes.update(arg._indexes) if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template.xindexes) + template_indexes = set(template._indexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.xindexes[k] for k in new_indexes}) + indexes.update({k: template._indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template.xindexes) + indexes = dict(template._indexes) if isinstance(template, DataArray): output_chunks = dict( zip(template.dims, template.chunks) # type: ignore[arg-type] @@ -558,7 +558,7 @@ def subset_dataset_to_block( attrs=template.attrs, ) - for index in result.xindexes: + for index in result._indexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding @@ -568,7 +568,7 @@ def subset_dataset_to_block( for dim in dims: if dim in output_chunks: var_chunks.append(output_chunks[dim]) - elif dim in result.xindexes: + elif dim in result._indexes: var_chunks.append((result.sizes[dim],)) elif dim in template.dims: # new unindexed dimension diff --git a/xarray/core/types.py b/xarray/core/types.py index 9f0f9eee54c..3f368501b25 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -9,6 +9,7 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy + from .indexes import Index from .npcompat import ArrayLike from .variable import Variable @@ -21,6 +22,7 @@ T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Index = TypeVar("T_Index", bound="Index") # Maybe we rename this to T_Data or something less Fortran-y? T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") diff --git a/xarray/core/utils.py b/xarray/core/utils.py index da3a196a621..a0f5bfdcf27 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -22,7 +22,6 @@ Mapping, MutableMapping, MutableSet, - Sequence, TypeVar, cast, ) @@ -69,10 +68,24 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: return index -def maybe_cast_to_coords_dtype(label, coords_dtype): - if coords_dtype.kind == "f" and not isinstance(label, slice): - label = np.asarray(label, dtype=coords_dtype) - return label +def get_valid_numpy_dtype(array: np.ndarray | pd.Index): + """Return a numpy compatible dtype from either + a numpy array or a pandas.Index. + + Used for wrapping a pandas.Index as an xarray,Variable. + + """ + if isinstance(array, pd.PeriodIndex): + dtype = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype = array.categories.dtype # type: ignore[union-attr] + elif not is_valid_numpy_dtype(array.dtype): + dtype = np.dtype("O") + else: + dtype = array.dtype + + return dtype def maybe_coerce_to_str(index, original_coords): @@ -105,9 +118,14 @@ def safe_cast_to_index(array: Any) -> pd.Index: if isinstance(array, pd.Index): index = array elif hasattr(array, "to_index"): + # xarray Variable index = array.to_index() elif hasattr(array, "to_pandas_index"): + # xarray Index index = array.to_pandas_index() + elif hasattr(array, "array") and isinstance(array.array, pd.Index): + # xarray PandasIndexingAdapter + index = array.array else: kwargs = {} if hasattr(array, "dtype") and array.dtype.kind == "O": @@ -116,33 +134,6 @@ def safe_cast_to_index(array: Any) -> pd.Index: return _maybe_cast_to_cftimeindex(index) -def multiindex_from_product_levels( - levels: Sequence[pd.Index], names: Sequence[str] = None -) -> pd.MultiIndex: - """Creating a MultiIndex from a product without refactorizing levels. - - Keeping levels the same gives back the original labels when we unstack. - - Parameters - ---------- - levels : sequence of pd.Index - Values for each MultiIndex level. - names : sequence of str, optional - Names for each level. - - Returns - ------- - pandas.MultiIndex - """ - if any(not isinstance(lev, pd.Index) for lev in levels): - raise TypeError("levels must be a list of pd.Index objects") - - split_labels, levels = zip(*[lev.factorize() for lev in levels]) - labels_mesh = np.meshgrid(*split_labels, indexing="ij") - labels = [x.ravel() for x in labels_mesh] - return pd.MultiIndex(levels, labels, sortorder=0, names=names) - - def maybe_wrap_array(original, new_array): """Wrap a transformed array with __array_wrap__ if it can be done safely. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c8d46d20d46..a21cf8c2d97 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -4,7 +4,6 @@ import itertools import numbers import warnings -from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING, Any, Hashable, Mapping, Sequence @@ -17,7 +16,6 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexes import PandasIndex, PandasMultiIndex from .indexing import ( BasicIndexer, OuterIndexer, @@ -531,18 +529,6 @@ def to_index_variable(self): to_coord = utils.alias(to_index_variable, "to_coord") - def _to_xindex(self): - # temporary function used internally as a replacement of to_index() - # returns an xarray Index instance instead of a pd.Index instance - index_var = self.to_index_variable() - index = index_var.to_index() - dim = index_var.dims[0] - - if isinstance(index, pd.MultiIndex): - return PandasMultiIndex(index, dim) - else: - return PandasIndex(index, dim) - def to_index(self): """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() @@ -3007,37 +2993,27 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def assert_unique_multiindex_level_names(variables): - """Check for uniqueness of MultiIndex level names in all given - variables. +def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: + """Calculate the dimensions corresponding to a set of variables. - Not public API. Used for checking consistency of DataArray and Dataset - objects. + Returns dictionary mapping from dimension names to sizes. Raises ValueError + if any of the dimension sizes conflict. """ - level_names = defaultdict(list) - all_level_names = set() - for var_name, var in variables.items(): - if isinstance(var._data, PandasIndexingAdapter): - idx_level_names = var.to_index_variable().level_names - if idx_level_names is not None: - for n in idx_level_names: - level_names[n].append(f"{n!r} ({var_name})") - if idx_level_names: - all_level_names.update(idx_level_names) - - for k, v in level_names.items(): - if k in variables: - v.append(f"({k})") - - duplicate_names = [v for v in level_names.values() if len(v) > 1] - if duplicate_names: - conflict_str = "\n".join(", ".join(v) for v in duplicate_names) - raise ValueError(f"conflicting MultiIndex level name(s):\n{conflict_str}") - # Check confliction between level names and dimensions GH:2299 - for k, v in variables.items(): - for d in v.dims: - if d in all_level_names: + dims: dict[Hashable, int] = {} + last_used = {} + scalar_vars = {k for k, v in variables.items() if not v.dims} + for k, var in variables.items(): + for dim, size in zip(var.dims, var.shape): + if dim in scalar_vars: + raise ValueError( + f"dimension {dim!r} already exists as a scalar variable" + ) + if dim not in dims: + dims[dim] = size + last_used[dim] = k + elif dims[dim] != size: raise ValueError( - "conflicting level / dimension names. {} " - "already exists as a level name.".format(d) + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" ) + return dims diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f09d1eb1853..d942f6656ba 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from ..core.indexes import PandasMultiIndex from ..core.options import OPTIONS from ..core.pycompat import DuckArrayModule from ..core.utils import is_scalar @@ -383,11 +384,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): _assert_valid_xy(darray, x, "x") _assert_valid_xy(darray, y, "y") - if ( - all(k in darray._level_coords for k in (x, y)) - and darray._level_coords[x] == darray._level_coords[y] - ): - raise ValueError("x and y cannot be levels of the same MultiIndex") + if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): + if isinstance(darray._indexes[x], PandasMultiIndex): + raise ValueError("x and y cannot be levels of the same MultiIndex") return x, y @@ -398,11 +397,13 @@ def _assert_valid_xy(darray, xy, name): """ # MultiIndex cannot be plotted; no point in allowing them here - multiindex = {darray._level_coords[lc] for lc in darray._level_coords} + multiindex_dims = { + idx.dim + for idx in darray.xindexes.get_unique() + if isinstance(idx, PandasMultiIndex) + } - valid_xy = ( - set(darray.dims) | set(darray.coords) | set(darray._level_coords) - ) - multiindex + valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims if xy not in valid_xy: valid_xy_str = "', '".join(sorted(valid_xy)) diff --git a/xarray/testing.py b/xarray/testing.py index 4369b828daf..0df34a60e73 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -4,11 +4,12 @@ from typing import Hashable, Set, Union import numpy as np +import pandas as pd from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import Index, default_indexes +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -251,7 +252,9 @@ def assert_chunks_equal(a, b): assert left.chunks == right.chunks -def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): +def _assert_indexes_invariants_checks( + indexes, possible_coord_variables, dims, check_default=True +): assert isinstance(indexes, dict), indexes assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() @@ -262,11 +265,42 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): } assert indexes.keys() <= index_vars, (set(indexes), index_vars) - # Note: when we support non-default indexes, these checks should be opt-in - # only! - defaults = default_indexes(possible_coord_variables, dims) - assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) - assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) + # check pandas index wrappers vs. coordinate data adapters + for k, index in indexes.items(): + if isinstance(index, PandasIndex): + pd_index = index.index + var = possible_coord_variables[k] + assert (index.dim,) == var.dims, (pd_index, var) + if k == index.dim: + # skip multi-index levels here (checked below) + assert index.coord_dtype == var.dtype, (index.coord_dtype, var.dtype) + assert isinstance(var._data.array, pd.Index), var._data.array + # TODO: check identity instead of equality? + assert pd_index.equals(var._data.array), (pd_index, var) + if isinstance(index, PandasMultiIndex): + pd_index = index.index + for name in index.index.names: + assert name in possible_coord_variables, (pd_index, index_vars) + var = possible_coord_variables[name] + assert (index.dim,) == var.dims, (pd_index, var) + assert index.level_coords_dtype[name] == var.dtype, ( + index.level_coords_dtype[name], + var.dtype, + ) + assert isinstance(var._data.array, pd.MultiIndex), var._data.array + assert pd_index.equals(var._data.array), (pd_index, var) + # check all all levels are in `indexes` + assert name in indexes, (name, set(indexes)) + # index identity is used to find unique indexes in `indexes` + assert index is indexes[name], (pd_index, indexes[name].index) + + if check_default: + defaults = default_indexes(possible_coord_variables, dims) + assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + assert all(v.equals(defaults[k]) for k, v in indexes.items()), ( + indexes, + defaults, + ) def _assert_variable_invariants(var: Variable, name: Hashable = None): @@ -285,7 +319,7 @@ def _assert_variable_invariants(var: Variable, name: Hashable = None): assert isinstance(var._attrs, (type(None), dict)), name_or_empty + (var._attrs,) -def _assert_dataarray_invariants(da: DataArray): +def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._variable, Variable), da._variable _assert_variable_invariants(da._variable) @@ -302,10 +336,12 @@ def _assert_dataarray_invariants(da: DataArray): _assert_variable_invariants(v, k) if da._indexes is not None: - _assert_indexes_invariants_checks(da._indexes, da._coords, da.dims) + _assert_indexes_invariants_checks( + da._indexes, da._coords, da.dims, check_default=check_default_indexes + ) -def _assert_dataset_invariants(ds: Dataset): +def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): assert isinstance(ds._variables, dict), type(ds._variables) assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables for k, v in ds._variables.items(): @@ -336,13 +372,17 @@ def _assert_dataset_invariants(ds: Dataset): } if ds._indexes is not None: - _assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims) + _assert_indexes_invariants_checks( + ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes + ) assert isinstance(ds._encoding, (type(None), dict)) assert isinstance(ds._attrs, (type(None), dict)) -def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]): +def _assert_internal_invariants( + xarray_obj: Union[DataArray, Dataset, Variable], check_default_indexes: bool +): """Validate that an xarray object satisfies its own internal invariants. This exists for the benefit of xarray's own test suite, but may be useful @@ -352,9 +392,13 @@ def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]) if isinstance(xarray_obj, Variable): _assert_variable_invariants(xarray_obj) elif isinstance(xarray_obj, DataArray): - _assert_dataarray_invariants(xarray_obj) + _assert_dataarray_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) elif isinstance(xarray_obj, Dataset): - _assert_dataset_invariants(xarray_obj) + _assert_dataset_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) else: raise TypeError( "{} is not a supported type for xarray invariant checks".format( diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 00fec07f793..7872fec2e62 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -177,25 +177,25 @@ def assert_no_warnings(): # invariants -def assert_equal(a, b): +def assert_equal(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_equal(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_identical(a, b): +def assert_identical(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_identical(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_allclose(a, b, **kwargs): +def assert_allclose(a, b, check_default_indexes=True, **kwargs): __tracebackhide__ = True xarray.testing.assert_allclose(a, b, **kwargs) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) def create_test_data(seed=None, add_attrs=True): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e0bc0b10437..a66ec4f174c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3290,7 +3290,9 @@ def test_open_mfdataset_exact_join_raises_error(self, combine, concat_dim, opt): with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): if combine == "by_coords": files.reverse() - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align objects.*join.*exact.*" + ): open_mfdataset( files, data_vars=opt, diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 6ba0c6a9be2..7e50e0d8b53 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -392,7 +392,7 @@ def test_combine_nested_join(self, join, expected): def test_combine_nested_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact"): combine_nested(objs, concat_dim="x", join="exact") def test_empty_input(self): @@ -747,7 +747,7 @@ def test_combine_coords_join(self, join, expected): def test_combine_coords_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): combine_nested(objs, concat_dim="x", join="exact") @pytest.mark.parametrize( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index dac3c17b1f1..6a86738ab2f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -960,7 +960,7 @@ def test_dataset_join() -> None: ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) # by default, cannot have different labels - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): apply_ufunc(operator.add, ds0, ds1) with pytest.raises(TypeError, match=r"must supply"): apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") @@ -1892,7 +1892,7 @@ def test_dot_align_coords(use_dask) -> None: xr.testing.assert_allclose(expected, actual) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.dot(da_a, da_b) # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 8a37df62261..8abede64761 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -7,6 +7,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.indexes import PandasIndex from . import ( InaccessibleArray, @@ -259,7 +260,7 @@ def test_concat_join_kwarg(self) -> None: coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -465,7 +466,7 @@ def test_concat_dim_is_variable(self) -> None: def test_concat_multiindex(self) -> None: x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) - expected = Dataset({"x": x}) + expected = Dataset(coords={"x": x}) actual = concat( [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" ) @@ -639,7 +640,7 @@ def test_concat_join_kwarg(self) -> None: coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -783,3 +784,27 @@ def test_concat_typing_check() -> None: match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): concat([da, ds], dim="foo") + + +def test_concat_not_all_indexes() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + # ds2.x has no default index + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + + with pytest.raises( + ValueError, match=r"'x' must have either an index or no index in all datasets.*" + ): + concat([ds1, ds2], dim="x") + + +def test_concat_index_not_same_dim() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + # TODO: use public API for setting a non-default index, when available + ds2._indexes["x"] = PandasIndex([3, 4], "y") + + with pytest.raises( + ValueError, + match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", + ): + concat([ds1, ds2], dim="x") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 42d8df57cb7..872c0c6f1db 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -51,14 +51,14 @@ def assertLazyAnd(self, expected, actual, test): if isinstance(actual, Dataset): for k, v in actual.variables.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) elif isinstance(actual, DataArray): assert isinstance(actual.data, da.Array) for k, v in actual.coords.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) @@ -1226,7 +1226,7 @@ def sumda(da1, da2): with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) - with pytest.raises(ValueError, match=r"indexes along dimension 'x' are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*index.*are not equal"): xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) # reduction diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 55c68b7ff6b..ed9df9a871c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -24,7 +24,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.indexes import Index, PandasIndex, propagate_indexes +from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.utils import is_scalar from xarray.tests import ( ReturnItem, @@ -93,9 +93,9 @@ def test_repr_multiindex(self): array([0, 1, 2, 3]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2""" + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2""" ) assert expected == repr(self.mda) @@ -111,9 +111,9 @@ def test_repr_multiindex_long(self): array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + * level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" ) assert expected == repr(mda_long) @@ -769,11 +769,6 @@ def test_contains(self): assert 1 in data_array assert 3 not in data_array - def test_attr_sources_multiindex(self): - # make sure attr-style access for multi-index levels - # returns DataArray objects - assert isinstance(self.mda.level_1, DataArray) - def test_pickle(self): data = DataArray(np.random.random((3, 3)), dims=("id", "time")) roundtripped = pickle.loads(pickle.dumps(data)) @@ -901,7 +896,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"], stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): da.isel( x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), @@ -1007,6 +1002,21 @@ def test_sel_float(self): assert_equal(expected_scalar, actual_scalar) assert_equal(expected_16, actual_16) + def test_sel_float_multiindex(self): + # regression test https://github.com/pydata/xarray/issues/5691 + # test multi-index created from coordinates, one with dtype=float32 + lvl1 = ["a", "a", "b", "b"] + lvl2 = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + da = xr.DataArray( + [1, 2, 3, 4], dims="x", coords={"lvl1": ("x", lvl1), "lvl2": ("x", lvl2)} + ) + da = da.set_index(x=["lvl1", "lvl2"]) + + actual = da.sel(lvl1="a", lvl2=0.1) + expected = da.isel(x=0) + + assert_equal(actual, expected) + def test_sel_no_index(self): array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0)) @@ -1261,7 +1271,7 @@ def test_selection_multiindex_from_level(self): data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) assert data.dims == ("xy",) actual = data.sel(y="a") - expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop_vars("y") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y") assert_equal(actual, expected) def test_virtual_default_coords(self): @@ -1311,13 +1321,15 @@ def test_coords(self): assert expected == actual del da.coords["x"] - da._indexes = propagate_indexes(da._indexes, exclude="x") + da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords)) expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda["level_1"] = np.arange(4) - self.mda.coords["level_1"] = np.arange(4) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = ("x", np.arange(4)) + self.mda.coords["level_1"] = ("x", np.arange(4)) def test_coords_to_index(self): da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) @@ -1422,14 +1434,22 @@ def test_reset_coords(self): with pytest.raises(ValueError, match=r"cannot remove index"): data.reset_coords("y") + # non-dimension index coordinate + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + data = DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x", name="foo") + with pytest.raises(ValueError, match=r"cannot remove index"): + data.reset_coords("lvl1") + def test_assign_coords(self): array = DataArray(10) actual = array.assign_coords(c=42) expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda.assign_coords(level_1=range(4)) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda.assign_coords(level_1=("x", range(4))) # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") @@ -1437,6 +1457,8 @@ def test_assign_coords(self): da["x"] = [0, 1, 2, 3] # size conflict with pytest.raises(ValueError): da.coords["x"] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) @@ -1453,6 +1475,12 @@ def test_set_coords_update_index(self): actual.coords["x"] = ["a", "b", "c"] assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) + def test_set_coords_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = range(4) + def test_coords_replacement_alignment(self): # regression test for GH725 arr = DataArray([0, 1, 2], dims=["abc"]) @@ -1473,6 +1501,12 @@ def test_coords_delitem_delete_indexes(self): del arr.coords["x"] assert "x" not in arr.xindexes + def test_coords_delitem_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del self.mda.coords["level_1"] + def test_broadcast_like(self): arr1 = DataArray( np.ones((2, 3)), @@ -1624,10 +1658,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # as kwargs array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") @@ -1635,10 +1666,7 @@ def test_swap_dims(self): actual = array.swap_dims(x="y") assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) @@ -1647,10 +1675,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) def test_expand_dims_error(self): array = DataArray( @@ -1843,49 +1868,52 @@ def test_set_index(self): array2d.set_index(x="level") # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r".*variable\(s\) do not exist"): obj.set_index(x="level_4") - assert str(excinfo.value) == "level_4 is not the name of an existing variable." def test_reset_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", self.mindex.values) expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(self.mindex.names) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(["x", "level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert list(obj.xindexes) == ["level_2"] - coords = { - "x": ("x", self.mindex.droplevel("level_1")), - "level_1": ("x", self.mindex.get_level_values("level_1")), - } expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index(["level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert list(obj.xindexes) == ["level_2"] + assert type(obj.xindexes["level_2"]) is PandasIndex - expected = DataArray(self.mda.values, dims="x") + coords = {k: v for k, v in coords.items() if k != "x"} + expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x", drop=True) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) array = self.mda.copy() array = array.reset_index(["x"], drop=True) - assert_identical(array, expected) + assert_identical(array, expected, check_default_indexes=False) # single index array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") - expected = DataArray([1, 2], coords={"x_": ("x", ["a", "b"])}, dims="x") - assert_identical(array.reset_index("x"), expected) + obj = array.reset_index("x") + assert_identical(obj, array, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) da = DataArray([1, 0], [coord_1]) - expected = DataArray([1, 0], {"coord_1_": coord_1}, dims=["coord_1"]) obj = da.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, da, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): midx = self.mindex.reorder_levels(["level_2", "level_1"]) @@ -2157,11 +2185,15 @@ def test_dataset_math(self): assert_identical(actual, expected) def test_stack_unstack(self): - orig = DataArray([[0, 1], [2, 3]], dims=["x", "y"], attrs={"foo": 2}) + orig = DataArray( + [[0, 1], [2, 3]], + dims=["x", "y"], + attrs={"foo": 2}, + ) assert_identical(orig, orig.unstack()) # test GH3000 - a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() + a = orig[:0, :1].stack(dim=("x", "y")).indexes["dim"] b = pd.MultiIndex( levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], codes=[[], []], @@ -2176,16 +2208,21 @@ def test_stack_unstack(self): assert_identical(orig, actual) dims = ["a", "b", "c", "d", "e"] - orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + coords = { + "a": [0], + "b": [1, 2], + "c": [3, 4, 5], + "d": [6, 7], + "e": [8], + } + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), coords=coords, dims=dims) stacked = orig.stack(ab=["a", "b"], cd=["c", "d"]) unstacked = stacked.unstack(["ab", "cd"]) - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) unstacked = stacked.unstack() - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 @@ -2317,6 +2354,19 @@ def test_drop_coordinates(self): actual = renamed.drop_vars("foo", errors="ignore") assert_identical(actual, renamed) + def test_drop_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + self.mda.drop_vars("level_1") + + def test_drop_all_multiindex_levels(self): + dim_levels = ["x", "level_1", "level_2"] + actual = self.mda.drop_vars(dim_levels) + # no error, multi-index dropped + for key in dim_levels: + assert key not in actual.xindexes + def test_drop_index_labels(self): arr = DataArray(np.random.randn(2, 3), coords={"y": [0, 1, 2]}, dims=["x", "y"]) actual = arr.drop_sel(y=[0, 1]) @@ -2707,7 +2757,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): align(left.isel(x=0).expand_dims("x"), right, join="override") @pytest.mark.parametrize( @@ -2726,7 +2778,9 @@ def test_align_override(self): ], ) def test_align_override_error(self, darrays): - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(*darrays, join="override") def test_align_exclude(self): @@ -2782,10 +2836,16 @@ def test_align_mixed_indexes(self): assert_identical(result1, array_with_coord) def test_align_without_indexes_errors(self): - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): align( DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], coords=[("x", [0, 1])]), @@ -3584,7 +3644,9 @@ def test_dot_align_coords(self): dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da.dot(dm) da_aligned, dm_aligned = xr.align(da, dm, join="inner") @@ -3635,7 +3697,9 @@ def test_matmul_align_coords(self): assert_identical(result, expected) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da_a @ da_b def test_binary_op_propagate_indexes(self): @@ -6618,7 +6682,7 @@ def test_clip(da): assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1])) # Unclear whether we want this work, OK to adjust the test when we have decided. - with pytest.raises(ValueError, match="arguments without labels along dimension"): + with pytest.raises(ValueError, match="cannot reindex or align along dimension.*"): result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8d6c4f96857..35831b82fa6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -239,9 +239,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -259,9 +259,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + * a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -740,7 +740,9 @@ def test_coords_setitem_with_new_dimension(self): def test_coords_setitem_multiindex(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.coords["level_1"] = range(4) def test_coords_set(self): @@ -1132,7 +1134,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"].drop_vars(["dim2"]), stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): data.isel( dim1=DataArray( [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} @@ -1482,6 +1484,16 @@ def test_sel_drop(self): selected = data.sel(x=0, drop=True) assert_identical(expected, selected) + def test_sel_drop_mindex(self): + midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) + data = Dataset(coords={"x": midx}) + + actual = data.sel(foo="a", drop=True) + assert "foo" not in actual.coords + + actual = data.sel(foo="a", drop=False) + assert_equal(actual.foo, DataArray("a", coords={"foo": "a"})) + def test_isel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) @@ -1680,7 +1692,7 @@ def test_sel_method(self): with pytest.raises(TypeError, match=r"``method``"): # this should not pass silently - data.sel(method=data) + data.sel(dim2=1, method=data) # cannot pass method if there is no associated coordinate with pytest.raises(ValueError, match=r"cannot supply"): @@ -1820,8 +1832,10 @@ def test_reindex(self): data.reindex("foo") # invalid dimension - with pytest.raises(ValueError, match=r"invalid reindex dim"): - data.reindex(invalid=0) + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): + # data.reindex(invalid=0) # out of order expected = data.sel(dim2=data["dim2"][:5:-1]) @@ -2053,7 +2067,7 @@ def test_align_exact(self): assert_identical(left1, left) assert_identical(left2, left) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.align(left, right, join="exact") def test_align_override(self): @@ -2075,7 +2089,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(left.isel(x=0).expand_dims("x"), right, join="override") def test_align_exclude(self): @@ -2155,11 +2171,15 @@ def test_align_non_unique(self): def test_align_str_dtype(self): - a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]}) - b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]}) + a = Dataset({"foo": ("x", [0, 1])}, coords={"x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) - expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]}) - expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]}) + expected_a = Dataset( + {"foo": ("x", [0, 1, np.NaN])}, coords={"x": ["a", "b", "c"]} + ) + expected_b = Dataset( + {"foo": ("x", [np.NaN, 1, 2])}, coords={"x": ["a", "b", "c"]} + ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -2340,6 +2360,14 @@ def test_drop_variables(self): actual = data.drop({"time", "not_found_here"}, errors="ignore") assert_identical(expected, actual) + def test_drop_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + data.drop_vars("level_1") + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) @@ -2647,12 +2675,14 @@ def test_rename_dims(self): expected = Dataset( {"x": ("x_new", [0, 1, 2]), "y": ("x_new", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x") dims_dict = {"x": "x_new"} actual = original.rename_dims(dims_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_dims(**dims_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError dims_dict_bad = {"x_bad": "x_new"} @@ -2667,25 +2697,56 @@ def test_rename_vars(self): expected = Dataset( {"x_new": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x_new") name_dict = {"x": "x_new"} actual = original.rename_vars(name_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_vars(**name_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError names_dict_bad = {"x_bad": "x_new"} with pytest.raises(ValueError): original.rename_vars(names_dict_bad) - def test_rename_multiindex(self): - mindex = pd.MultiIndex.from_tuples( - [([1, 2]), ([3, 4])], names=["level0", "level1"] - ) - data = Dataset({}, {"x": mindex}) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - data.rename({"x": "level0"}) + def test_rename_dimension_coord(self) -> None: + # rename a dimension corodinate to a non-dimension coordinate + # should preserve index + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + + actual = original.rename_vars({"x": "x_new"}) + assert "x_new" in actual.xindexes + + actual_2 = original.rename_dims({"x": "x_new"}) + assert "x" in actual_2.xindexes + + def test_rename_multiindex(self) -> None: + mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + original = Dataset({}, {"x": mindex}) + expected = Dataset({}, {"x": mindex.rename(["a", "c"])}) + + actual = original.rename({"b": "c"}) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"'a' conflicts"): + original.rename({"x": "a"}) + with pytest.raises(ValueError, match=r"'x' conflicts"): + original.rename({"a": "x"}) + with pytest.raises(ValueError, match=r"'b' conflicts"): + original.rename({"a": "b"}) + + def test_rename_perserve_attrs_encoding(self) -> None: + # test propagate attrs/encoding to new variable(s) created from Index object + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + expected = Dataset(coords={"y": ("y", [0, 1, 2])}) + for ds, dim in zip([original, expected], ["x", "y"]): + ds[dim].attrs = {"foo": "bar"} + ds[dim].encoding = {"foo": "bar"} + + actual = original.rename({"x": "y"}) + assert_identical(actual, expected) @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): @@ -2745,10 +2806,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) roundtripped = actual.swap_dims({"y": "x"}) assert_identical(original.set_coords("y"), roundtripped) @@ -2779,10 +2837,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) def test_expand_dims_error(self): original = Dataset( @@ -2974,33 +3029,49 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) + # ensure pre-existing indexes involved are removed + # (level_2 should be a coordinate with no index) + ds = create_test_multiindex() + coords = {"x": coords["level_1"], "level_2": coords["level_2"]} + expected = Dataset({}, coords=coords) + + obj = ds.set_index(x="level_1") + assert_identical(obj, expected) + # ensure set_index with no existing index and a single data var given # doesn't return multi-index ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) expected = Dataset(coords={"x": [0, 1, 2]}) assert_identical(ds.set_index(x="x_var"), expected) - # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r"bar variable\(s\) do not exist"): ds.set_index(foo="bar") - assert str(excinfo.value) == "bar is not the name of an existing variable." + + with pytest.raises(ValueError, match=r"dimension mismatch.*"): + ds.set_index(y="x_var") def test_reset_index(self): ds = create_test_multiindex() mindex = ds["x"].to_index() indexes = [mindex.get_level_values(n) for n in mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", mindex.values) expected = Dataset({}, coords=coords) obj = ds.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + + ds = Dataset(coords={"y": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match=r".*not coordinates with an index"): + ds.reset_index("y") def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) ds = Dataset({}, {"coord_1": coord_1}) - expected = Dataset({}, {"coord_1_": coord_1}) obj = ds.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, ds, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): ds = create_test_multiindex() @@ -3008,6 +3079,10 @@ def test_reorder_levels(self): midx = mindex.reorder_levels(["level_2", "level_1"]) expected = Dataset({}, coords={"x": midx}) + # check attrs propagated + ds["level_1"].attrs["foo"] = "bar" + expected["level_1"].attrs["foo"] = "bar" + reindexed = ds.reorder_levels(x=["level_2", "level_1"]) assert_identical(reindexed, expected) @@ -3017,15 +3092,22 @@ def test_reorder_levels(self): def test_stack(self): ds = Dataset( - {"a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), "y": ["a", "b"]} + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) expected = Dataset( - {"a": ("z", [0, 0, 1, 1]), "b": ("z", [0, 1, 2, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, ) + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "x", "y"] actual = ds.stack(z=[...]) assert_identical(expected, actual) @@ -3040,17 +3122,85 @@ def test_stack(self): exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) expected = Dataset( - {"a": ("z", [0, 1, 0, 1]), "b": ("z", [0, 2, 1, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 2, 1, 3])}, + coords={"z": exp_index}, ) + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["y", "x"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "y", "x"] + + @pytest.mark.parametrize( + "create_index,expected_keys", + [ + (True, ["z", "x", "y"]), + (False, []), + (None, ["z", "x", "y"]), + ], + ) + def test_stack_create_index(self, create_index, expected_keys) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ) + + actual = ds.stack(z=["x", "y"], create_index=create_index) + assert list(actual.xindexes) == expected_keys + + # TODO: benbovy (flexible indexes) - test error multiple indexes found + # along dimension + create_index=True + + def test_stack_multi_index(self) -> None: + # multi-index on a dimension to stack is discarded too + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + ds = xr.Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, + coords={"x": midx, "y": [0, 1]}, + ) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, + coords={ + "x": ("z", np.repeat(midx.values, 2)), + "lvl1": ("z", np.repeat(midx.get_level_values("lvl1"), 2)), + "lvl2": ("z", np.repeat(midx.get_level_values("lvl2"), 2)), + "y": ("z", [0, 1, 0, 1] * 2), + }, + ) + actual = ds.stack(z=["x", "y"], create_index=False) + assert_identical(expected, actual) + assert len(actual.xindexes) == 0 + + with pytest.raises(ValueError, match=r"cannot create.*wraps a multi-index"): + ds.stack(z=["x", "y"], create_index=True) + + def test_stack_non_dim_coords(self): + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ).rename_vars(x="xx") + + exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, + ) + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "xx", "y"] def test_unstack(self): index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) - ds = Dataset({"b": ("z", [0, 1, 2, 3]), "z": index}) + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords={"z": index}) expected = Dataset( {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} ) + + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + for dim in ["z", ["z"], None]: actual = ds.unstack(dim) assert_identical(actual, expected) @@ -3059,7 +3209,7 @@ def test_unstack_errors(self): ds = Dataset({"x": [1, 2, 3]}) with pytest.raises(ValueError, match=r"does not contain the dimensions"): ds.unstack("foo") - with pytest.raises(ValueError, match=r"do not have a MultiIndex"): + with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): ds.unstack("x") def test_unstack_fill_value(self): @@ -3147,12 +3297,11 @@ def test_stack_unstack_fast(self): def test_stack_unstack_slow(self): ds = Dataset( - { + data_vars={ "a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), - "x": [0, 1], - "y": ["a", "b"], - } + }, + coords={"x": [0, 1], "y": ["a", "b"]}, ) stacked = ds.stack(z=["x", "y"]) actual = stacked.isel(z=slice(None, None, -1)).unstack("z") @@ -3187,8 +3336,6 @@ def test_to_stacked_array_dtype_dims(self): D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - # TODO: benbovy - flexible indexes: update when MultiIndex has its own class - # inherited from xarray.Index assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") @@ -3259,6 +3406,14 @@ def test_update_overwrite_coords(self): expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 3}) assert_identical(data, expected) + def test_update_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + data.update({"level_1": range(4)}) + def test_update_auto_align(self): ds = Dataset({"x": ("t", [3, 4])}, {"t": [0, 1]}) @@ -3359,34 +3514,6 @@ def test_virtual_variable_same_name(self): expected = DataArray(times.time, [("time", times)], name="time") assert_identical(actual, expected) - def test_virtual_variable_multiindex(self): - # access multi-index levels as virtual variables - data = create_test_multiindex() - expected = DataArray( - ["a", "a", "b", "b"], - name="level_1", - coords=[data["x"].to_index()], - dims="x", - ) - assert_identical(expected, data["level_1"]) - - # combine multi-index level and datetime - dr_index = pd.date_range("1/1/2011", periods=4, freq="H") - mindex = pd.MultiIndex.from_arrays( - [["a", "a", "b", "b"], dr_index], names=("level_str", "level_date") - ) - data = Dataset({}, {"x": mindex}) - expected = DataArray( - mindex.get_level_values("level_date").hour, - name="hour", - coords=[mindex], - dims="x", - ) - assert_identical(expected, data["level_date.hour"]) - - # attribute style access - assert_identical(data.level_str, data["level_str"]) - def test_time_season(self): ds = Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] @@ -3427,7 +3554,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=r"already exists as a scalar"): data1["newvar"] = ("scalar", [3, 4, 5]) # can't resize a used dimension - with pytest.raises(ValueError, match=r"arguments without labels"): + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): data1["dim1"] = data1["dim1"][:5] # override an existing value data1["A"] = 3 * data2["A"] @@ -3464,7 +3591,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3]}] data3["var2"] = data3["var2"].T - err_msg = "indexes along dimension 'dim2' are not equal" + err_msg = r"cannot align objects.*not equal along these coordinates.*" with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3, 4]}] err_msg = "Dataset assignment only accepts DataArrays, Datasets, and scalars." @@ -3680,22 +3807,39 @@ def test_assign_attrs(self): def test_assign_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) - # raise an Error when any level name is used as dimension GH:2299 - with pytest.raises(ValueError): - data["y"] = ("level_1", [0, 1]) + + def test_assign_all_multiindex_coords(self): + data = create_test_multiindex() + actual = data.assign(x=range(4), level_1=range(4), level_2=range(4)) + # no error but multi-index dropped in favor of single indexes for each level + assert ( + actual.xindexes["x"] + is not actual.xindexes["level_1"] + is not actual.xindexes["level_2"] + ) def test_merge_multiindex_level(self): data = create_test_multiindex() - other = Dataset({"z": ("level_1", [0, 1])}) # conflict dimension - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", [0, 1])}) + with pytest.raises(ValueError, match=r".*conflicting dimension sizes.*"): data.merge(other) - other = Dataset({"level_1": ("x", [0, 1])}) # conflict variable name - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", range(4))}) + with pytest.raises( + ValueError, match=r"unable to determine.*coordinates or not.*" + ): data.merge(other) + # `other` Dataset coordinates are ignored (bug or feature?) + other = Dataset(coords={"level_1": ("x", range(4))}) + assert_identical(data.merge(other), data) + def test_setitem_original_non_unique_index(self): # regression test for GH943 original = Dataset({"data": ("x", np.arange(5))}, coords={"x": [0, 1, 2, 0, 1]}) @@ -3727,7 +3871,9 @@ def test_setitem_both_non_unique_index(self): def test_setitem_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data["level_1"] = range(4) def test_delitem(self): @@ -3745,6 +3891,13 @@ def test_delitem(self): del actual["y"] assert_identical(expected, actual) + def test_delitem_multiindex_level(self): + data = create_test_multiindex() + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del data["level_1"] + def test_squeeze(self): data = Dataset({"foo": (["x", "y", "z"], [[[1], [2]]])}) for args in [[], [["x"]], [["x", "z"]]]: @@ -4373,7 +4526,7 @@ def test_where_other(self): with pytest.raises(ValueError, match=r"cannot set"): ds.where(ds > 1, other=0, drop=True) - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align .* are not equal"): ds.where(ds > 1, ds.isel(x=slice(3))) with pytest.raises(ValueError, match=r"exact match required"): @@ -5505,7 +5658,7 @@ def test_ipython_key_completion(self): assert sorted(actual) == sorted(expected) # MultiIndex - ds_midx = ds.stack(dim12=["dim1", "dim2"]) + ds_midx = ds.stack(dim12=["dim2", "dim3"]) actual = ds_midx._ipython_key_completions_() expected = [ "var1", diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 529382279de..105cec7e850 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -558,9 +558,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: display_expand_attrs=False, ): actual = formatting.dataset_repr(ds) - col_width = formatting._calculate_col_width( - formatting._get_col_items(ds.variables) - ) + col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( ds, col_width=col_width + 1, max_rows=display_max_rows diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 4ee80f65027..c67619e18c7 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -64,27 +64,27 @@ def test_short_data_repr_html_dask(dask_dataarray) -> None: def test_format_dims_no_dims() -> None: dims: Dict = {} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert formatted == "" def test_format_dims_unsafe_dim_name() -> None: dims = {"": 3, "y": 2} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert "<x>" in formatted def test_format_dims_non_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["time"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["time"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" not in formatted def test_format_dims_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["x"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["x"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" in formatted @@ -119,14 +119,6 @@ def test_repr_of_dataarray(dataarray) -> None: ) -def test_summary_of_multiindex_coord(multiindex) -> None: - idx = multiindex.x.variable.to_index_variable() - formatted = fh._summarize_coord_multiindex("foo", idx) - assert "(level_1, level_2)" in formatted - assert "MultiIndex" in formatted - assert "foo" in formatted - - def test_repr_of_multiindex(multiindex) -> None: formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4b6da82bdc7..f866b68dfa8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -519,10 +519,16 @@ def test_groupby_drops_nans() -> None: actual = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) expected = ( - stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset() + stacked.variable.where(stacked.id.notnull()) + .rename({"xy": "id"}) + .to_dataset() + .reset_index("id", drop=True) + .drop_vars(["lon", "lat"]) + .assign(id=stacked.id.values) + .dropna("id") + .transpose(*actual.dims) ) - expected["id"] = stacked.id.values - assert_identical(actual, expected.dropna("id").transpose(*actual.dims)) + assert_identical(actual, expected) # reduction operation along a different dimension actual = grouped.mean("time") diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 18f76df765d..7edcaa15105 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -1,10 +1,21 @@ +import copy +from typing import Any, Dict, List + import numpy as np import pandas as pd import pytest import xarray as xr -from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe -from xarray.core.variable import IndexVariable +from xarray.core.indexes import ( + Index, + Indexes, + PandasIndex, + PandasMultiIndex, + _asarray_tuplesafe, +) +from xarray.core.variable import IndexVariable, Variable + +from . import assert_equal, assert_identical def test_asarray_tuplesafe() -> None: @@ -19,23 +30,110 @@ def test_asarray_tuplesafe() -> None: assert res[1] == (1,) +class CustomIndex(Index): + def __init__(self, dims) -> None: + self.dims = dims + + +class TestIndex: + @pytest.fixture + def index(self) -> CustomIndex: + return CustomIndex({"x": 2}) + + def test_from_variables(self) -> None: + with pytest.raises(NotImplementedError): + Index.from_variables({}) + + def test_concat(self) -> None: + with pytest.raises(NotImplementedError): + Index.concat([], "x") + + def test_stack(self) -> None: + with pytest.raises(NotImplementedError): + Index.stack({}, "x") + + def test_unstack(self, index) -> None: + with pytest.raises(NotImplementedError): + index.unstack() + + def test_create_variables(self, index) -> None: + assert index.create_variables() == {} + assert index.create_variables({"x": "var"}) == {"x": "var"} + + def test_to_pandas_index(self, index) -> None: + with pytest.raises(TypeError): + index.to_pandas_index() + + def test_isel(self, index) -> None: + assert index.isel({}) is None + + def test_sel(self, index) -> None: + with pytest.raises(NotImplementedError): + index.sel({}) + + def test_join(self, index) -> None: + with pytest.raises(NotImplementedError): + index.join(CustomIndex({"y": 2})) + + def test_reindex_like(self, index) -> None: + with pytest.raises(NotImplementedError): + index.reindex_like(CustomIndex({"y": 2})) + + def test_equals(self, index) -> None: + with pytest.raises(NotImplementedError): + index.equals(CustomIndex({"y": 2})) + + def test_roll(self, index) -> None: + assert index.roll({}) is None + + def test_rename(self, index) -> None: + assert index.rename({}, {}) is index + + @pytest.mark.parametrize("deep", [True, False]) + def test_copy(self, index, deep) -> None: + copied = index.copy(deep=deep) + assert isinstance(copied, CustomIndex) + assert copied is not index + + copied.dims["x"] = 3 + if deep: + assert copied.dims != index.dims + assert copied.dims != copy.deepcopy(index).dims + else: + assert copied.dims is index.dims + assert copied.dims is copy.copy(index).dims + + def test_getitem(self, index) -> None: + with pytest.raises(NotImplementedError): + index[:] + + class TestPandasIndex: def test_constructor(self) -> None: pd_idx = pd.Index([1, 2, 3]) index = PandasIndex(pd_idx, "x") - assert index.index is pd_idx + assert index.index.equals(pd_idx) + # makes a shallow copy + assert index.index is not pd_idx assert index.dim == "x" + # test no name set for pd.Index + pd_idx.name = None + index = PandasIndex(pd_idx, "x") + assert index.index.name == "x" + def test_from_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) var = xr.Variable( - "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} ) - index, index_vars = PandasIndex.from_variables({"x": var}) - xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + index = PandasIndex.from_variables({"x": var}) assert index.dim == "x" - assert index.index.equals(index_vars["x"].to_index()) + assert index.index.equals(pd.Index(data)) + assert index.coord_dtype == data.dtype var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises(ValueError, match=r".*only accepts one variable.*"): @@ -46,94 +144,206 @@ def test_from_variables(self) -> None: ): PandasIndex.from_variables({"foo": var2}) - def test_from_pandas_index(self) -> None: - pd_idx = pd.Index([1, 2, 3], name="foo") - - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - - assert index.dim == "x" - assert index.index is pd_idx - assert index.index.name == "foo" - xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) - - # test no name set for pd.Index - pd_idx.name = None - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - assert "x" in index_vars - assert index.index is not pd_idx - assert index.index.name == "x" + def test_from_variables_index_adapter(self) -> None: + # test index type is preserved when variable wraps a pd.Index + data = pd.Series(["foo", "bar"], dtype="category") + pd_idx = pd.Index(data) + var = xr.Variable("x", pd_idx) + + index = PandasIndex.from_variables({"x": var}) + assert isinstance(index.index, pd.CategoricalIndex) + + def test_concat_periods(self): + periods = pd.period_range("2000-01-01", periods=10) + indexes = [PandasIndex(periods[:5], "t"), PandasIndex(periods[5:], "t")] + expected = PandasIndex(periods, "t") + actual = PandasIndex.concat(indexes, dim="t") + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + positions = [list(range(5)), list(range(5, 10))] + actual = PandasIndex.concat(indexes, dim="t", positions=positions) + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype) -> None: + + a = PandasIndex(np.array(["a"], dtype=dtype), "x", coord_dtype=dtype) + b = PandasIndex(np.array(["b"], dtype=dtype), "x", coord_dtype=dtype) + expected = PandasIndex( + np.array(["a", "b"], dtype=dtype), "x", coord_dtype=dtype + ) - def to_pandas_index(self): + actual = PandasIndex.concat([a, b], "x") + assert actual.equals(expected) + assert np.issubdtype(actual.coord_dtype, dtype) + + def test_concat_empty(self) -> None: + idx = PandasIndex.concat([], "x") + assert idx.coord_dtype is np.dtype("O") + + def test_concat_dim_error(self) -> None: + indexes = [PandasIndex([0, 1], "x"), PandasIndex([2, 3], "y")] + + with pytest.raises(ValueError, match=r"Cannot concatenate.*dimensions.*"): + PandasIndex.concat(indexes, "x") + + def test_create_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) + pd_idx = pd.Index(data, name="foo") + index = PandasIndex(pd_idx, "x", coord_dtype=data.dtype) + index_vars = { + "foo": IndexVariable( + "x", data, attrs={"unit": "m"}, encoding={"fill_value": 0.0} + ) + } + + actual = index.create_variables(index_vars) + assert_identical(actual["foo"], index_vars["foo"]) + assert actual["foo"].dtype == index_vars["foo"].dtype + assert actual["foo"].dtype == index.coord_dtype + + def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index = PandasIndex(pd_idx, "x") - assert index.to_pandas_index() is pd_idx + assert index.to_pandas_index() is index.index - def test_query(self) -> None: + def test_sel(self) -> None: # TODO: add tests that aren't just for edge cases index = PandasIndex(pd.Index([1, 2, 3]), "x") with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"does not have a MultiIndex"): - index.query({"x": {"one": 0}}) + index.sel({"x": {"one": 0}}) + + def test_sel_boolean(self) -> None: + # index should be ignored and indexer dtype should not be coerced + # see https://github.com/pydata/xarray/issues/5727 + index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") + actual = index.sel({"x": [False, True, False, True]}) + expected_dim_indexers = {"x": [False, True, False, True]} + np.testing.assert_array_equal( + actual.dim_indexers["x"], expected_dim_indexers["x"] + ) - def test_query_datetime(self) -> None: + def test_sel_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) - actual = index.query({"x": "2001-01-01"}) - expected = (1, None) - assert actual == expected + actual = index.sel({"x": "2001-01-01"}) + expected_dim_indexers = {"x": 1} + assert actual.dim_indexers == expected_dim_indexers - actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) - assert actual == expected + actual = index.sel({"x": index.to_pandas_index().to_numpy()[1]}) + assert actual.dim_indexers == expected_dim_indexers - def test_query_unsorted_datetime_index_raises(self) -> None: + def test_sel_unsorted_datetime_index_raises(self) -> None: index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should # raise instead, so we can be sure the result of indexing with a # slice is always a view. - index.query({"x": slice("2001", "2002")}) + index.sel({"x": slice("2001", "2002")}) def test_equals(self) -> None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([1, 2, 3], "x") assert index1.equals(index2) is True - def test_union(self) -> None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([4, 5, 6], "y") - actual = index1.union(index2) - assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) - assert actual.dim == "x" + def test_join(self) -> None: + index1 = PandasIndex(["a", "aa", "aaa"], "x", coord_dtype=" None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([2, 3, 4], "y") - actual = index1.intersection(index2) - assert actual.index.equals(pd.Index([2, 3])) - assert actual.dim == "x" + expected = PandasIndex(["aa", "aaa"], "x") + actual = index1.join(index2) + print(actual.index) + assert actual.equals(expected) + assert actual.coord_dtype == " None: + index1 = PandasIndex([0, 1, 2], "x") + index2 = PandasIndex([1, 2, 3, 4], "x") + + expected = {"x": [1, 2, -1, -1]} + actual = index1.reindex_like(index2) + assert actual.keys() == expected.keys() + np.testing.assert_array_equal(actual["x"], expected["x"]) + + index3 = PandasIndex([1, 1, 2], "x") + with pytest.raises(ValueError, match=r".*index has duplicate values"): + index3.reindex_like(index2) + + def test_rename(self) -> None: + index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) + + # shortcut + new_index = index.rename({}, {}) + assert new_index is index + + new_index = index.rename({"a": "b"}, {}) + assert new_index.index.name == "b" + assert new_index.dim == "x" + assert new_index.coord_dtype == np.int32 + + new_index = index.rename({}, {"x": "y"}) + assert new_index.index.name == "a" + assert new_index.dim == "y" + assert new_index.coord_dtype == np.int32 def test_copy(self) -> None: - expected = PandasIndex([1, 2, 3], "x") + expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) actual = expected.copy() assert actual.index.equals(expected.index) assert actual.index is not expected.index assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype def test_getitem(self) -> None: pd_idx = pd.Index([1, 2, 3]) - expected = PandasIndex(pd_idx, "x") + expected = PandasIndex(pd_idx, "x", coord_dtype=np.int32) actual = expected[1:] assert actual.index.equals(pd_idx[1:]) assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype class TestPandasMultiIndex: + def test_constructor(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + + index = PandasMultiIndex(pd_idx, "x") + + assert index.dim == "x" + assert index.index.equals(pd_idx) + assert index.index.names == ("foo", "bar") + assert index.index.name == "x" + assert index.level_coords_dtype == { + "foo": foo_data.dtype, + "bar": bar_data.dtype, + } + + with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): + PandasMultiIndex(pd_idx, "foo") + + # default level names + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) + index = PandasMultiIndex(pd_idx, "x") + assert index.index.names == ("x_level_0", "x_level_1") + def test_from_variables(self) -> None: v_level1 = xr.Variable( "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} @@ -142,18 +352,15 @@ def test_from_variables(self) -> None: "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} ) - index, index_vars = PandasMultiIndex.from_variables( + index = PandasMultiIndex.from_variables( {"level1": v_level1, "level2": v_level2} ) expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) assert index.dim == "x" assert index.index.equals(expected_idx) - - assert list(index_vars) == ["x", "level1", "level2"] - xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) - xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"]) - xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"]) + assert index.index.name == "x" + assert index.index.names == ["level1", "level2"] var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( @@ -162,35 +369,267 @@ def test_from_variables(self) -> None: PandasMultiIndex.from_variables({"var": var}) v_level3 = xr.Variable("y", [4, 5, 6]) - with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): + with pytest.raises( + ValueError, match=r"unmatched dimensions for multi-index variables.*" + ): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) - def test_from_pandas_index(self) -> None: - pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) + def test_concat(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [[0, 1, 2], ["a", "b"]], names=("foo", "bar") + ) + level_coords_dtype = {"foo": np.int32, "bar": " None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 3, 2])), + } + + index = PandasMultiIndex.stack(prod_vars, "z") + + assert index.dim == "z" + assert index.index.names == ["x", "y"] + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + ) + + with pytest.raises( + ValueError, match=r"conflicting dimensions for multi-index product.*" + ): + PandasMultiIndex.stack( + {"x": xr.Variable("x", ["a", "b"]), "x2": xr.Variable("x", [1, 2])}, + "z", + ) + + def test_stack_non_unique(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 1, 2])), + } - def test_query(self) -> None: + index = PandasMultiIndex.stack(prod_vars, "z") + + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + ) + np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + + def test_unstack(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2, 3]], names=["one", "two"] + ) + index = PandasMultiIndex(pd_midx, "x") + + new_indexes, new_pd_idx = index.unstack() + assert list(new_indexes) == ["one", "two"] + assert new_indexes["one"].equals(PandasIndex(["a", "b"], "one")) + assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) + assert new_pd_idx.equals(pd_midx) + + def test_create_variables(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + index_vars = { + "x": IndexVariable("x", pd_idx), + "foo": IndexVariable("x", foo_data, attrs={"unit": "m"}), + "bar": IndexVariable("x", bar_data, encoding={"fill_value": 0}), + } + + index = PandasMultiIndex(pd_idx, "x") + actual = index.create_variables(index_vars) + + for k, expected in index_vars.items(): + assert_identical(actual[k], expected) + assert actual[k].dtype == expected.dtype + if k != "x": + assert actual[k].dtype == index.level_coords_dtype[k] + + def test_sel(self) -> None: index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) + # test tuples inside slice are considered as scalar indexer values - assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) + actual = index.sel({"x": slice(("a", 1), ("b", 2))}) + expected_dim_indexers = {"x": slice(0, 4)} + assert actual.dim_indexers == expected_dim_indexers with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): - index.query({"one": 0, "x": "a"}) + index.sel({"one": 0, "x": "a"}) with pytest.raises(ValueError, match=r"invalid multi-index level names"): - index.query({"x": {"three": 0}}) + index.sel({"x": {"three": 0}}) with pytest.raises(IndexError): - index.query({"x": (slice(None), 1, "no_level")}) + index.sel({"x": (slice(None), 1, "no_level")}) + + def test_join(self): + midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two")) + level_coords_dtype = {"one": " None: + level_coords_dtype = {"one": " None: + level_coords_dtype = {"one": "U<1", "two": np.int32} + expected = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), + "x", + level_coords_dtype=level_coords_dtype, + ) + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + assert actual.level_coords_dtype == expected.level_coords_dtype + + +class TestIndexes: + @pytest.fixture + def unique_indexes(self) -> List[PandasIndex]: + x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x") + y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y") + z_pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["one", "two"] + ) + z_midx = PandasMultiIndex(z_pd_midx, "z") + + return [x_idx, y_idx, z_midx] + + @pytest.fixture + def indexes(self, unique_indexes) -> Indexes[Index]: + x_idx, y_idx, z_midx = unique_indexes + indexes: Dict[Any, Index] = { + "x": x_idx, + "y": y_idx, + "z": z_midx, + "one": z_midx, + "two": z_midx, + } + variables: Dict[Any, Variable] = {} + for idx in unique_indexes: + variables.update(idx.create_variables()) + + return Indexes(indexes, variables) + + def test_interface(self, unique_indexes, indexes) -> None: + x_idx = unique_indexes[0] + assert list(indexes) == ["x", "y", "z", "one", "two"] + assert len(indexes) == 5 + assert "x" in indexes + assert indexes["x"] is x_idx + + def test_variables(self, indexes) -> None: + assert tuple(indexes.variables) == ("x", "y", "z", "one", "two") + + def test_dims(self, indexes) -> None: + assert indexes.dims == {"x": 3, "y": 3, "z": 4} + + def test_get_unique(self, unique_indexes, indexes) -> None: + assert indexes.get_unique() == unique_indexes + + def test_is_multi(self, indexes) -> None: + assert indexes.is_multi("one") is True + assert indexes.is_multi("x") is False + + def test_get_all_coords(self, indexes) -> None: + expected = { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + } + assert indexes.get_all_coords("one") == expected + + with pytest.raises(ValueError, match="errors must be.*"): + indexes.get_all_coords("x", errors="invalid") + + with pytest.raises(ValueError, match="no index found.*"): + indexes.get_all_coords("no_coord") + + assert indexes.get_all_coords("no_coord", errors="ignore") == {} + + def test_get_all_dims(self, indexes) -> None: + expected = {"z": 4} + assert indexes.get_all_dims("one") == expected + + def test_group_by_index(self, unique_indexes, indexes): + expected = [ + (unique_indexes[0], {"x": indexes.variables["x"]}), + (unique_indexes[1], {"y": indexes.variables["y"]}), + ( + unique_indexes[2], + { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + }, + ), + ] + + assert indexes.group_by_index() == expected + + def test_to_pandas_indexes(self, indexes) -> None: + pd_indexes = indexes.to_pandas_indexes() + assert isinstance(pd_indexes, Indexes) + assert all([isinstance(idx, pd.Index) for idx in pd_indexes.values()]) + assert indexes.variables == pd_indexes.variables + + def test_copy_indexes(self, indexes) -> None: + copied, index_vars = indexes.copy_indexes() + + assert copied.keys() == indexes.keys() + for new, original in zip(copied.values(), indexes.values()): + assert new.equals(original) + # check unique index objects preserved + assert copied["z"] is copied["one"] is copied["two"] + + assert index_vars.keys() == indexes.variables.keys() + for new, original in zip(index_vars.values(), indexes.variables.values()): + assert_identical(new, original) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 533f4a0cd62..0b40bd18223 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1,4 +1,5 @@ import itertools +from typing import Any import numpy as np import pandas as pd @@ -6,6 +7,8 @@ from xarray import DataArray, Dataset, Variable from xarray.core import indexing, nputils +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.types import T_Xarray from . import IndexerMaker, ReturnItem, assert_array_equal @@ -62,31 +65,74 @@ def test_group_indexers_by_index(self) -> None: ) data.coords["y2"] = ("y", [2.0, 3.0]) - indexes, grouped_indexers = indexing.group_indexers_by_index( - data, {"z": 0, "one": "a", "two": 1, "y": 0} + grouped_indexers = indexing.group_indexers_by_index( + data, {"z": 0, "one": "a", "two": 1, "y": 0}, {} ) - assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]} - assert grouped_indexers == { - "x": {"one": "a", "two": 1}, - "y": {"y": 0}, - None: {"z": 0}, - } - - with pytest.raises(KeyError, match=r"no index found for coordinate y2"): - indexing.group_indexers_by_index(data, {"y2": 2.0}) - with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): - indexing.group_indexers_by_index(data, {"w": "a"}) + + for idx, indexers in grouped_indexers: + if idx is None: + assert indexers == {"z": 0} + elif idx.equals(data.xindexes["x"]): + assert indexers == {"one": "a", "two": 1} + elif idx.equals(data.xindexes["y"]): + assert indexers == {"y": 0} + assert len(grouped_indexers) == 3 + + with pytest.raises(KeyError, match=r"no index found for coordinate 'y2'"): + indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) + with pytest.raises( + KeyError, match=r"'w' is not a valid dimension or coordinate" + ): + indexing.group_indexers_by_index(data, {"w": "a"}, {}) with pytest.raises(ValueError, match=r"cannot supply.*"): - indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") + indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) + + def test_map_index_queries(self) -> None: + def create_sel_results( + x_indexer, + x_index, + other_vars, + drop_coords, + drop_indexes, + rename_dims, + ): + dim_indexers = {"x": x_indexer} + index_vars = x_index.create_variables() + indexes = {k: x_index for k in index_vars} + variables = {} + variables.update(index_vars) + variables.update(other_vars) + + return indexing.IndexSelResult( + dim_indexers=dim_indexers, + indexes=indexes, + variables=variables, + drop_coords=drop_coords, + drop_indexes=drop_indexes, + rename_dims=rename_dims, + ) + + def test_indexer( + data: T_Xarray, + x: Any, + expected: indexing.IndexSelResult, + ) -> None: + results = indexing.map_index_queries(data, {"x": x}) + + assert results.dim_indexers.keys() == expected.dim_indexers.keys() + assert_array_equal(results.dim_indexers["x"], expected.dim_indexers["x"]) - def test_remap_label_indexers(self) -> None: - def test_indexer(data, x, expected_pos, expected_idx=None) -> None: - pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) - idx, _ = new_idx_vars.get("x", (None, None)) - if idx is not None: - idx = idx.to_pandas_index() - assert_array_equal(pos.get("x"), expected_pos) - assert_array_equal(idx, expected_idx) + assert results.indexes.keys() == expected.indexes.keys() + for k in results.indexes: + assert results.indexes[k].equals(expected.indexes[k]) + + assert results.variables.keys() == expected.variables.keys() + for k in results.variables: + assert_array_equal(results.variables[k], expected.variables[k]) + + assert set(results.drop_coords) == set(expected.drop_coords) + assert set(results.drop_indexes) == set(expected.drop_indexes) + assert results.rename_dims == expected.rename_dims data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -94,50 +140,96 @@ def test_indexer(data, x, expected_pos, expected_idx=None) -> None: ) mdata = DataArray(range(8), [("x", mindex)]) - test_indexer(data, 1, 0) - test_indexer(data, np.int32(1), 0) - test_indexer(data, Variable([], 1), 0) - test_indexer(mdata, ("a", 1, -1), 0) - test_indexer( - mdata, - ("a", 1), + test_indexer(data, 1, indexing.IndexSelResult({"x": 0})) + test_indexer(data, np.int32(1), indexing.IndexSelResult({"x": 0})) + test_indexer(data, Variable([], 1), indexing.IndexSelResult({"x": 0})) + test_indexer(mdata, ("a", 1, -1), indexing.IndexSelResult({"x": 0})) + + expected = create_sel_results( [True, True, False, False, False, False, False, False], - [-1, -2], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, ) - test_indexer( - mdata, - "a", + test_indexer(mdata, ("a", 1), expected) + + expected = create_sel_results( slice(0, 4, None), - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) - test_indexer( - mdata, - ("a",), + test_indexer(mdata, "a", expected) + + expected = create_sel_results( [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) - test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7]) - test_indexer(mdata, slice("a", "b"), slice(0, 8, None)) - test_indexer(mdata, slice(("a", 1), ("b", 1)), slice(0, 6, None)) - test_indexer(mdata, {"one": "a", "two": 1, "three": -1}, 0) + test_indexer(mdata, ("a",), expected) + test_indexer( - mdata, - {"one": "a", "two": 1}, - [True, True, False, False, False, False, False, False], - [-1, -2], + mdata, [("a", 1, -1), ("b", 2, -2)], indexing.IndexSelResult({"x": [0, 7]}) + ) + test_indexer( + mdata, slice("a", "b"), indexing.IndexSelResult({"x": slice(0, 8, None)}) ) test_indexer( mdata, - {"one": "a", "three": -1}, - [True, False, True, False, False, False, False, False], - [1, 2], + slice(("a", 1), ("b", 1)), + indexing.IndexSelResult({"x": slice(0, 6, None)}), ) test_indexer( mdata, - {"one": "a"}, + {"one": "a", "two": 1, "three": -1}, + indexing.IndexSelResult({"x": 0}), + ) + + expected = create_sel_results( + [True, True, False, False, False, False, False, False], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, + ) + test_indexer(mdata, {"one": "a", "two": 1}, expected) + + expected = create_sel_results( + [True, False, True, False, False, False, False, False], + PandasIndex(pd.Index([1, 2]), "two"), + {"one": Variable((), "a"), "three": Variable((), -1)}, + ["x"], + ["one", "three"], + {"x": "two"}, + ) + test_indexer(mdata, {"one": "a", "three": -1}, expected) + + expected = create_sel_results( [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) + test_indexer(mdata, {"one": "a"}, expected) def test_read_only_view(self) -> None: diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 555a29b1952..6dca04ed069 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -242,7 +242,7 @@ def test_merge_error(self): def test_merge_alignment_error(self): ds = xr.Dataset(coords={"x": [1, 2]}) other = xr.Dataset(coords={"x": [2, 3]}) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.merge([ds, other], join="exact") def test_merge_wrong_input_error(self): diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index a083c50c3d1..bc3cc367c0e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -147,10 +147,10 @@ def strip_units(obj): new_obj = xr.Dataset(data_vars=data_vars, coords=coords) elif isinstance(obj, xr.DataArray): - data = array_strip_units(obj.data) + data = array_strip_units(obj.variable._data) coords = { strip_units(name): ( - (value.dims, array_strip_units(value.data)) + (value.dims, array_strip_units(value.variable._data)) if isinstance(value.data, Quantity) else value # to preserve multiindexes ) @@ -198,8 +198,7 @@ def attach_units(obj, units): name: ( (value.dims, array_attach_units(value.data, units.get(name) or 1)) if name in units - # to preserve multiindexes - else value + else (value.dims, value.data) ) for name, value in obj.coords.items() } @@ -3610,7 +3609,10 @@ def test_stacking_stacked(self, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_identical(expected, actual) + if func.name == "reset_index": + assert_identical(expected, actual, check_default_indexes=False) + else: + assert_identical(expected, actual) @pytest.mark.skip(reason="indexes don't support units") def test_to_unstacked_dataset(self, dtype): @@ -4719,7 +4721,10 @@ def test_stacking_stacked(self, variant, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "reset_index": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.xfail( reason="stacked dimension's labels have to be hashable, but is a numpy.array" @@ -5503,7 +5508,10 @@ def test_content_manipulation(self, func, variant, dtype): actual = func(ds) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "rename_dims": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ce796e9de49..0a720b23d3b 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -89,31 +89,6 @@ def test_safe_cast_to_index_datetime_datetime(): assert isinstance(actual, pd.Index) -def test_multiindex_from_product_levels(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 3, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 3, 2]) - - other = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) - np.testing.assert_array_equal(result.values, other.values) - - -def test_multiindex_from_product_levels_non_unique(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 1, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 2]) - - class TestArrayEquiv: def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional