From ff6793d975ef4d1d8d5d32b8ad6f4f44e02dda9b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 22 Nov 2022 18:02:08 +0100 Subject: [PATCH] Switch to T_DataArray in .coords (#7285) * Switch to T_DataArray in .coords * Update coordinates.py * Update coordinates.py * mypy understands the type from items better apparanetly * Update coordinates.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve DataArrayCoords generic type * fix import * Test adding __class_getitem__ * Update coordinates.py * Test adding a _data slot. * Adding class_getitem seems to work. * test mypy on 3.8 * Update ci-additional.yaml Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- xarray/core/coordinates.py | 31 ++++++++++++++++++++++--------- xarray/core/dataarray.py | 4 ++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 0f35f7e4c9a..52110d5bfee 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -2,7 +2,7 @@ import warnings from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Any, Hashable, Iterator, List, Mapping, Sequence import numpy as np import pandas as pd @@ -14,18 +14,27 @@ from .variable import Variable, calculate_dimensions if TYPE_CHECKING: + from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset + from .types import T_DataArray # Used as the key corresponding to a DataArray's variable when converting # arbitrary DataArray objects to datasets _THIS_ARRAY = ReprObject("") +# TODO: Remove when min python version >= 3.9: +GenericAlias = type(List[int]) -class Coordinates(Mapping[Hashable, "DataArray"]): - __slots__ = () - def __getitem__(self, key: Hashable) -> DataArray: +class Coordinates(Mapping[Hashable, "T_DataArray"]): + _data: DataWithCoords + __slots__ = ("_data",) + + # TODO: Remove when min python version >= 3.9: + __class_getitem__ = classmethod(GenericAlias) + + def __getitem__(self, key: Hashable) -> T_DataArray: raise NotImplementedError() def __setitem__(self, key: Hashable, value: Any) -> None: @@ -238,6 +247,8 @@ class DatasetCoordinates(Coordinates): objects. """ + _data: Dataset + __slots__ = ("_data",) def __init__(self, dataset: Dataset): @@ -278,7 +289,7 @@ def variables(self) -> Mapping[Hashable, Variable]: def __getitem__(self, key: Hashable) -> DataArray: if key in self._data.data_vars: raise KeyError(key) - return cast("DataArray", self._data[key]) + return self._data[key] def to_dataset(self) -> Dataset: """Convert these coordinates into a new Dataset""" @@ -334,16 +345,18 @@ def _ipython_key_completions_(self): ] -class DataArrayCoordinates(Coordinates): +class DataArrayCoordinates(Coordinates["T_DataArray"]): """Dictionary like container for DataArray coordinates. Essentially a dict with keys given by the array's dimensions and the values given by corresponding DataArray objects. """ + _data: T_DataArray + __slots__ = ("_data",) - def __init__(self, dataarray: DataArray): + def __init__(self, dataarray: T_DataArray) -> None: self._data = dataarray @property @@ -366,7 +379,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: def _names(self) -> set[Hashable]: return set(self._data._coords) - def __getitem__(self, key: Hashable) -> DataArray: + def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( @@ -452,7 +465,7 @@ def drop_coords( def assert_coordinate_consistent( - obj: DataArray | Dataset, coords: Mapping[Any, Variable] + obj: T_DataArray | Dataset, coords: Mapping[Any, Variable] ) -> None: """Make sure the dimension coordinate of obj is consistent with coords. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b4954df014f..ff55028ff82 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3993,8 +3993,8 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: """ d = self.variable.to_dict(data=data) d.update({"coords": {}, "name": self.name}) - for k in self.coords: - d["coords"][k] = self.coords[k].variable.to_dict(data=data) + for k, coord in self.coords.items(): + d["coords"][k] = coord.variable.to_dict(data=data) if encoding: d["encoding"] = dict(self.encoding) return d