Skip to content

Commit

Permalink
Switch to T_DataArray in .coords (#7285)
Browse files Browse the repository at this point in the history
* Switch to T_DataArray in .coords

* Update coordinates.py

* Update coordinates.py

* mypy understands the type from items better apparanetly

* Update coordinates.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve DataArrayCoords generic type

* fix import

* Test adding __class_getitem__

* Update coordinates.py

* Test adding a _data slot.

* Adding class_getitem seems to work.

* test mypy on 3.8

* Update ci-additional.yaml

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michael Niklas <mick.niklas@gmail.com>
  • Loading branch information
3 people authored Nov 22, 2022
1 parent 5344ccb commit ff6793d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
31 changes: 22 additions & 9 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast
from typing import TYPE_CHECKING, Any, Hashable, Iterator, List, Mapping, Sequence

import numpy as np
import pandas as pd
Expand All @@ -14,18 +14,27 @@
from .variable import Variable, calculate_dimensions

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataArray

# Used as the key corresponding to a DataArray's variable when converting
# arbitrary DataArray objects to datasets
_THIS_ARRAY = ReprObject("<this-array>")

# TODO: Remove when min python version >= 3.9:
GenericAlias = type(List[int])

class Coordinates(Mapping[Hashable, "DataArray"]):
__slots__ = ()

def __getitem__(self, key: Hashable) -> DataArray:
class Coordinates(Mapping[Hashable, "T_DataArray"]):
_data: DataWithCoords
__slots__ = ("_data",)

# TODO: Remove when min python version >= 3.9:
__class_getitem__ = classmethod(GenericAlias)

def __getitem__(self, key: Hashable) -> T_DataArray:
raise NotImplementedError()

def __setitem__(self, key: Hashable, value: Any) -> None:
Expand Down Expand Up @@ -238,6 +247,8 @@ class DatasetCoordinates(Coordinates):
objects.
"""

_data: Dataset

__slots__ = ("_data",)

def __init__(self, dataset: Dataset):
Expand Down Expand Up @@ -278,7 +289,7 @@ def variables(self) -> Mapping[Hashable, Variable]:
def __getitem__(self, key: Hashable) -> DataArray:
if key in self._data.data_vars:
raise KeyError(key)
return cast("DataArray", self._data[key])
return self._data[key]

def to_dataset(self) -> Dataset:
"""Convert these coordinates into a new Dataset"""
Expand Down Expand Up @@ -334,16 +345,18 @@ def _ipython_key_completions_(self):
]


class DataArrayCoordinates(Coordinates):
class DataArrayCoordinates(Coordinates["T_DataArray"]):
"""Dictionary like container for DataArray coordinates.
Essentially a dict with keys given by the array's
dimensions and the values given by corresponding DataArray objects.
"""

_data: T_DataArray

__slots__ = ("_data",)

def __init__(self, dataarray: DataArray):
def __init__(self, dataarray: T_DataArray) -> None:
self._data = dataarray

@property
Expand All @@ -366,7 +379,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]:
def _names(self) -> set[Hashable]:
return set(self._data._coords)

def __getitem__(self, key: Hashable) -> DataArray:
def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
Expand Down Expand Up @@ -452,7 +465,7 @@ def drop_coords(


def assert_coordinate_consistent(
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3993,8 +3993,8 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
"""
d = self.variable.to_dict(data=data)
d.update({"coords": {}, "name": self.name})
for k in self.coords:
d["coords"][k] = self.coords[k].variable.to_dict(data=data)
for k, coord in self.coords.items():
d["coords"][k] = coord.variable.to_dict(data=data)
if encoding:
d["encoding"] = dict(self.encoding)
return d
Expand Down

0 comments on commit ff6793d

Please sign in to comment.