Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible index variables #8124

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Indexes,
PandasIndex,
PandasMultiIndex,
create_index_variables,
indexes_all_equal,
safe_cast_to_index,
)
Expand Down Expand Up @@ -425,7 +426,7 @@ def align_indexes(self) -> None:
elif self.join == "right":
joined_index_vars = matching_index_vars[-1]
else:
joined_index_vars = joined_index.create_variables()
joined_index_vars = create_index_variables(joined_index)
else:
joined_index = matching_indexes[0]
joined_index_vars = matching_index_vars[0]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from xarray.core import dtypes, utils
from xarray.core.alignment import align, reindex_variables
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
from xarray.core.indexes import Index, PandasIndex, create_index_variables
from xarray.core.merge import (
_VALID_COMPAT,
collect_variables_and_indexes,
Expand Down Expand Up @@ -619,7 +619,7 @@ def get_indexes(name):
# index created from a scalar coordinate
idx_vars = {name: datasets[0][name].variable}
result_indexes.update({k: combined_idx for k in idx_vars})
combined_idx_vars = combined_idx.create_variables(idx_vars)
combined_idx_vars = create_index_variables(combined_idx, idx_vars)
for k, v in combined_idx_vars.items():
v.attrs = merge_attrs(
[ds.variables[k].attrs for ds in datasets],
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
else:
variables = {}
for name, data in coords.items():
var = as_variable(data, name=name)
var = as_variable(data, name=name, auto_convert=False)
if var.dims == (name,) and indexes is None:
index, index_vars = create_default_index_implicit(var, list(coords))
default_indexes.update({k: index for k in index_vars})
Expand Down Expand Up @@ -930,9 +930,12 @@ def create_coords_with_default_indexes(
if isinstance(obj, DataArray):
dataarray_coords.append(obj.coords)

variable = as_variable(obj, name=name)
variable = as_variable(obj, name=name, auto_convert=False)

if variable.dims == (name,):
# still needed to convert to IndexVariable first due to some
# pandas multi-index edge cases.
variable = variable.to_index_variable()
idx, idx_vars = create_default_index_implicit(variable, all_variables)
indexes.update({k: idx for k in idx_vars})
variables.update(idx_vars)
Expand Down
10 changes: 7 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def _infer_coords_and_dims(
dims = list(coords.keys())
else:
for n, (dim, coord) in enumerate(zip(dims, coords)):
coord = as_variable(coord, name=dims[n]).to_index_variable()
coord = as_variable(
coord, name=dims[n], auto_convert=False
).to_index_variable()
dims[n] = coord.name
dims = tuple(dims)
elif len(dims) != len(shape):
Expand All @@ -183,10 +185,12 @@ def _infer_coords_and_dims(
new_coords = {}
if utils.is_dict_like(coords):
for k, v in coords.items():
new_coords[k] = as_variable(v, name=k)
new_coords[k] = as_variable(v, name=k, auto_convert=False)
if new_coords[k].dims == (k,):
new_coords[k] = new_coords[k].to_index_variable()
elif coords is not None:
for dim, coord in zip(dims, coords):
var = as_variable(coord, name=dim)
var = as_variable(coord, name=dim, auto_convert=False)
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

Expand Down
14 changes: 8 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
PandasMultiIndex,
assert_no_index_corrupted,
create_default_index_implicit,
create_index_variables,
filter_indexes_from_coords,
isel_indexes,
remove_unused_levels_categories,
Expand Down Expand Up @@ -4106,11 +4107,12 @@ def _rename_indexes(
new_index = index.rename(name_dict, dims_dict)
new_coord_names = [name_dict.get(k, k) for k in coord_names]
indexes.update({k: new_index for k in new_coord_names})
new_index_vars = new_index.create_variables(
new_index_vars = create_index_variables(
new_index,
{
new: self._variables[old]
for old, new in zip(coord_names, new_coord_names)
}
},
)
variables.update(new_index_vars)

Expand Down Expand Up @@ -4940,7 +4942,7 @@ def set_xindex(

index = index_cls.from_variables(coord_vars, options=options)

new_coord_vars = index.create_variables(coord_vars)
new_coord_vars = create_index_variables(index, coord_vars)

# special case for setting a pandas multi-index from level coordinates
# TODO: remove it once we depreciate pandas multi-index dimension (tuple
Expand Down Expand Up @@ -5134,7 +5136,7 @@ def _stack_once(
idx = index_cls.stack(product_vars, new_dim)
new_indexes[new_dim] = idx
new_indexes.update({k: idx for k in product_vars})
idx_vars = idx.create_variables(product_vars)
idx_vars = create_index_variables(idx, product_vars)
# keep consistent multi-index coordinate order
for k in idx_vars:
new_variables.pop(k, None)
Expand Down Expand Up @@ -5326,7 +5328,7 @@ def _unstack_once(
indexes.update(new_indexes)

for name, idx in new_indexes.items():
variables.update(idx.create_variables(index_vars))
variables.update(create_index_variables(idx, index_vars))

for name, var in self.variables.items():
if name not in index_vars:
Expand Down Expand Up @@ -5367,7 +5369,7 @@ def _unstack_full_reindex(

new_index_variables = {}
for name, idx in new_indexes.items():
new_index_variables.update(idx.create_variables(index_vars))
new_index_variables.update(create_index_variables(idx, index_vars))

new_dim_sizes = {k: v.size for k, v in new_index_variables.items()}
variables.update(new_index_variables)
Expand Down
36 changes: 33 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from xarray.core import formatting, nputils, utils
from xarray.core.indexing import (
IndexedCoordinateArray,
IndexSelResult,
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
Expand Down Expand Up @@ -1332,6 +1333,29 @@ def rename(self, name_dict, dims_dict):
)


def create_index_variables(
index: Index, variables: Mapping[Any, Variable] | None = None
) -> IndexVars:
"""Create index coordinate variables and wrap their data in order to prevent
modifying their values in-place.

"""
# - IndexVariable already has safety guards that prevent updating its values
# (it is a special case for PandasIndex that will likely be removed, eventually)
# - For Variable objects: wrap their data.
from xarray.core.variable import IndexVariable

index_vars = index.create_variables(variables)

for var in index_vars.values():
if not isinstance(var, IndexVariable) and not isinstance(
var._data, IndexedCoordinateArray
):
var._data = IndexedCoordinateArray(var._data)

return index_vars


def create_default_index_implicit(
dim_variable: Variable,
all_variables: Mapping | Iterable[Hashable] | None = None,
Expand All @@ -1349,7 +1373,13 @@ def create_default_index_implicit(
all_variables = {k: None for k in all_variables}

name = dim_variable.dims[0]
array = getattr(dim_variable._data, "array", None)
data = dim_variable._data
if isinstance(data, PandasIndexingAdapter):
array = data.array
elif isinstance(data, IndexedCoordinateArray):
array = getattr(data.array, "array", None)
else:
array = None
index: PandasIndex

if isinstance(array, pd.MultiIndex):
Expand Down Expand Up @@ -1631,7 +1661,7 @@ def copy_indexes(
convert_new_idx = False

new_idx = idx._copy(deep=deep, memo=memo)
idx_vars = idx.create_variables(coords)
idx_vars = create_index_variables(idx, coords)

if convert_new_idx:
new_idx = cast(PandasIndex, new_idx).index
Expand Down Expand Up @@ -1779,7 +1809,7 @@ def _apply_indexes(
new_index = getattr(index, func)(index_args)
if new_index is not None:
new_indexes.update({k: new_index for k in index_vars})
new_index_vars = new_index.create_variables(index_vars)
new_index_vars = create_index_variables(new_index, index_vars)
new_index_variables.update(new_index_vars)
else:
for k in index_vars:
Expand Down
29 changes: 29 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,35 @@ def transpose(self, order):
return self.array.transpose(order)


class IndexedCoordinateArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an Xarray indexed coordinate array to make sure it keeps
synced with its index.

"""

__slots__ = ("array",)

def __init__(self, array):
self.array = as_indexable(array)

def get_duck_array(self):
return self.array.get_duck_array()

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def transpose(self, order):
return self.array.transpose(order)

def __setitem__(self, key, value):
raise TypeError(
"cannot modify the values of an indexed coordinate in-place "
"as it may corrupt its index. "
"Please use DataArray.assign_coords, Dataset.assign_coords "
"or Dataset.assign as appropriate."
)


class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def append_all(variables, indexes):
indexes_.pop(name, None)
append_all(coords_, indexes_)

variable = as_variable(variable, name=name)
variable = as_variable(variable, name=name, auto_convert=False)
if name in indexes:
append(name, variable, indexes[name])
elif variable.dims == (name,):
Expand Down
30 changes: 25 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xarray.core.common import AbstractArray
from xarray.core.indexing import (
BasicIndexer,
IndexedCoordinateArray,
OuterIndexer,
PandasIndexingAdapter,
VectorizedIndexer,
Expand All @@ -45,6 +46,7 @@
decode_numpy_dict_values,
drop_dims_from_indexers,
either_dict_or_kwargs,
emit_user_level_warning,
ensure_us_time_resolution,
infix_dims,
is_duck_array,
Expand Down Expand Up @@ -86,7 +88,7 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def as_variable(obj, name=None) -> Variable | IndexVariable:
def as_variable(obj, name=None, auto_convert=True) -> Variable | IndexVariable:
"""Convert an object into a Variable.

Parameters
Expand All @@ -106,6 +108,9 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
along a dimension of this given name.
- Variables with name matching one of their dimensions are converted
into `IndexVariable` objects.
auto_convert : bool, optional
For internal use only! If True, convert a "dimension" variable into
an IndexVariable object (deprecated).

Returns
-------
Expand Down Expand Up @@ -156,9 +161,15 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
f"explicit list of dimensions: {obj!r}"
)

if name is not None and name in obj.dims and obj.ndim == 1:
# automatically convert the Variable into an Index
obj = obj.to_index_variable()
if auto_convert:
if name is not None and name in obj.dims and obj.ndim == 1:
# automatically convert the Variable into an Index
emit_user_level_warning(
f"variable {name!r} with name matching its dimension will not be "
"automatically converted into an `IndexVariable` object in the future.",
FutureWarning,
)
obj = obj.to_index_variable()

return obj

Expand Down Expand Up @@ -430,6 +441,13 @@ def data(self) -> Any:

@data.setter
def data(self, data):
if isinstance(self._data, IndexedCoordinateArray):
raise ValueError(
"Cannot assign to the .data attribute of an indexed coordinate "
"as it may corrupt its index. "
"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate."
)

data = as_compatible_data(data)
if data.shape != self.shape:
raise ValueError(
Expand Down Expand Up @@ -829,8 +847,10 @@ def _broadcast_indexes_vectorized(self, key):
variable = (
value
if isinstance(value, Variable)
else as_variable(value, name=dim)
else as_variable(value, name=dim, auto_convert=False)
)
if variable.dims == (dim,):
variable = variable.to_index_variable()
if variable.dtype.kind == "b": # boolean indexing case
(variable,) = variable._nonzero()

Expand Down
18 changes: 13 additions & 5 deletions xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
from xarray.core.indexing import IndexedCoordinateArray
from xarray.core.variable import IndexVariable, Variable

__all__ = (
Expand Down Expand Up @@ -272,9 +273,11 @@ def _assert_indexes_invariants_checks(
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
k
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable) or isinstance(v._data, IndexedCoordinateArray)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
assert indexes.keys() == index_vars, (set(indexes), index_vars)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
Expand Down Expand Up @@ -340,9 +343,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

if check_default_indexes:
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

for k, v in da._coords.items():
_assert_variable_invariants(v, k)

Expand Down