Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to T_DataArray in .coords #7285

Merged
merged 17 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast
from typing import TYPE_CHECKING, Any, Hashable, Iterator, List, Mapping, Sequence

import numpy as np
import pandas as pd
Expand All @@ -14,18 +14,27 @@
from .variable import Variable, calculate_dimensions

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataArray

# Used as the key corresponding to a DataArray's variable when converting
# arbitrary DataArray objects to datasets
_THIS_ARRAY = ReprObject("<this-array>")

# TODO: Remove when min python version >= 3.9:
GenericAlias = type(List[int])

class Coordinates(Mapping[Hashable, "DataArray"]):
__slots__ = ()

def __getitem__(self, key: Hashable) -> DataArray:
class Coordinates(Mapping[Hashable, "T_DataArray"]):
_data: DataWithCoords
__slots__ = ("_data",)

# TODO: Remove when min python version >= 3.9:
__class_getitem__ = classmethod(GenericAlias)

def __getitem__(self, key: Hashable) -> T_DataArray:
raise NotImplementedError()

def __setitem__(self, key: Hashable, value: Any) -> None:
Expand Down Expand Up @@ -238,6 +247,8 @@ class DatasetCoordinates(Coordinates):
objects.
"""

_data: Dataset

__slots__ = ("_data",)

def __init__(self, dataset: Dataset):
Expand Down Expand Up @@ -278,7 +289,7 @@ def variables(self) -> Mapping[Hashable, Variable]:
def __getitem__(self, key: Hashable) -> DataArray:
if key in self._data.data_vars:
raise KeyError(key)
return cast("DataArray", self._data[key])
return self._data[key]

def to_dataset(self) -> Dataset:
"""Convert these coordinates into a new Dataset"""
Expand Down Expand Up @@ -334,16 +345,18 @@ def _ipython_key_completions_(self):
]


class DataArrayCoordinates(Coordinates):
class DataArrayCoordinates(Coordinates["T_DataArray"]):
"""Dictionary like container for DataArray coordinates.

Essentially a dict with keys given by the array's
dimensions and the values given by corresponding DataArray objects.
"""

_data: T_DataArray

__slots__ = ("_data",)

def __init__(self, dataarray: DataArray):
def __init__(self, dataarray: T_DataArray) -> None:
self._data = dataarray

@property
Expand All @@ -366,7 +379,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]:
def _names(self) -> set[Hashable]:
return set(self._data._coords)

def __getitem__(self, key: Hashable) -> DataArray:
def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)
Comment on lines +382 to 383
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray/core/coordinates.py:372: error: A function returning TypeVar should receive at least one argument containing the same Typevar  [type-var]
xarray/core/coordinates.py:372: note: Consider using the upper bound "DataArray" instead

I don't really understand this error, why should the method have at least one argument with the same TypeVar? The TypeVar is stored in self already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I meant with making coordinates a generic class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not following. Feel free to push to this PR or link me to an example you think is similar.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, nvmd. This does not seem to work in python 3.8....
Anyone has an idea how to solve this? haha.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work though? What does a reveal_type(da.coords) give?
Or a reveal_type(da.coords["x"])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this on main:

da = xr.DataArray()
reveal_type(da)  # note: Revealed type is "Any"

Do you as well?

Copy link
Collaborator

@headtr1ck headtr1ck Nov 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I get note: Revealed type is "xarray.core.dataarray.DataArray" as expected.
(python 3.9.13 and mypy 0.990)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I use:

mypy                      0.982           py310h8d17308_0    conda-forge
mypy_extensions           0.4.3           py310h5588dad_5    conda-forge
...
python                    3.10.6          h9a09f29_0_cpython    conda-forge


def _update_coords(
Expand Down Expand Up @@ -452,7 +465,7 @@ def drop_coords(


def assert_coordinate_consistent(
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3993,8 +3993,8 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
"""
d = self.variable.to_dict(data=data)
d.update({"coords": {}, "name": self.name})
for k in self.coords:
d["coords"][k] = self.coords[k].variable.to_dict(data=data)
for k, coord in self.coords.items():
d["coords"][k] = coord.variable.to_dict(data=data)
if encoding:
d["encoding"] = dict(self.encoding)
return d
Expand Down