diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index c539536a294..35421c7f71e 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -691,12 +691,7 @@ def _update_coords( self._data._variables = variables self._data._coord_names.update(new_coord_names) self._data._dims = dims - - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._indexes.update(indexes) def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -777,12 +772,7 @@ def _update_coords( "cannot add coordinates with new dimensions to a DataArray" ) self._data._coords = coords - - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._indexes.update(indexes) def _drop_coords(self, coord_names): # should drop indexed coordinates only diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f1a0cb9dc34..28d3b9dcffb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,9 +64,11 @@ PandasIndex, PandasMultiIndex, assert_no_index_corrupted, + chunk_indexes, create_default_index_implicit, filter_indexes_from_coords, isel_indexes, + load_indexes, remove_unused_levels_categories, roll_indexes, ) @@ -816,6 +818,11 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset: -------- dask.compute """ + # apply Index.load, collect new indexes and variables and replace the existing ones + # new index variables may still be lazy: load them here after + indexes, index_variables = load_indexes(self.xindexes, kwargs) + self.coords._update_coords(index_variables, indexes) + # access .data to coerce everything to numpy or dask arrays lazy_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) @@ -2641,21 +2648,36 @@ def chunk( if from_array_kwargs is None: from_array_kwargs = {} - variables = { - k: _maybe_chunk( - k, - v, - chunks, - token, - lock, - name_prefix, - inline_array=inline_array, - chunked_array_type=chunkmanager, - from_array_kwargs=from_array_kwargs.copy(), - ) - for k, v in self.variables.items() - } - return self._replace(variables) + # apply Index.chunk, collect new indexes and variables + indexes, index_variables = chunk_indexes( + self.xindexes, + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs, + ) + + variables = {} + for k, v in self.variables.items(): + if k in index_variables: + variables[k] = index_variables[k] + else: + variables[k] = _maybe_chunk( + k, + v, + chunks, + token, + lock, + name_prefix, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) + + return self._replace(variables=variables, indexes=indexes) def _validate_indexers( self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise" diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b5e396963a1..f3f45b124f5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -4,7 +4,7 @@ import copy from collections import defaultdict from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast import numpy as np import pandas as pd @@ -15,6 +15,7 @@ PandasIndexingAdapter, PandasMultiIndexingAdapter, ) +from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.utils import ( Frozen, emit_user_level_warning, @@ -54,8 +55,6 @@ class Index: corresponding operation on a :py:meth:`Dataset` or :py:meth:`DataArray` either will raise a ``NotImplementedError`` or will simply drop/pass/copy the index from/to the result. - - Do not use this class directly for creating index objects. """ @classmethod @@ -321,6 +320,55 @@ def equals(self: T_Index, other: T_Index) -> bool: """ raise NotImplementedError() + def load(self, **kwargs) -> Index | None: + """Method called when calling :py:meth:`Dataset.load` or + :py:meth:`Dataset.compute` (or DataArray equivalent methods). + + The default implementation will simply drop the index by returning + ``None``. + + Possible re-implementations in subclasses are: + + - For an index with coordinate data fully in memory like a ``PandasIndex``: + return itself + - For an index with lazy coordinate data (e.g., a dask array): + return an index object of another type like ``PandasIndex`` + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + """ + return None + + def chunk( + self, + chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]], + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + ) -> Index | None: + """Method called when calling :py:meth:`Dataset.chunk` or + :py:meth:`Dataset.chunk` (or DataArray equivalent methods). + + The default implementation will simply drop the index by returning + ``None``. + + Possible re-implementations in subclasses are: + + - For an index with coordinate data fully in memory like a ``PandasIndex``: + return itself (do not chunk) + - For an index with lazy coordinate data (e.g., a dask array): + rebuild the index with an internal lookup structure that is + in sync with the new chunks + + For more details about the parameters, see :py:meth:`Dataset.chunk`. + """ + return None + def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: """Roll this index by an offset along one or more dimensions. @@ -821,6 +869,23 @@ def reindex_like( return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + def load(self: T_PandasIndex, **kwargs) -> T_PandasIndex: + # both index and coordinate(s) already loaded in-memory + return self + + def chunk( + self: T_PandasIndex, + chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]], + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + ) -> T_PandasIndex: + # skip chunk + return self + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: shift = shifts[self.dim] % self.index.shape[0] @@ -1764,19 +1829,46 @@ def check_variables(): return not not_equal -def _apply_indexes( +def _apply_index_method( indexes: Indexes[Index], - args: Mapping[Any, Any], - func: str, + method_name: str, + dim_args: Mapping | None = None, + kwargs: Mapping | None = None, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + """Utility function that applies a given Index method to an Indexes + collection and that returns new collections of indexes and coordinate + variables. + + Index method calls and arguments are filtered according to ``dim_args`` if + it is not None. Otherwise, the method is called unconditionally for each + index. + + ``kwargs`` is passed to every call of the index method. + + """ + if kwargs is None: + kwargs = {} + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} new_index_variables: dict[Hashable, Variable] = {} for index, index_vars in indexes.group_by_index(): - index_dims = {d for var in index_vars.values() for d in var.dims} - index_args = {k: v for k, v in args.items() if k in index_dims} - if index_args: - new_index = getattr(index, func)(index_args) + func = getattr(index, method_name) + + if dim_args is None: + new_index = func(**kwargs) + skip_index = False + else: + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in dim_args.items() if k in index_dims} + if index_args: + new_index = func(index_args, **kwargs) + skip_index = False + else: + new_index = None + skip_index = True + + if not skip_index: if new_index is not None: new_indexes.update({k: new_index for k in index_vars}) new_index_vars = new_index.create_variables(index_vars) @@ -1790,16 +1882,31 @@ def _apply_indexes( def isel_indexes( indexes: Indexes[Index], - indexers: Mapping[Any, Any], + indexers: Mapping, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_indexes(indexes, indexers, "isel") + return _apply_index_method(indexes, "isel", dim_args=indexers) def roll_indexes( indexes: Indexes[Index], shifts: Mapping[Any, int], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_indexes(indexes, shifts, "roll") + return _apply_index_method(indexes, "roll", dim_args=shifts) + + +def load_indexes( + indexes: Indexes[Index], + kwargs: Mapping, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_index_method(indexes, "load", kwargs=kwargs) + + +def chunk_indexes( + indexes: Indexes[Index], + chunks: Mapping, + **kwargs, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_index_method(indexes, "chunk", dim_args=chunks, kwargs=kwargs) def filter_indexes_from_coords(