From b8aaa5311291f773240cc9412a4f1a519b96191a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:57:42 -0700 Subject: [PATCH 01/11] Add a `.drop_attrs` method (#8258) * Add a `.drop_attrs` method Part of #3891 * Add tests * Add explicit coords test * Use `._replace` for half the method * . * Add a `deep` kwarg (default `True`?) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * api * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas * Update xarray/core/dataset.py Co-authored-by: Michael Niklas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- doc/api.rst | 2 ++ doc/whats-new.rst | 4 +++ xarray/core/dataarray.py | 17 ++++++++++++ xarray/core/dataset.py | 42 +++++++++++++++++++++++++++++ xarray/tests/test_dataarray.py | 5 ++++ xarray/tests/test_dataset.py | 48 ++++++++++++++++++++++++++++++++++ 6 files changed, 118 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index a8f8ea7dd1c..4cf8f374d37 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,6 +111,7 @@ Dataset contents Dataset.drop_duplicates Dataset.drop_dims Dataset.drop_encoding + Dataset.drop_attrs Dataset.set_coords Dataset.reset_coords Dataset.convert_calendar @@ -306,6 +307,7 @@ DataArray contents DataArray.drop_indexes DataArray.drop_duplicates DataArray.drop_encoding + DataArray.drop_attrs DataArray.reset_coords DataArray.copy DataArray.convert_calendar diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e8369dc2f40..6a8e898c93c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,10 @@ New Features By `Martin Raspaud `_. - Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). By `Justus Magin `_. +- Add :py:meth:`DataArray.drop_attrs` & :py:meth:`Dataset.drop_attrs` methods, + to return an object without ``attrs``. A ``deep`` parameter controls whether + variables' ``attrs`` are also dropped. + By `Maximilian Roos `_. (:pull:`8288`) Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b67f8089eb2..47dc9d13ffc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -7456,3 +7456,20 @@ def to_dask_dataframe( # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the DataArray. + + Parameters + ---------- + deep : bool, default True + Removes attributes from coordinates. + + Returns + ------- + DataArray + """ + return ( + self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset) + ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 50cfc7b0c29..3930b12ef3d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10680,3 +10680,45 @@ def resample( restore_coord_dims=restore_coord_dims, **indexer_kwargs, ) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the Dataset and its variables. + + Parameters + ---------- + deep : bool, default True + Removes attributes from all variables. + + Returns + ------- + Dataset + """ + # Remove attributes from the dataset + self = self._replace(attrs={}) + + if not deep: + return self + + # Remove attributes from each variable in the dataset + for var in self.variables: + # variables don't have a `._replace` method, so we copy and then remove + # attrs. If we added a `._replace` method, we could use that instead. + if var not in self.indexes: + self[var] = self[var].copy() + self[var].attrs = {} + + new_idx_variables = {} + # Not sure this is the most elegant way of doing this, but it works. + # (Should we have a more general "map over all variables, including + # indexes" approach?) + for idx, idx_vars in self.xindexes.group_by_index(): + # copy each coordinate variable of an index and drop their attrs + temp_idx_variables = {k: v.copy() for k, v in idx_vars.items()} + for v in temp_idx_variables.values(): + v.attrs = {} + # re-wrap the index object in new coordinate variables + new_idx_variables.update(idx.create_variables(temp_idx_variables)) + self = self.assign(new_idx_variables) + + return self diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 659c7c168a5..44ef486e5d6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2980,6 +2980,11 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == {"a": 1, "b": 2} + def test_drop_attrs(self) -> None: + # Mostly tested in test_dataset.py, but adding a very small test here + da = DataArray([], attrs=dict(a=1, b=2)) + assert da.drop_attrs().attrs == {} + @pytest.mark.parametrize( "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] ) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f6829861776..fd511af0dfb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4450,6 +4450,54 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == dict(a=1, b=2) + def test_drop_attrs(self) -> None: + # Simple example + ds = Dataset().assign_attrs(a=1, b=2) + original = ds.copy() + expected = Dataset() + result = ds.drop_attrs() + assert_identical(result, expected) + + # Doesn't change original + assert_identical(ds, original) + + # Example with variables and coords with attrs, and a multiindex. (arguably + # should have used a canonical dataset with all the features we're should + # support...) + var = Variable("x", [1, 2, 3], attrs=dict(x=1, y=2)) + idx = IndexVariable("y", [1, 2, 3], attrs=dict(c=1, d=2)) + mx = xr.Coordinates.from_pandas_multiindex( + pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z" + ) + ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2) + assert ds.attrs != {} + assert ds["var1"].attrs != {} + assert ds["y"].attrs != {} + assert ds.coords["y"].attrs != {} + + original = ds.copy(deep=True) + result = ds.drop_attrs() + + assert result.attrs == {} + assert result["var1"].attrs == {} + assert result["y"].attrs == {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + + # Doesn't change original + assert_identical(ds, original) + # Specifically test that the attrs on the coords are still there. (The index + # can't currently contain `attrs`, so we can't test those.) + assert ds.coords["y"].attrs != {} + + # Test for deep=False + result_shallow = ds.drop_attrs(deep=False) + assert result_shallow.attrs == {} + assert result_shallow["var1"].attrs != {} + assert result_shallow["y"].attrs != {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): From 42fd51030aa00d81b9b14ab6bdd92cb25a528c04 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 21:56:52 +0200 Subject: [PATCH 02/11] Update _typing.py --- xarray/namedarray/_typing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index b715973814f..1a169a28ff9 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -211,9 +211,7 @@ def __array_namespace__(self) -> ModuleType: ... # NamedArray can most likely use both __array_function__ and __array_namespace__: _arrayfunction_or_api = (_arrayfunction, _arrayapi) -duckarray = Union[ - _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co] -] +duckarray = Union[_arrayfunction[_ShapeType, _DType], _arrayapi[_ShapeType, _DType]] # Corresponds to np.typing.NDArray: DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]] From 78c8d4665516376f8f18d24ba630712a4f252e79 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 21:59:41 +0200 Subject: [PATCH 03/11] Revert "Update _typing.py" This reverts commit 42fd51030aa00d81b9b14ab6bdd92cb25a528c04. --- xarray/namedarray/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 1a169a28ff9..b715973814f 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -211,7 +211,9 @@ def __array_namespace__(self) -> ModuleType: ... # NamedArray can most likely use both __array_function__ and __array_namespace__: _arrayfunction_or_api = (_arrayfunction, _arrayapi) -duckarray = Union[_arrayfunction[_ShapeType, _DType], _arrayapi[_ShapeType, _DType]] +duckarray = Union[ + _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co] +] # Corresponds to np.typing.NDArray: DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]] From f85da8abb017011ab2a0f42cd635f3b859bf228d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:20:40 +0200 Subject: [PATCH 04/11] Test main push --- xarray/namedarray/_typing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index b715973814f..8aa34b6b5af 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -317,3 +317,5 @@ def todense(self) -> np.ndarray[Any, _DType_co]: ... ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] + +# test From 5edd249a94dc592459280b48b16eaede449131c4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:23:43 +0200 Subject: [PATCH 05/11] Revert "Test main push" This reverts commit f85da8abb017011ab2a0f42cd635f3b859bf228d. --- xarray/namedarray/_typing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 8aa34b6b5af..b715973814f 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -317,5 +317,3 @@ def todense(self) -> np.ndarray[Any, _DType_co]: ... ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] - -# test From 0eac740514ba5e08bf5208b430ccc0b23b17c18f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 12 Jul 2024 17:52:32 -0700 Subject: [PATCH 06/11] Allow mypy to run in vscode (#9239) Seems to require adding an exclude for `build` and an `__init__.py` file in the `properties` directory... --- properties/__init__.py | 0 pyproject.toml | 1 + 2 files changed, 1 insertion(+) create mode 100644 properties/__init__.py diff --git a/properties/__init__.py b/properties/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pyproject.toml b/pyproject.toml index 2ada0c1c171..4704751b445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] enable_error_code = "redundant-self" exclude = [ + 'build', 'xarray/util/generate_.*\.py', 'xarray/datatree_/doc/.*\.py', ] From d8b76448e3d20556ed0107dab8f702cd7c9d70f6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 13 Jul 2024 22:35:29 +0200 Subject: [PATCH 07/11] Fix typing for test_plot.py (#9234) * Fix typing for test_plot.py * Update test_plot.py * make sure we actually get ndarrays here, I get it locally at least * Add a minimal test and ignore in real test * Update test_plot.py * Update test_plot.py --- xarray/tests/test_plot.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index fa08e9975ab..578e6bcc18e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -158,9 +158,10 @@ def setup(self) -> Generator: plt.close("all") def pass_in_axis(self, plotmethod, subplot_kw=None) -> None: - fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw) - plotmethod(ax=axs[0]) - assert axs[0].has_data() + fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw, squeeze=False) + ax = axs[0, 0] + plotmethod(ax=ax) + assert ax.has_data() @pytest.mark.slow def imshow_called(self, plotmethod) -> bool: @@ -240,9 +241,9 @@ def test_1d_x_y_kw(self) -> None: xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]] - f, ax = plt.subplots(3, 1) + f, axs = plt.subplots(3, 1, squeeze=False) for aa, (x, y) in enumerate(xy): - da.plot(x=x, y=y, ax=ax.flat[aa]) + da.plot(x=x, y=y, ax=axs.flat[aa]) with pytest.raises(ValueError, match=r"Cannot specify both"): da.plot(x="z", y="z") @@ -1566,7 +1567,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyLabel" in alltxt assert "testvar" not in alltxt # change cbar ax - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"} ) @@ -1576,7 +1579,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyBar" in alltxt assert "testvar" not in alltxt # note that there are two ways to achieve this - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax} ) @@ -3371,16 +3376,16 @@ def test_plot1d_default_rcparams() -> None: # see overlapping markers: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax) - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual: np.ndarray = mpl.colors.to_rgba_array("w") + expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # Facetgrids should have the default value as well: fg = ds.plot.scatter(x="A", y="B", col="x", marker="o") ax = fg.axs.ravel()[0] - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual = mpl.colors.to_rgba_array("w") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # scatter should not emit any warnings when using unfilled markers: with assert_no_warnings(): @@ -3390,9 +3395,9 @@ def test_plot1d_default_rcparams() -> None: # Prioritize edgecolor argument over default plot1d values: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k") - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k") - ) + actual = mpl.colors.to_rgba_array("k") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) @requires_matplotlib From 9c26ca743a9cf1262e4f95b5097b91ba0adb6103 Mon Sep 17 00:00:00 2001 From: ChrisCleaner <61554538+ChrisCleaner@users.noreply.github.com> Date: Mon, 15 Jul 2024 20:14:09 +0200 Subject: [PATCH 08/11] Added a space to the documentation (#9247) --- doc/user-guide/terminology.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 55937310827..140f3dd1faa 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -221,7 +221,7 @@ complete examples, please consult the relevant documentation.* combined_ds lazy - Lazily-evaluated operations do not load data into memory until necessary.Instead of doing calculations + Lazily-evaluated operations do not load data into memory until necessary. Instead of doing calculations right away, xarray lets you plan what calculations you want to do, like finding the average temperature in a dataset.This planning is called "lazy evaluation." Later, when you're ready to see the final result, you tell xarray, "Okay, go ahead and do those calculations now!" From 076c0c2740ea778186817790bd3d626740c2fe04 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Mon, 15 Jul 2024 21:19:55 +0200 Subject: [PATCH 09/11] test push From 7477fd10115ba89b86752ee967c858f6f0fe7f9b Mon Sep 17 00:00:00 2001 From: Mathijs Verhaegh Date: Tue, 16 Jul 2024 08:25:41 +0200 Subject: [PATCH 10/11] Per-variable specification of boolean parameters in open_dataset (#9218) * allow per-variable choice of mask_and_scale in open_dataset * simplify docstring datatype * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * dict -> Mapping in type annotation Co-authored-by: Michael Niklas * use typevar for _item_or_default annotation Otherwise you lose all typing when you use that because it returns Any. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement feature for 4 additional parameters * fix default value inconsistency * add what's new + None annotation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * _item_or_default return type T | None * remove deault default value _item_or_default * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring dtype naming --------- Co-authored-by: Mathijs Verhaegh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 34 ++++++++++++++++++++++------------ xarray/conventions.py | 34 ++++++++++++++++++++++------------ 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6a8e898c93c..4d4291e050e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,9 @@ v2024.06.1 (unreleased) New Features ~~~~~~~~~~~~ +- Allow per-variable specification of ``mask_and_scale``, ``decode_times``, ``decode_timedelta`` + ``use_cftime`` and ``concat_characters`` params in :py:func:`~xarray.open_dataset` (:pull:`9218`). + By `Mathijs Verhaegh `_. - Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). By `Martin Raspaud `_. - Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 521bdf65e6a..ece60a2b161 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -398,11 +398,11 @@ def open_dataset( chunks: T_Chunks = None, cache: bool | None = None, decode_cf: bool | None = None, - mask_and_scale: bool | None = None, - decode_times: bool | None = None, - decode_timedelta: bool | None = None, - use_cftime: bool | None = None, - concat_characters: bool | None = None, + mask_and_scale: bool | Mapping[str, bool] | None = None, + decode_times: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, @@ -451,25 +451,31 @@ def open_dataset( decode_cf : bool, optional Whether to decode these variables, assuming they were saved according to CF conventions. - mask_and_scale : bool, optional + mask_and_scale : bool or dict-like, optional If True, replace array values equal to `_FillValue` with NA and scale values according to the formula `original_values * scale_factor + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. This keyword may not be supported by all the backends. - decode_times : bool, optional + be replaced by NA. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_times : bool or dict-like, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. - decode_timedelta : bool, optional + decode_timedelta : bool or dict-like, optional If True, decode variables and coordinates with time units in {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} into timedelta objects. If False, leave them encoded as numbers. If None (default), assume the same value of decode_time. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. - use_cftime: bool, optional + use_cftime: bool or dict-like, optional Only relevant if encoded dates come from a standard calendar (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to @@ -478,12 +484,16 @@ def open_dataset( ``cftime.datetime`` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible - raise an error. This keyword may not be supported by all the backends. - concat_characters : bool, optional + raise an error. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + concat_characters : bool or dict-like, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and removed) if they have no corresponding variable and if they are only used as the last dimension of character arrays. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. decode_coords : bool or {"coordinates", "all"}, optional Controls which variables are set as coordinate variables: diff --git a/xarray/conventions.py b/xarray/conventions.py index 6eff45c5b2d..ff1256883ba 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -2,7 +2,7 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union import numpy as np import pandas as pd @@ -384,16 +384,26 @@ def _update_bounds_encoding(variables: T_Variables) -> None: bounds_encoding.setdefault("calendar", encoding["calendar"]) +T = TypeVar("T") + + +def _item_or_default(obj: Mapping[Any, T] | T, key: Hashable, default: T) -> T: + """ + Return item by key if obj is mapping and key is present, else return default value. + """ + return obj.get(key, default) if isinstance(obj, Mapping) else obj + + def decode_cf_variables( variables: T_Variables, attributes: T_Attrs, - concat_characters: bool = True, - mask_and_scale: bool = True, - decode_times: bool = True, + concat_characters: bool | Mapping[str, bool] = True, + mask_and_scale: bool | Mapping[str, bool] = True, + decode_times: bool | Mapping[str, bool] = True, decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, - use_cftime: bool | None = None, - decode_timedelta: bool | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, ) -> tuple[T_Variables, T_Attrs, set[Hashable]]: """ Decode several CF encoded variables. @@ -431,7 +441,7 @@ def stackable(dim: Hashable) -> bool: if k in drop_variables: continue stack_char_dim = ( - concat_characters + _item_or_default(concat_characters, k, True) and v.dtype == "S1" and v.ndim > 0 and stackable(v.dims[-1]) @@ -440,12 +450,12 @@ def stackable(dim: Hashable) -> bool: new_vars[k] = decode_cf_variable( k, v, - concat_characters=concat_characters, - mask_and_scale=mask_and_scale, - decode_times=decode_times, + concat_characters=_item_or_default(concat_characters, k, True), + mask_and_scale=_item_or_default(mask_and_scale, k, True), + decode_times=_item_or_default(decode_times, k, True), stack_char_dim=stack_char_dim, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + use_cftime=_item_or_default(use_cftime, k, None), + decode_timedelta=_item_or_default(decode_timedelta, k, None), ) except Exception as e: raise type(e)(f"Failed to decode variable {k!r}: {e}") from e From 71fce9b0be5a00004a4ba0ff30c3e661f1790cdb Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Wed, 17 Jul 2024 06:00:53 +0200 Subject: [PATCH 11/11] Enable pandas type checking (#9213) * remove pandas from ignore missing imports * add Any to dim arg of concat as placeholder * allow sequence of np.ndarrays as coords in dataArray constructor * fix several typing issues in tests * fix more types * more fixes * more typing... * we are getting there? * who might have guessed it... more typing * continue fixing typing issues * fix some typed_ops * fix last non-typed-ops errors * update typed ops * remove useless DaskArray type in scalar or array type * fix missing import in type_checking * fix import * improve cftime offsets typing * fix classvars * fix some checks * fix a broken test * improve typing of test_concat * fix broken concat * add whats-new --------- Co-authored-by: Maximilian Roos --- doc/whats-new.rst | 2 + pyproject.toml | 1 - xarray/coding/cftime_offsets.py | 155 ++++++++++++---------- xarray/coding/cftimeindex.py | 63 +++++---- xarray/coding/times.py | 37 +++--- xarray/core/_typed_ops.py | 134 +++++++++---------- xarray/core/concat.py | 49 ++++--- xarray/core/dataarray.py | 26 +++- xarray/core/dataset.py | 21 ++- xarray/core/extension_array.py | 12 +- xarray/core/groupers.py | 46 ++++--- xarray/core/indexes.py | 71 +++++----- xarray/core/indexing.py | 13 +- xarray/core/missing.py | 4 +- xarray/core/resample_cftime.py | 31 +++-- xarray/core/types.py | 12 +- xarray/core/utils.py | 45 +++---- xarray/core/variable.py | 4 +- xarray/namedarray/daskmanager.py | 8 +- xarray/tests/test_backends.py | 8 +- xarray/tests/test_cftime_offsets.py | 2 +- xarray/tests/test_cftimeindex.py | 40 +++--- xarray/tests/test_cftimeindex_resample.py | 13 +- xarray/tests/test_coding_times.py | 35 ++--- xarray/tests/test_concat.py | 20 +-- xarray/tests/test_conventions.py | 2 +- xarray/tests/test_dataarray.py | 93 +++++++------ xarray/tests/test_dataset.py | 44 +++--- xarray/tests/test_formatting.py | 4 +- xarray/tests/test_groupby.py | 10 +- xarray/tests/test_indexes.py | 25 ++-- xarray/tests/test_plot.py | 8 +- xarray/tests/test_rolling.py | 22 +-- xarray/tests/test_variable.py | 53 ++++---- xarray/util/generate_ops.py | 23 ++-- 35 files changed, 632 insertions(+), 504 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4d4291e050e..74c7104117a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -82,6 +82,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Enable typing checks of pandas (:pull:`9213`). + By `Michael Niklas `_. .. _whats-new.2024.06.0: diff --git a/pyproject.toml b/pyproject.toml index 4704751b445..3eafcda7670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ module = [ "netCDF4.*", "netcdftime.*", "opt_einsum.*", - "pandas.*", "pint.*", "pooch.*", "pyarrow.*", diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index c2712569782..9dbc60ef0f3 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -43,9 +43,10 @@ from __future__ import annotations import re +from collections.abc import Mapping from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, Literal import numpy as np import pandas as pd @@ -74,7 +75,10 @@ if TYPE_CHECKING: - from xarray.core.types import InclusiveOptions, SideOptions + from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias + + +DayOption: TypeAlias = Literal["start", "end"] def _nanosecond_precision_timestamp(*args, **kwargs): @@ -109,9 +113,10 @@ def get_date_type(calendar, use_cftime=True): class BaseCFTimeOffset: _freq: ClassVar[str | None] = None - _day_option: ClassVar[str | None] = None + _day_option: ClassVar[DayOption | None] = None + n: int - def __init__(self, n: int = 1): + def __init__(self, n: int = 1) -> None: if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " @@ -119,13 +124,15 @@ def __init__(self, n: int = 1): ) self.n = n - def rule_code(self): + def rule_code(self) -> str | None: return self._freq - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, BaseCFTimeOffset): + return NotImplemented return self.n == other.n and self.rule_code() == other.rule_code() - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other def __add__(self, other): @@ -142,12 +149,12 @@ def __sub__(self, other): else: return NotImplemented - def __mul__(self, other): + def __mul__(self, other: int) -> Self: if not isinstance(other, int): return NotImplemented return type(self)(n=other * self.n) - def __neg__(self): + def __neg__(self) -> Self: return self * -1 def __rmul__(self, other): @@ -161,10 +168,10 @@ def __rsub__(self, other): raise TypeError("Cannot subtract cftime offsets of differing types") return -self + other - def __apply__(self): + def __apply__(self, other): return NotImplemented - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" test_date = (self + date) - self @@ -197,22 +204,21 @@ def _get_offset_day(self, other): class Tick(BaseCFTimeOffset): # analogous https://github.com/pandas-dev/pandas/blob/ccb25ab1d24c4fb9691270706a59c8d319750870/pandas/_libs/tslibs/offsets.pyx#L806 - def _next_higher_resolution(self): + def _next_higher_resolution(self) -> Tick: self_type = type(self) - if self_type not in [Day, Hour, Minute, Second, Millisecond]: - raise ValueError("Could not convert to integer offset at any resolution") - if type(self) is Day: + if self_type is Day: return Hour(self.n * 24) - if type(self) is Hour: + if self_type is Hour: return Minute(self.n * 60) - if type(self) is Minute: + if self_type is Minute: return Second(self.n * 60) - if type(self) is Second: + if self_type is Second: return Millisecond(self.n * 1000) - if type(self) is Millisecond: + if self_type is Millisecond: return Microsecond(self.n * 1000) + raise ValueError("Could not convert to integer offset at any resolution") - def __mul__(self, other): + def __mul__(self, other: int | float) -> Tick: if not isinstance(other, (int, float)): return NotImplemented if isinstance(other, float): @@ -227,12 +233,12 @@ def __mul__(self, other): return new_self * other return type(self)(n=other * self.n) - def as_timedelta(self): + def as_timedelta(self) -> timedelta: """All Tick subclasses must implement an as_timedelta method.""" raise NotImplementedError -def _get_day_of_month(other, day_option): +def _get_day_of_month(other, day_option: DayOption) -> int: """Find the day in `other`'s month that satisfies a BaseCFTimeOffset's onOffset policy, as described by the `day_option` argument. @@ -251,14 +257,13 @@ def _get_day_of_month(other, day_option): if day_option == "start": return 1 - elif day_option == "end": + if day_option == "end": return _days_in_month(other) - elif day_option is None: + if day_option is None: # Note: unlike `_shift_month`, _get_day_of_month does not # allow day_option = None raise NotImplementedError() - else: - raise ValueError(day_option) + raise ValueError(day_option) def _days_in_month(date): @@ -293,7 +298,7 @@ def _adjust_n_years(other, n, month, reference_day): return n -def _shift_month(date, months, day_option="start"): +def _shift_month(date, months, day_option: DayOption = "start"): """Shift the date to a month start or end a given number of months away.""" if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") @@ -316,7 +321,9 @@ def _shift_month(date, months, day_option="start"): return date.replace(year=year, month=month, day=day) -def roll_qtrday(other, n, month, day_option, modby=3): +def roll_qtrday( + other, n: int, month: int, day_option: DayOption, modby: int = 3 +) -> int: """Possibly increment or decrement the number of periods to shift based on rollforward/rollbackward conventions. @@ -357,7 +364,7 @@ def roll_qtrday(other, n, month, day_option, modby=3): return n -def _validate_month(month, default_month): +def _validate_month(month: int | None, default_month: int) -> int: result_month = default_month if month is None else month if not isinstance(result_month, int): raise TypeError( @@ -381,7 +388,7 @@ def __apply__(self, other): n = _adjust_n_months(other.day, self.n, 1) return _shift_month(other, n, "start") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 @@ -394,7 +401,7 @@ def __apply__(self, other): n = _adjust_n_months(other.day, self.n, _days_in_month(other)) return _shift_month(other, n, "end") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == _days_in_month(date) @@ -419,10 +426,10 @@ def onOffset(self, date): class QuarterOffset(BaseCFTimeOffset): """Quarter representation copied off of pandas/tseries/offsets.py""" - _freq: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -439,29 +446,28 @@ def __apply__(self, other): months = qtrs * 3 - months_since return _shift_month(other, months, self._day_option) - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" mod_month = (date.month - self.month) % 3 return mod_month == 0 and date.day == self._get_offset_day(date) - def __sub__(self, other): + def __sub__(self, other: Self) -> Self: if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract cftime.datetime from offset.") - elif type(other) == type(self) and other.month == self.month: + if type(other) == type(self) and other.month == self.month: return type(self)(self.n - other.n, month=self.month) - else: - return NotImplemented + return NotImplemented def __mul__(self, other): if isinstance(other, float): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" def __str__(self): @@ -519,11 +525,10 @@ def rollback(self, date): class YearOffset(BaseCFTimeOffset): - _freq: ClassVar[str] - _day_option: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -549,10 +554,10 @@ def __mul__(self, other): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" - def __str__(self): + def __str__(self) -> str: return f"<{type(self).__name__}: n={self.n}, month={self.month}>" @@ -561,7 +566,7 @@ class YearBegin(YearOffset): _day_option = "start" _default_month = 1 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 and date.month == self.month @@ -586,7 +591,7 @@ class YearEnd(YearOffset): _day_option = "end" _default_month = 12 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == _days_in_month(date) and date.month == self.month @@ -609,7 +614,7 @@ def rollback(self, date): class Day(Tick): _freq = "D" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(days=self.n) def __apply__(self, other): @@ -619,7 +624,7 @@ def __apply__(self, other): class Hour(Tick): _freq = "h" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(hours=self.n) def __apply__(self, other): @@ -629,7 +634,7 @@ def __apply__(self, other): class Minute(Tick): _freq = "min" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(minutes=self.n) def __apply__(self, other): @@ -639,7 +644,7 @@ def __apply__(self, other): class Second(Tick): _freq = "s" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(seconds=self.n) def __apply__(self, other): @@ -649,7 +654,7 @@ def __apply__(self, other): class Millisecond(Tick): _freq = "ms" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(milliseconds=self.n) def __apply__(self, other): @@ -659,30 +664,32 @@ def __apply__(self, other): class Microsecond(Tick): _freq = "us" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(microseconds=self.n) def __apply__(self, other): return other + self.as_timedelta() -def _generate_anchored_offsets(base_freq, offset): - offsets = {} +def _generate_anchored_offsets( + base_freq: str, offset: type[YearOffset | QuarterOffset] +) -> dict[str, type[BaseCFTimeOffset]]: + offsets: dict[str, type[BaseCFTimeOffset]] = {} for month, abbreviation in _MONTH_ABBREVIATIONS.items(): anchored_freq = f"{base_freq}-{abbreviation}" - offsets[anchored_freq] = partial(offset, month=month) + offsets[anchored_freq] = partial(offset, month=month) # type: ignore[assignment] return offsets -_FREQUENCIES = { +_FREQUENCIES: Mapping[str, type[BaseCFTimeOffset]] = { "A": YearEnd, "AS": YearBegin, "Y": YearEnd, "YE": YearEnd, "YS": YearBegin, - "Q": partial(QuarterEnd, month=12), - "QE": partial(QuarterEnd, month=12), - "QS": partial(QuarterBegin, month=1), + "Q": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QE": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QS": partial(QuarterBegin, month=1), # type: ignore[dict-item] "M": MonthEnd, "ME": MonthEnd, "MS": MonthBegin, @@ -717,7 +724,9 @@ def _generate_anchored_offsets(base_freq, offset): CFTIME_TICKS = (Day, Hour, Minute, Second) -def _generate_anchored_deprecated_frequencies(deprecated, recommended): +def _generate_anchored_deprecated_frequencies( + deprecated: str, recommended: str +) -> dict[str, str]: pairs = {} for abbreviation in _MONTH_ABBREVIATIONS.values(): anchored_deprecated = f"{deprecated}-{abbreviation}" @@ -726,7 +735,7 @@ def _generate_anchored_deprecated_frequencies(deprecated, recommended): return pairs -_DEPRECATED_FREQUENICES = { +_DEPRECATED_FREQUENICES: dict[str, str] = { "A": "YE", "Y": "YE", "AS": "YS", @@ -759,16 +768,16 @@ def _emit_freq_deprecation_warning(deprecated_freq): emit_user_level_warning(message, FutureWarning) -def to_offset(freq, warn=True): +def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset: """Convert a frequency string to the appropriate subclass of BaseCFTimeOffset.""" if isinstance(freq, BaseCFTimeOffset): return freq - else: - try: - freq_data = re.match(_PATTERN, freq).groupdict() - except AttributeError: - raise ValueError("Invalid frequency string provided") + + match = re.match(_PATTERN, freq) + if match is None: + raise ValueError("Invalid frequency string provided") + freq_data = match.groupdict() freq = freq_data["freq"] if warn and freq in _DEPRECATED_FREQUENICES: @@ -909,7 +918,9 @@ def _translate_closed_to_inclusive(closed): return inclusive -def _infer_inclusive(closed, inclusive): +def _infer_inclusive( + closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None +) -> InclusiveOptions: """Follows code added in pandas #43504.""" if closed is not no_default and inclusive is not None: raise ValueError( @@ -917,9 +928,9 @@ def _infer_inclusive(closed, inclusive): "passed if argument `inclusive` is not None." ) if closed is not no_default: - inclusive = _translate_closed_to_inclusive(closed) - elif inclusive is None: - inclusive = "both" + return _translate_closed_to_inclusive(closed) + if inclusive is None: + return "both" return inclusive @@ -933,7 +944,7 @@ def cftime_range( closed: NoDefault | SideOptions = no_default, inclusive: None | InclusiveOptions = None, calendar="standard", -): +) -> CFTimeIndex: """Return a fixed frequency CFTimeIndex. Parameters diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 6898809e3b0..cd902257902 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -45,6 +45,7 @@ import re import warnings from datetime import timedelta +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -64,6 +65,10 @@ except ImportError: cftime = None +if TYPE_CHECKING: + from xarray.coding.cftime_offsets import BaseCFTimeOffset + from xarray.core.types import Self + # constants for cftimeindex.repr CFTIME_REPR_LENGTH = 19 @@ -495,7 +500,7 @@ def get_value(self, series, key): else: return series.iloc[self.get_loc(key)] - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: """Adapted from pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" try: @@ -503,16 +508,20 @@ def __contains__(self, key): return ( is_scalar(result) or type(result) == slice - or (isinstance(result, np.ndarray) and result.size) + or (isinstance(result, np.ndarray) and result.size > 0) ) except (KeyError, TypeError, ValueError): return False - def contains(self, key): + def contains(self, key: Any) -> bool: """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift(self, n: int | float, freq: str | timedelta): + def shift( # type: ignore[override] # freq is typed Any, we are more precise + self, + periods: int | float, + freq: str | timedelta | BaseCFTimeOffset | None = None, + ) -> Self: """Shift the CFTimeIndex a multiple of the given frequency. See the documentation for :py:func:`~xarray.cftime_range` for a @@ -520,9 +529,9 @@ def shift(self, n: int | float, freq: str | timedelta): Parameters ---------- - n : int, float if freq of days or below + periods : int, float if freq of days or below Periods to shift by - freq : str or datetime.timedelta + freq : str, datetime.timedelta or BaseCFTimeOffset A frequency string or datetime.timedelta object to shift by Returns @@ -546,33 +555,42 @@ def shift(self, n: int | float, freq: str | timedelta): CFTimeIndex([2000-02-01 12:00:00], dtype='object', length=1, calendar='standard', freq=None) """ - if isinstance(freq, timedelta): - return self + n * freq - elif isinstance(freq, str): - from xarray.coding.cftime_offsets import to_offset + from xarray.coding.cftime_offsets import BaseCFTimeOffset - return self + n * to_offset(freq) - else: + if freq is None: + # None type is required to be compatible with base pd.Index class raise TypeError( - f"'freq' must be of type str or datetime.timedelta, got {freq}." + f"`freq` argument cannot be None for {type(self).__name__}.shift" ) - def __add__(self, other): + if isinstance(freq, timedelta): + return self + periods * freq + + if isinstance(freq, (str, BaseCFTimeOffset)): + from xarray.coding.cftime_offsets import to_offset + + return self + periods * to_offset(freq) + + raise TypeError( + f"'freq' must be of type str or datetime.timedelta, got {type(freq)}." + ) + + def __add__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(np.array(self) + other) + return type(self)(np.array(self) + other) - def __radd__(self, other): + def __radd__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(other + np.array(self)) + return type(self)(other + np.array(self)) def __sub__(self, other): if _contains_datetime_timedeltas(other): - return CFTimeIndex(np.array(self) - other) - elif isinstance(other, pd.TimedeltaIndex): - return CFTimeIndex(np.array(self) - other.to_pytimedelta()) - elif _contains_cftime_datetimes(np.array(other)): + return type(self)(np.array(self) - other) + if isinstance(other, pd.TimedeltaIndex): + return type(self)(np.array(self) - other.to_pytimedelta()) + if _contains_cftime_datetimes(np.array(other)): try: return pd.TimedeltaIndex(np.array(self) - np.array(other)) except OUT_OF_BOUNDS_TIMEDELTA_ERRORS: @@ -580,8 +598,7 @@ def __sub__(self, other): "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." ) - else: - return NotImplemented + return NotImplemented def __rsub__(self, other): try: diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 50a2ba93c09..badb9259b06 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -5,7 +5,7 @@ from collections.abc import Hashable from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, Callable, Union +from typing import Callable, Literal, Union, cast import numpy as np import pandas as pd @@ -36,10 +36,9 @@ except ImportError: cftime = None -if TYPE_CHECKING: - from xarray.core.types import CFCalendar, T_DuckArray +from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray - T_Name = Union[Hashable, None] +T_Name = Union[Hashable, None] # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} @@ -111,22 +110,25 @@ def _is_numpy_compatible_time_range(times): return True -def _netcdf_to_numpy_timeunit(units: str) -> str: +def _netcdf_to_numpy_timeunit(units: str) -> NPDatetimeUnitOptions: units = units.lower() if not units.endswith("s"): units = f"{units}s" - return { - "nanoseconds": "ns", - "microseconds": "us", - "milliseconds": "ms", - "seconds": "s", - "minutes": "m", - "hours": "h", - "days": "D", - }[units] + return cast( + NPDatetimeUnitOptions, + { + "nanoseconds": "ns", + "microseconds": "us", + "milliseconds": "ms", + "seconds": "s", + "minutes": "m", + "hours": "h", + "days": "D", + }[units], + ) -def _numpy_to_netcdf_timeunit(units: str) -> str: +def _numpy_to_netcdf_timeunit(units: NPDatetimeUnitOptions) -> str: return { "ns": "nanoseconds", "us": "microseconds", @@ -252,12 +254,12 @@ def _decode_datetime_with_pandas( "pandas." ) - time_units, ref_date = _unpack_netcdf_time_units(units) + time_units, ref_date_str = _unpack_netcdf_time_units(units) time_units = _netcdf_to_numpy_timeunit(time_units) try: # TODO: the strict enforcement of nanosecond precision Timestamps can be # relaxed when addressing GitHub issue #7493. - ref_date = nanosecond_precision_timestamp(ref_date) + ref_date = nanosecond_precision_timestamp(ref_date_str) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp # strings, in which case we fall back to using cftime @@ -471,6 +473,7 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: # TODO: the strict enforcement of nanosecond precision datetime values can # be relaxed when addressing GitHub issue #7493. new = np.empty(times.shape, dtype="M8[ns]") + dt: pd.Timestamp | Literal["NaT"] for i, t in np.ndenumerate(times): try: # Use pandas.Timestamp in place of datetime.datetime, because diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index c1748e322c2..61aa1846bd0 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -11,15 +11,15 @@ from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByCompatible, Self, - T_DataArray, T_Xarray, VarCompatible, ) if TYPE_CHECKING: + from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.types import T_DataArray as T_DA class DatasetOpsMixin: @@ -455,165 +455,165 @@ def _binary_op( raise NotImplementedError @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... + def __add__(self, other: T_DA) -> T_DA: ... @overload def __add__(self, other: VarCompatible) -> Self: ... - def __add__(self, other: VarCompatible) -> Self | T_DataArray: + def __add__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.add) @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... + def __sub__(self, other: T_DA) -> T_DA: ... @overload def __sub__(self, other: VarCompatible) -> Self: ... - def __sub__(self, other: VarCompatible) -> Self | T_DataArray: + def __sub__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.sub) @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... + def __mul__(self, other: T_DA) -> T_DA: ... @overload def __mul__(self, other: VarCompatible) -> Self: ... - def __mul__(self, other: VarCompatible) -> Self | T_DataArray: + def __mul__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mul) @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... + def __pow__(self, other: T_DA) -> T_DA: ... @overload def __pow__(self, other: VarCompatible) -> Self: ... - def __pow__(self, other: VarCompatible) -> Self | T_DataArray: + def __pow__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.pow) @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + def __truediv__(self, other: T_DA) -> T_DA: ... @overload def __truediv__(self, other: VarCompatible) -> Self: ... - def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: + def __truediv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.truediv) @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + def __floordiv__(self, other: T_DA) -> T_DA: ... @overload def __floordiv__(self, other: VarCompatible) -> Self: ... - def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: + def __floordiv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.floordiv) @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... + def __mod__(self, other: T_DA) -> T_DA: ... @overload def __mod__(self, other: VarCompatible) -> Self: ... - def __mod__(self, other: VarCompatible) -> Self | T_DataArray: + def __mod__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mod) @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... + def __and__(self, other: T_DA) -> T_DA: ... @overload def __and__(self, other: VarCompatible) -> Self: ... - def __and__(self, other: VarCompatible) -> Self | T_DataArray: + def __and__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.and_) @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... + def __xor__(self, other: T_DA) -> T_DA: ... @overload def __xor__(self, other: VarCompatible) -> Self: ... - def __xor__(self, other: VarCompatible) -> Self | T_DataArray: + def __xor__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.xor) @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... + def __or__(self, other: T_DA) -> T_DA: ... @overload def __or__(self, other: VarCompatible) -> Self: ... - def __or__(self, other: VarCompatible) -> Self | T_DataArray: + def __or__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.or_) @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + def __lshift__(self, other: T_DA) -> T_DA: ... @overload def __lshift__(self, other: VarCompatible) -> Self: ... - def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __lshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lshift) @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + def __rshift__(self, other: T_DA) -> T_DA: ... @overload def __rshift__(self, other: VarCompatible) -> Self: ... - def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __rshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.rshift) @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... + def __lt__(self, other: T_DA) -> T_DA: ... @overload def __lt__(self, other: VarCompatible) -> Self: ... - def __lt__(self, other: VarCompatible) -> Self | T_DataArray: + def __lt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lt) @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... + def __le__(self, other: T_DA) -> T_DA: ... @overload def __le__(self, other: VarCompatible) -> Self: ... - def __le__(self, other: VarCompatible) -> Self | T_DataArray: + def __le__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.le) @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... + def __gt__(self, other: T_DA) -> T_DA: ... @overload def __gt__(self, other: VarCompatible) -> Self: ... - def __gt__(self, other: VarCompatible) -> Self | T_DataArray: + def __gt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.gt) @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... + def __ge__(self, other: T_DA) -> T_DA: ... @overload def __ge__(self, other: VarCompatible) -> Self: ... - def __ge__(self, other: VarCompatible) -> Self | T_DataArray: + def __ge__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.ge) @overload # type:ignore[override] - def __eq__(self, other: T_DataArray) -> T_DataArray: ... + def __eq__(self, other: T_DA) -> T_DA: ... @overload def __eq__(self, other: VarCompatible) -> Self: ... - def __eq__(self, other: VarCompatible) -> Self | T_DataArray: + def __eq__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_eq) @overload # type:ignore[override] - def __ne__(self, other: T_DataArray) -> T_DataArray: ... + def __ne__(self, other: T_DA) -> T_DA: ... @overload def __ne__(self, other: VarCompatible) -> Self: ... - def __ne__(self, other: VarCompatible) -> Self | T_DataArray: + def __ne__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, @@ -770,96 +770,96 @@ class DatasetGroupByOpsMixin: __slots__ = () def _binary_op( - self, other: GroupByCompatible, f: Callable, reflexive: bool = False + self, other: Dataset | DataArray, f: Callable, reflexive: bool = False ) -> Dataset: raise NotImplementedError - def __add__(self, other: GroupByCompatible) -> Dataset: + def __add__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other: GroupByCompatible) -> Dataset: + def __sub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other: GroupByCompatible) -> Dataset: + def __mul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other: GroupByCompatible) -> Dataset: + def __pow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other: GroupByCompatible) -> Dataset: + def __truediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other: GroupByCompatible) -> Dataset: + def __floordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other: GroupByCompatible) -> Dataset: + def __mod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other: GroupByCompatible) -> Dataset: + def __and__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other: GroupByCompatible) -> Dataset: + def __xor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other: GroupByCompatible) -> Dataset: + def __or__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other: GroupByCompatible) -> Dataset: + def __lshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other: GroupByCompatible) -> Dataset: + def __rshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other: GroupByCompatible) -> Dataset: + def __lt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other: GroupByCompatible) -> Dataset: + def __le__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other: GroupByCompatible) -> Dataset: + def __gt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other: GroupByCompatible) -> Dataset: + def __ge__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __eq__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __ne__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, # and it should be declared as follows: __hash__: None # type:ignore[assignment] - def __radd__(self, other: GroupByCompatible) -> Dataset: + def __radd__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other: GroupByCompatible) -> Dataset: + def __rsub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other: GroupByCompatible) -> Dataset: + def __rmul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other: GroupByCompatible) -> Dataset: + def __rpow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other: GroupByCompatible) -> Dataset: + def __rtruediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: + def __rfloordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other: GroupByCompatible) -> Dataset: + def __rmod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other: GroupByCompatible) -> Dataset: + def __rand__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other: GroupByCompatible) -> Dataset: + def __rxor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other: GroupByCompatible) -> Dataset: + def __ror__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/concat.py b/xarray/core/concat.py index b1cca586992..15292bdb34b 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -32,10 +32,11 @@ T_DataVars = Union[ConcatOptions, Iterable[Hashable]] +# TODO: replace dim: Any by 1D array_likes @overload def concat( objs: Iterable[T_Dataset], - dim: Hashable | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -50,7 +51,7 @@ def concat( @overload def concat( objs: Iterable[T_DataArray], - dim: Hashable | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -303,7 +304,7 @@ def _calc_concat_dim_index( dim: Hashable | None - if isinstance(dim_or_data, str): + if utils.hashable(dim_or_data): dim = dim_or_data index = None else: @@ -474,7 +475,7 @@ def _parse_datasets( def _dataset_concat( - datasets: list[T_Dataset], + datasets: Iterable[T_Dataset], dim: str | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars, coords: str | list[str], @@ -505,12 +506,14 @@ def _dataset_concat( else: dim_var = None - dim, index = _calc_concat_dim_index(dim) + dim_name, index = _calc_concat_dim_index(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] datasets = list( - align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) + align( + *datasets, join=join, copy=False, exclude=[dim_name], fill_value=fill_value + ) ) dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets( @@ -524,19 +527,21 @@ def _dataset_concat( f"{both_data_and_coords!r} is a coordinate in some datasets but not others." ) # we don't want the concat dimension in the result dataset yet - dim_coords.pop(dim, None) - dims_sizes.pop(dim, None) + dim_coords.pop(dim_name, None) + dims_sizes.pop(dim_name, None) # case where concat dimension is a coordinate or data_var but not a dimension - if (dim in coord_names or dim in data_names) and dim not in dim_names: + if ( + dim_name in coord_names or dim_name in data_names + ) and dim_name not in dim_names: datasets = [ - ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim) + ds.expand_dims(dim_name, create_index_for_new_dim=create_index_for_new_dim) for ds in datasets ] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( - datasets, dim, dim_names, data_vars, coords, compat + datasets, dim_name, dim_names, data_vars, coords, compat ) # determine which variables to merge, and then merge them according to compat @@ -576,8 +581,8 @@ def ensure_common_dims(vars, concat_dim_lengths): # dimensions and the same shape for all of them except along the # concat dimension common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) - if dim not in common_dims: - common_dims = (dim,) + common_dims + if dim_name not in common_dims: + common_dims = (dim_name,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): if var.dims != common_dims: common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims) @@ -593,12 +598,12 @@ def get_indexes(name): for ds in datasets: if name in ds._indexes: yield ds._indexes[name] - elif name == dim: + elif name == dim_name: var = ds._variables[name] if not var.dims: - data = var.set_dims(dim).values + data = var.set_dims(dim_name).values if create_index_for_new_dim: - yield PandasIndex(data, dim, coord_dtype=var.dtype) + yield PandasIndex(data, dim_name, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) @@ -644,7 +649,7 @@ def get_indexes(name): f"{name!r} must have either an index or no index in all datasets, " f"found {len(indexes)}/{len(datasets)} datasets with an index." ) - combined_idx = indexes[0].concat(indexes, dim, positions) + combined_idx = indexes[0].concat(indexes, dim_name, positions) if name in datasets[0]._indexes: idx_vars = datasets[0].xindexes.get_all_coords(name) else: @@ -660,14 +665,14 @@ def get_indexes(name): result_vars[k] = v else: combined_var = concat_vars( - vars, dim, positions, combine_attrs=combine_attrs + vars, dim_name, positions, combine_attrs=combine_attrs ) # reindex if variable is not present in all datasets if len(variable_index) < concat_index_size: combined_var = reindex_variables( variables={name: combined_var}, dim_pos_indexers={ - dim: pd.Index(variable_index).get_indexer(concat_index) + dim_name: pd.Index(variable_index).get_indexer(concat_index) }, fill_value=fill_value, )[name] @@ -693,12 +698,12 @@ def get_indexes(name): if index is not None: if dim_var is not None: - index_vars = index.create_variables({dim: dim_var}) + index_vars = index.create_variables({dim_name: dim_var}) else: index_vars = index.create_variables() - coord_vars[dim] = index_vars[dim] - result_indexes[dim] = index + coord_vars[dim_name] = index_vars[dim_name] + result_indexes[dim_name] = index coords_obj = Coordinates(coord_vars, indexes=result_indexes) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 47dc9d13ffc..09f5664aa06 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -135,7 +135,11 @@ def _check_coords_dims(shape, coords, dim): def _infer_coords_and_dims( shape: tuple[int, ...], - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ), dims: str | Iterable[Hashable] | None, ) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" @@ -199,7 +203,11 @@ def _infer_coords_and_dims( def _check_data_shape( data: Any, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ), dims: str | Iterable[Hashable] | None, ) -> Any: if data is dtypes.NA: @@ -413,7 +421,11 @@ class DataArray( def __init__( self, data: Any = dtypes.NA, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ) = None, dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, @@ -965,7 +977,7 @@ def indexes(self) -> Indexes: return self.xindexes.to_pandas_indexes() @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of :py:class:`~xarray.indexes.Index` objects used for label based indexing. """ @@ -3004,7 +3016,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") - level_number = idx._get_level_number(level) + level_number = idx._get_level_number(level) # type: ignore[attr-defined] variables = idx.levels[level_number] variable_dim = idx.names[level_number] @@ -3838,7 +3850,7 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame: "pandas objects. Requires 2 or fewer dimensions." ) indexes = [self.get_index(dim) for dim in self.dims] - return constructor(self.values, *indexes) + return constructor(self.values, *indexes) # type: ignore[operator] def to_dataframe( self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None @@ -6841,7 +6853,7 @@ def groupby_bins( _validate_groupby_squeeze(squeeze) grouper = BinGrouper( - bins=bins, + bins=bins, # type: ignore[arg-type] # TODO: fix this arg or BinGrouper cut_kwargs={ "right": right, "labels": labels, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3930b12ef3d..1793abf02d8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4156,15 +4156,15 @@ def interp_like( kwargs = {} # pick only dimension coordinates with a single index - coords = {} + coords: dict[Hashable, Variable] = {} other_indexes = other.xindexes for dim in self.dims: other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") if len(other_dim_coords) == 1: coords[dim] = other_dim_coords[dim] - numeric_coords: dict[Hashable, pd.Index] = {} - object_coords: dict[Hashable, pd.Index] = {} + numeric_coords: dict[Hashable, Variable] = {} + object_coords: dict[Hashable, Variable] = {} for k, v in coords.items(): if v.dtype.kind in "uifcMm": numeric_coords[k] = v @@ -6539,7 +6539,13 @@ def interpolate_na( limit: int | None = None, use_coordinate: bool | Hashable = True, max_gap: ( - int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta + int + | float + | str + | pd.Timedelta + | np.timedelta64 + | datetime.timedelta + | None ) = None, **kwargs: Any, ) -> Self: @@ -6573,7 +6579,8 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta \ + or None, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -9715,7 +9722,7 @@ def eval( c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ - return pd.eval( + return pd.eval( # type: ignore[return-value] statement, resolvers=[self], target=self, @@ -10394,7 +10401,7 @@ def groupby_bins( _validate_groupby_squeeze(squeeze) grouper = BinGrouper( - bins=bins, + bins=bins, # type: ignore[arg-type] #TODO: fix this here or in BinGrouper? cut_kwargs={ "right": right, "labels": labels, diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index c8b4fa88409..b0361ef0f0f 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Callable, Generic +from typing import Callable, Generic, cast import numpy as np import pandas as pd @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) + return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] @implements(np.where) @@ -57,9 +57,9 @@ def __extension_duck_array__where( and isinstance(y, pd.Categorical) and x.dtype != y.dtype ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) - return pd.Series(x).where(condition, pd.Series(y)).array + x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] + y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] + return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) class PandasExtensionArray(Generic[T_ExtensionArray]): @@ -116,7 +116,7 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)(type(self.array)([item])) + return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): diff --git a/xarray/core/groupers.py b/xarray/core/groupers.py index 075afd9f62f..f76bd22a2f6 100644 --- a/xarray/core/groupers.py +++ b/xarray/core/groupers.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import pandas as pd @@ -25,6 +25,9 @@ from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable +if TYPE_CHECKING: + pass + __all__ = [ "EncodedGroups", "Grouper", @@ -160,6 +163,7 @@ def _factorize_dummy(self) -> EncodedGroups: # equivalent to: group_indices = group_indices.reshape(-1, 1) group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] size_range = np.arange(size) + full_index: pd.Index if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) unique_coord = self.group @@ -275,7 +279,7 @@ def _init_properties(self, group: T_Group) -> None: if isinstance(group_as_index, CFTimeIndex): from xarray.core.resample_cftime import CFTimeGrouper - index_grouper = CFTimeGrouper( + self.index_grouper = CFTimeGrouper( freq=self.freq, closed=self.closed, label=self.label, @@ -284,7 +288,7 @@ def _init_properties(self, group: T_Group) -> None: loffset=self.loffset, ) else: - index_grouper = pd.Grouper( + self.index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 freq=_new_to_legacy_freq(self.freq), closed=self.closed, @@ -292,7 +296,6 @@ def _init_properties(self, group: T_Group) -> None: origin=self.origin, offset=offset, ) - self.index_grouper = index_grouper self.group_as_index = group_as_index def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: @@ -305,22 +308,25 @@ def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: return full_index, first_items, codes def first_items(self) -> tuple[pd.Series, np.ndarray]: - from xarray import CFTimeIndex + from xarray.coding.cftimeindex import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper - if isinstance(self.group_as_index, CFTimeIndex): - return self.index_grouper.first_items(self.group_as_index) - else: - s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) - grouped = s.groupby(self.index_grouper) - first_items = grouped.first() - counts = grouped.count() - # This way we generate codes for the final output index: full_index. - # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups - codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) - return first_items, codes + if isinstance(self.index_grouper, CFTimeGrouper): + return self.index_grouper.first_items( + cast(CFTimeIndex, self.group_as_index) + ) + + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.index_grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) + return first_items, codes def factorize(self, group) -> EncodedGroups: self._init_properties(group) @@ -369,7 +375,7 @@ def _apply_loffset( ) if isinstance(loffset, str): - loffset = pd.tseries.frequencies.to_offset(loffset) + loffset = pd.tseries.frequencies.to_offset(loffset) # type: ignore[assignment] needs_offset = ( isinstance(loffset, (pd.DateOffset, datetime.timedelta)) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f25c0ecf936..9d8a68edbf3 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -451,7 +451,7 @@ def safe_cast_to_index(array: Any) -> pd.Index: elif isinstance(array, PandasIndexingAdapter): index = array.array else: - kwargs: dict[str, str] = {} + kwargs: dict[str, Any] = {} if hasattr(array, "dtype"): if array.dtype.kind == "O": kwargs["dtype"] = "object" @@ -551,7 +551,7 @@ def as_scalar(value: np.ndarray): return value[()] if value.dtype.kind in "mM" else value.item() -def get_indexer_nd(index, labels, method=None, tolerance=None): +def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.ndarray: """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels """ @@ -740,7 +740,7 @@ def isel( # scalar indexer: drop index return None - return self._replace(self.index[indxr]) + return self._replace(self.index[indxr]) # type: ignore[index] def sel( self, labels: dict[Any, Any], method=None, tolerance=None @@ -898,36 +898,45 @@ def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal" ) -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: +T_PDIndex = TypeVar("T_PDIndex", bound=pd.Index) + + +def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() + new_index = cast(pd.MultiIndex, index.remove_unused_levels()) # if it contains CategoricalIndex, we need to remove unused categories # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels): levels = [] - for i, level in enumerate(index.levels): + for i, level in enumerate(new_index.levels): if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() + level = level[new_index.codes[i]].remove_unused_categories() else: - level = level[index.codes[i]] + level = level[new_index.codes[i]] levels.append(level) # TODO: calling from_array() reorders MultiIndex levels. It would # be best to avoid this, if possible, e.g., by using # MultiIndex.remove_unused_levels() (which does not reorder) on the # part of the MultiIndex that is not categorical, or by fixing this # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() + new_index = pd.MultiIndex.from_arrays(levels, names=new_index.names) + return cast(T_PDIndex, new_index) + + if isinstance(index, pd.CategoricalIndex): + return index.remove_unused_categories() # type: ignore[attr-defined] + return index class PandasMultiIndex(PandasIndex): """Wrap a pandas.MultiIndex as an xarray compatible index.""" + index: pd.MultiIndex + dim: Hashable + coord_dtype: Any level_coords_dtype: dict[str, Any] __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") @@ -1063,8 +1072,8 @@ def from_variables_maybe_expand( The index and its corresponding coordinates may be created along a new dimension. """ names: list[Hashable] = [] - codes: list[list[int]] = [] - levels: list[list[int]] = [] + codes: list[Iterable[int]] = [] + levels: list[Iterable[Any]] = [] level_variables: dict[Any, Variable] = {} _check_dim_compat({**current_variables, **variables}) @@ -1134,7 +1143,7 @@ def reorder_levels( its corresponding coordinates. """ - index = self.index.reorder_levels(level_variables.keys()) + index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys())) level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} return self._replace(index, level_coords_dtype=level_coords_dtype) @@ -1147,13 +1156,13 @@ def create_variables( variables = {} index_vars: IndexVars = {} - for name in (self.dim,) + self.index.names: + for name in (self.dim,) + tuple(self.index.names): if name == self.dim: level = None dtype = None else: level = name - dtype = self.level_coords_dtype[name] + dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? var = variables.get(name, None) if var is not None: @@ -1163,7 +1172,7 @@ def create_variables( attrs = {} encoding = {} - data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok? index_vars[name] = IndexVariable( self.dim, data, @@ -1186,6 +1195,8 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: new_index = None scalar_coord_values = {} + indexer: int | slice | np.ndarray | Variable | DataArray + # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): label_values = {} @@ -1212,7 +1223,7 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: ) scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: + if indexer.dtype.kind == "b" and indexer.sum() == 0: # type: ignore[union-attr] raise KeyError(f"{labels} not found") # assume one label value given for the multi-index "array" (dimension) @@ -1600,9 +1611,7 @@ def group_by_index( """Returns a list of unique indexes and their corresponding coordinates.""" index_coords = [] - - for i in self._id_index: - index = self._id_index[i] + for i, index in self._id_index.items(): coords = {k: self._variables[k] for k in self._id_coord_names[i]} index_coords.append((index, coords)) @@ -1640,26 +1649,28 @@ def copy_indexes( in this dict. """ - new_indexes = {} - new_index_vars = {} + new_indexes: dict[Hashable, T_PandasOrXarrayIndex] = {} + new_index_vars: dict[Hashable, Variable] = {} - idx: T_PandasOrXarrayIndex + xr_idx: Index + new_idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True dim = next(iter(coords.values())).dims[0] if isinstance(idx, pd.MultiIndex): - idx = PandasMultiIndex(idx, dim) + xr_idx = PandasMultiIndex(idx, dim) else: - idx = PandasIndex(idx, dim) + xr_idx = PandasIndex(idx, dim) else: convert_new_idx = False + xr_idx = idx - new_idx = idx._copy(deep=deep, memo=memo) - idx_vars = idx.create_variables(coords) + new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment] + idx_vars = xr_idx.create_variables(coords) if convert_new_idx: - new_idx = cast(PandasIndex, new_idx).index + new_idx = new_idx.index # type: ignore[attr-defined] new_indexes.update({k: new_idx for k in coords}) new_index_vars.update(idx_vars) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 06e7efdbb48..19937270268 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -34,6 +34,7 @@ from numpy.typing import DTypeLike from xarray.core.indexes import Index + from xarray.core.types import Self from xarray.core.variable import Variable from xarray.namedarray._typing import _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -1656,6 +1657,9 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_dtype") + array: pd.Index + _dtype: np.dtype + def __init__(self, array: pd.Index, dtype: DTypeLike = None): from xarray.core.indexes import safe_cast_to_index @@ -1792,7 +1796,7 @@ def transpose(self, order) -> pd.Index: def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" - def copy(self, deep: bool = True) -> PandasIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # Not the same as just writing `self.array.copy(deep=deep)`, as # shallow copies of the underlying numpy.ndarrays become deep ones # upon pickling @@ -1810,11 +1814,14 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): This allows creating one instance for each multi-index level while preserving indexing efficiency (memoized + might reuse another instance with the same multi-index). - """ __slots__ = ("array", "_dtype", "level", "adapter") + array: pd.MultiIndex + _dtype: np.dtype + level: str | None + def __init__( self, array: pd.MultiIndex, @@ -1910,7 +1917,7 @@ def _repr_html_(self) -> str: array_repr = short_array_repr(self._get_array_subset()) return f"
{escape(array_repr)}
" - def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 45abc70c0d3..bfbad72649a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -315,7 +315,9 @@ def interp_na( use_coordinate: bool | str = True, method: InterpOptions = "linear", limit: int | None = None, - max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None, + max_gap: ( + int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta | None + ) = None, keep_attrs: bool | None = None, **kwargs, ): diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 216bd8fca6b..a048e85b4d4 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -66,6 +66,13 @@ class CFTimeGrouper: single method, the only one required for resampling in xarray. It cannot be used in a call to groupby like a pandas.Grouper object can.""" + freq: BaseCFTimeOffset + closed: SideOptions + label: SideOptions + loffset: str | datetime.timedelta | BaseCFTimeOffset | None + origin: str | CFTimeDatetime + offset: datetime.timedelta | None + def __init__( self, freq: str | BaseCFTimeOffset, @@ -73,11 +80,8 @@ def __init__( label: SideOptions | None = None, loffset: str | datetime.timedelta | BaseCFTimeOffset | None = None, origin: str | CFTimeDatetime = "start_day", - offset: str | datetime.timedelta | None = None, + offset: str | datetime.timedelta | BaseCFTimeOffset | None = None, ): - self.offset: datetime.timedelta | None - self.closed: SideOptions - self.label: SideOptions self.freq = to_offset(freq) self.loffset = loffset self.origin = origin @@ -120,10 +124,10 @@ def __init__( if offset is not None: try: self.offset = _convert_offset_to_timedelta(offset) - except (ValueError, AttributeError) as error: + except (ValueError, TypeError) as error: raise ValueError( f"offset must be a datetime.timedelta object or an offset string " - f"that can be converted to a timedelta. Got {offset} instead." + f"that can be converted to a timedelta. Got {type(offset)} instead." ) from error else: self.offset = None @@ -250,12 +254,12 @@ def _get_time_bins( def _adjust_bin_edges( - datetime_bins: np.ndarray, + datetime_bins: CFTimeIndex, freq: BaseCFTimeOffset, closed: SideOptions, index: CFTimeIndex, - labels: np.ndarray, -): + labels: CFTimeIndex, +) -> tuple[CFTimeIndex, CFTimeIndex]: """This is required for determining the bin edges resampling with month end, quarter end, and year end frequencies. @@ -499,10 +503,11 @@ def _convert_offset_to_timedelta( ) -> datetime.timedelta: if isinstance(offset, datetime.timedelta): return offset - elif isinstance(offset, (str, Tick)): - return to_offset(offset).as_timedelta() - else: - raise ValueError + if isinstance(offset, (str, Tick)): + timedelta_cftime_offset = to_offset(offset) + if isinstance(timedelta_cftime_offset, Tick): + return timedelta_cftime_offset.as_timedelta() + raise TypeError(f"Expected timedelta, str or Tick, got {type(offset)}") def _ceil_via_cftimeindex(date: CFTimeDatetime, freq: str | BaseCFTimeOffset): diff --git a/xarray/core/types.py b/xarray/core/types.py index 588d15dc35d..786ab5973b1 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -28,10 +28,11 @@ else: Self: Any = None -if TYPE_CHECKING: - from numpy._typing import _SupportsDType - from numpy.typing import ArrayLike +from numpy._typing import _SupportsDType +from numpy.typing import ArrayLike + +if TYPE_CHECKING: from xarray.backends.common import BackendEntrypoint from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords @@ -45,7 +46,7 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray # type: ignore + DaskArray = np.ndarray try: from cubed import Array as CubedArray @@ -177,7 +178,7 @@ def copy( T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) -ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +ScalarOrArray = Union["ArrayLike", np.generic] VarCompatible = Union["Variable", "ScalarOrArray"] DaCompatible = Union["DataArray", "VarCompatible"] DsCompatible = Union["Dataset", "DaCompatible"] @@ -219,6 +220,7 @@ def copy( DatetimeUnitOptions = Literal[ "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None ] +NPDatetimeUnitOptions = Literal["D", "h", "m", "s", "ms", "us", "ns"] QueryEngineOptions = Literal["python", "numexpr", None] QueryParserOptions = Literal["pandas", "python"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 5cb52cbd25c..c2859632360 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -57,19 +57,12 @@ Mapping, MutableMapping, MutableSet, + Sequence, ValuesView, ) from enum import Enum from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, overload import numpy as np import pandas as pd @@ -117,26 +110,27 @@ def wrapper(*args, **kwargs): return wrapper -def get_valid_numpy_dtype(array: np.ndarray | pd.Index): +def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: """Return a numpy compatible dtype from either a numpy array or a pandas.Index. - Used for wrapping a pandas.Index as an xarray,Variable. + Used for wrapping a pandas.Index as an xarray.Variable. """ if isinstance(array, pd.PeriodIndex): - dtype = np.dtype("O") - elif hasattr(array, "categories"): + return np.dtype("O") + + if hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype if not is_valid_numpy_dtype(dtype): dtype = np.dtype("O") - elif not is_valid_numpy_dtype(array.dtype): - dtype = np.dtype("O") - else: - dtype = array.dtype + return dtype + + if not is_valid_numpy_dtype(array.dtype): + return np.dtype("O") - return dtype + return array.dtype # type: ignore[return-value] def maybe_coerce_to_str(index, original_coords): @@ -183,18 +177,17 @@ def equivalent(first: T, second: T) -> bool: if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) if isinstance(first, list) or isinstance(second, list): - return list_equiv(first, second) - return (first == second) or (pd.isnull(first) and pd.isnull(second)) + return list_equiv(first, second) # type: ignore[arg-type] + return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload] -def list_equiv(first, second): - equiv = True +def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: if len(first) != len(second): return False - else: - for f, s in zip(first, second): - equiv = equiv and equivalent(f, s) - return equiv + for f, s in zip(first, second): + if not equivalent(f, s): + return False + return True def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f0685882595..377dafa6f79 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -298,7 +298,7 @@ def as_compatible_data( # we don't want nested self-described arrays if isinstance(data, (pd.Series, pd.DataFrame)): - data = data.values + data = data.values # type: ignore[assignment] if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) @@ -1504,7 +1504,7 @@ def _unstack_once( # Potentially we could replace `len(other_dims)` with just `-1` other_dims = [d for d in self.dims if d != dim] new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) - new_dims = reordered.dims[: len(other_dims)] + new_dim_names + new_dims = reordered.dims[: len(other_dims)] + tuple(new_dim_names) create_template: Callable if fill_value is dtypes.NA: diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 14744d2de6b..963d12fd865 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -21,13 +21,13 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray[Any, Any] # type: ignore[assignment, misc] + DaskArray = np.ndarray[Any, Any] dask_available = module_available("dask") -class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): array_cls: type[DaskArray] available: bool = dask_available @@ -91,7 +91,7 @@ def array_api(self) -> Any: return da - def reduction( # type: ignore[override] + def reduction( self, arr: T_ChunkedArray, func: Callable[..., Any], @@ -113,7 +113,7 @@ def reduction( # type: ignore[override] keepdims=keepdims, ) # type: ignore[no-untyped-call] - def scan( # type: ignore[override] + def scan( self, func: Callable[..., Any], binop: Callable[..., Any], diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0b90a05262d..152a9ec40e9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -568,7 +568,7 @@ def test_roundtrip_cftime_datetime_data(self) -> None: assert actual.t.encoding["calendar"] == expected_calendar def test_roundtrip_timedelta_data(self) -> None: - time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -5627,7 +5627,9 @@ def test_h5netcdf_entrypoint(tmp_path: Path) -> None: @requires_netCDF4 @pytest.mark.parametrize("str_type", (str, np.str_)) -def test_write_file_from_np_str(str_type, tmpdir) -> None: +def test_write_file_from_np_str( + str_type: type[str] | type[np.str_], tmpdir: str +) -> None: # https://github.com/pydata/xarray/pull/5264 scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]] years = range(2015, 2100 + 1) @@ -5638,7 +5640,7 @@ def test_write_file_from_np_str(str_type, tmpdir) -> None: ) tdf.index.name = "scenario" tdf.columns.name = "year" - tdf = tdf.stack() + tdf = cast(pd.DataFrame, tdf.stack()) tdf.name = "tas" txr = tdf.to_xarray() diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index eabb7d2f4d6..78aa49c7f83 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -511,7 +511,7 @@ def test_Microsecond_multiplied_float_error(): ], ids=_id_func, ) -def test_neg(offset, expected): +def test_neg(offset: BaseCFTimeOffset, expected: BaseCFTimeOffset) -> None: assert -offset == expected diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f6eb15fa373..116487e2bcf 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -679,11 +679,11 @@ def test_indexing_in_series_loc(series, index, scalar_args, range_args): @requires_cftime def test_indexing_in_series_iloc(series, index): - expected = 1 - assert series.iloc[0] == expected + expected1 = 1 + assert series.iloc[0] == expected1 - expected = pd.Series([1, 2], index=index[:2]) - assert series.iloc[:2].equals(expected) + expected2 = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected2) @requires_cftime @@ -696,27 +696,27 @@ def test_series_dropna(index): @requires_cftime def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): - expected = pd.Series([1], name=index[0]) + expected_s = pd.Series([1], name=index[0]) for arg in scalar_args: - result = df.loc[arg] - assert result.equals(expected) + result_s = df.loc[arg] + assert result_s.equals(expected_s) - expected = pd.DataFrame([1, 2], index=index[:2]) + expected_df = pd.DataFrame([1, 2], index=index[:2]) for arg in range_args: - result = df.loc[arg] - assert result.equals(expected) + result_df = df.loc[arg] + assert result_df.equals(expected_df) @requires_cftime def test_indexing_in_dataframe_iloc(df, index): - expected = pd.Series([1], name=index[0]) - result = df.iloc[0] - assert result.equals(expected) - assert result.equals(expected) + expected_s = pd.Series([1], name=index[0]) + result_s = df.iloc[0] + assert result_s.equals(expected_s) + assert result_s.equals(expected_s) - expected = pd.DataFrame([1, 2], index=index[:2]) - result = df.iloc[:2] - assert result.equals(expected) + expected_df = pd.DataFrame([1, 2], index=index[:2]) + result_df = df.iloc[:2] + assert result_df.equals(expected_df) @requires_cftime @@ -957,17 +957,17 @@ def test_cftimeindex_shift(index, freq) -> None: @requires_cftime -def test_cftimeindex_shift_invalid_n() -> None: +def test_cftimeindex_shift_invalid_periods() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift("a", "D") + index.shift("a", "D") # type: ignore[arg-type] @requires_cftime def test_cftimeindex_shift_invalid_freq() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift(1, 1) + index.shift(1, 1) # type: ignore[arg-type] @requires_cftime diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 98d4377706c..3dda7a5f1eb 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -10,6 +10,7 @@ import xarray as xr from xarray.coding.cftime_offsets import _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core.pdcompat import _convert_base_to_offset from xarray.core.resample_cftime import CFTimeGrouper @@ -204,7 +205,9 @@ def test_calendars(calendar: str) -> None: .mean() ) # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass - da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + new_pd_index = da_cftime.xindexes["time"].to_pandas_index() + assert isinstance(new_pd_index, CFTimeIndex) # shouldn't that be a pd.Index? + da_cftime["time"] = new_pd_index.to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) @@ -248,11 +251,11 @@ def test_base_and_offset_error(): @pytest.mark.parametrize("offset", ["foo", "5MS", 10]) -def test_invalid_offset_error(offset) -> None: +def test_invalid_offset_error(offset: str | int) -> None: cftime_index = xr.cftime_range("2000", periods=5) da_cftime = da(cftime_index) with pytest.raises(ValueError, match="offset must be"): - da_cftime.resample(time="2D", offset=offset) + da_cftime.resample(time="2D", offset=offset) # type: ignore[arg-type] def test_timedelta_offset() -> None: @@ -279,7 +282,9 @@ def test_resample_loffset_cftimeindex(loffset) -> None: result = da_cftimeindex.resample(time="24h", loffset=loffset).mean() expected = da_datetimeindex.resample(time="24h", loffset=loffset).mean() - result["time"] = result.xindexes["time"].to_pandas_index().to_datetimeindex() + index = result.xindexes["time"].to_pandas_index() + assert isinstance(index, CFTimeIndex) + result["time"] = index.to_datetimeindex() xr.testing.assert_identical(result, expected) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 623e4e9f970..ef478af8786 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -3,6 +3,7 @@ import warnings from datetime import timedelta from itertools import product +from typing import Literal import numpy as np import pandas as pd @@ -144,15 +145,15 @@ def test_cf_datetime(num_dates, units, calendar) -> None: # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 assert (abs_diff <= np.timedelta64(1, "s")).all() - encoded, _, _ = encode_cf_datetime(actual, units, calendar) + encoded1, _, _ = encode_cf_datetime(actual, units, calendar) + assert_array_equal(num_dates, np.around(encoded1, 1)) - assert_duckarray_allclose(num_dates, encoded) if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to put # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) - assert_duckarray_allclose(num_dates, encoded) + encoded2, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.around(encoded2, 1)) @requires_cftime @@ -627,10 +628,10 @@ def test_cf_timedelta_2d() -> None: @pytest.mark.parametrize( ["deltas", "expected"], [ - (pd.to_timedelta(["1 day", "2 days"]), "days"), - (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), - (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), - (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ], ) def test_infer_timedelta_units(deltas, expected) -> None: @@ -1237,7 +1238,7 @@ def test_contains_cftime_lazy() -> None: ) def test_roundtrip_datetime64_nanosecond_precision( timestr: str, - timeunit: str, + timeunit: Literal["ns", "us"], dtype: np.typing.DTypeLike, fill_value: int | float | None, use_encoding: bool, @@ -1433,8 +1434,8 @@ def test_roundtrip_float_times() -> None: def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype) -> None: import dask.array - times = pd.date_range(start="1700", freq=freq, periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.date_range(start="1700", freq=freq, periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) @@ -1484,8 +1485,8 @@ def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None: import dask.array calendar = "standard" - times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) - times = dask.array.from_array(times, chunks=1) + times_idx = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) + times = dask.array.from_array(times_idx, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) @@ -1557,11 +1558,13 @@ def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) @pytest.mark.parametrize( ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] ) -def test_encode_cf_timedelta_via_dask(units, dtype) -> None: +def test_encode_cf_timedelta_via_dask( + units: str | None, dtype: np.dtype | None +) -> None: import dask.array - times = pd.timedelta_range(start="0D", freq="D", periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.timedelta_range(start="0D", freq="D", periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) assert is_duck_dask_array(encoded_times) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0c570de3b52..8b2a7ec5d28 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np import pandas as pd @@ -474,7 +474,7 @@ def data(self, request) -> Dataset: "dim3" ) - def rectify_dim_order(self, data, dataset) -> Dataset: + def rectify_dim_order(self, data: Dataset, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset( @@ -487,11 +487,13 @@ def rectify_dim_order(self, data, dataset) -> Dataset: @pytest.mark.parametrize( "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] ) - def test_concat_simple(self, data, dim, coords) -> None: + def test_concat_simple(self, data: Dataset, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) - def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: + def test_concat_merge_variables_present_in_some_datasets( + self, data: Dataset + ) -> None: # coordinates present in some datasets but not others ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) ds2 = Dataset(data_vars={"a": ("y", [0.2])}, coords={"z": 0.2}) @@ -515,7 +517,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: assert_identical(expected, actual) @pytest.mark.parametrize("data", [False], indirect=["data"]) - def test_concat_2(self, data) -> None: + def test_concat_2(self, data: Dataset) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] @@ -524,7 +526,9 @@ def test_concat_2(self, data) -> None: @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) @pytest.mark.parametrize("dim", ["dim1", "dim2"]) - def test_concat_coords_kwarg(self, data, dim, coords) -> None: + def test_concat_coords_kwarg( + self, data: Dataset, dim: str, coords: Literal["all", "minimal", "different"] + ) -> None: data = data.copy(deep=True) # make sure the coords argument behaves as expected data.coords["extra"] = ("dim4", np.arange(3)) @@ -538,7 +542,7 @@ def test_concat_coords_kwarg(self, data, dim, coords) -> None: else: assert_equal(data["extra"], actual["extra"]) - def test_concat(self, data) -> None: + def test_concat(self, data: Dataset) -> None: split_data = [ data.isel(dim1=slice(3)), data.isel(dim1=3), @@ -546,7 +550,7 @@ def test_concat(self, data) -> None: ] assert_identical(data, concat(split_data, "dim1")) - def test_concat_dim_precedence(self, data) -> None: + def test_concat_dim_precedence(self, data: Dataset) -> None: # verify that the dim argument takes precedence over # concatenating dataset variables of the same name dim = (2 * data["dim1"]).rename("dim1") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index fdfea3c3fe8..dc0b270dc51 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -119,7 +119,7 @@ def test_incompatible_attributes(self) -> None: Variable( ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} ), - Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), ] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 44ef486e5d6..b689bb8c02d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -341,12 +341,15 @@ def test_constructor(self) -> None: assert_identical(expected, actual) # list coords, w dims - coords1 = [["a", "b"], [-1, -2, -3]] + coords1: list[Any] = [["a", "b"], [-1, -2, -3]] actual = DataArray(data, coords1, ["x", "y"]) assert_identical(expected, actual) # pd.Index coords, w dims - coords2 = [pd.Index(["a", "b"], name="A"), pd.Index([-1, -2, -3], name="B")] + coords2: list[pd.Index] = [ + pd.Index(["a", "b"], name="A"), + pd.Index([-1, -2, -3], name="B"), + ] actual = DataArray(data, coords2, ["x", "y"]) assert_identical(expected, actual) @@ -424,7 +427,7 @@ def test_constructor_invalid(self) -> None: DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) def test_constructor_from_self_described(self) -> None: - data = [[-0.1, 21], [0, 2]] + data: list[list[float]] = [[-0.1, 21], [0, 2]] expected = DataArray( data, coords={"x": ["a", "b"], "y": [-1, -2]}, @@ -2488,7 +2491,7 @@ def test_stack_unstack(self) -> None: # test GH3000 a = orig[:0, :1].stack(new_dim=("x", "y")).indexes["new_dim"] b = pd.MultiIndex( - levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], + levels=[pd.Index([], dtype=np.int64), pd.Index([0], dtype=np.int64)], codes=[[], []], names=["x", "y"], ) @@ -3331,28 +3334,28 @@ def test_broadcast_coordinates(self) -> None: def test_to_pandas(self) -> None: # 0d - actual = DataArray(42).to_pandas() + actual_xr = DataArray(42).to_pandas() expected = np.array(42) - assert_array_equal(actual, expected) + assert_array_equal(actual_xr, expected) # 1d values = np.random.randn(3) index = pd.Index(["a", "b", "c"], name="x") da = DataArray(values, coords=[index]) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, index) - assert_array_equal(actual.index.name, "x") + actual_s = da.to_pandas() + assert_array_equal(np.asarray(actual_s.values), values) + assert_array_equal(actual_s.index, index) + assert_array_equal(actual_s.index.name, "x") # 2d values = np.random.randn(3, 2) da = DataArray( values, coords=[("x", ["a", "b", "c"]), ("y", [0, 1])], name="foo" ) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, ["a", "b", "c"]) - assert_array_equal(actual.columns, [0, 1]) + actual_df = da.to_pandas() + assert_array_equal(np.asarray(actual_df.values), values) + assert_array_equal(actual_df.index, ["a", "b", "c"]) + assert_array_equal(actual_df.columns, [0, 1]) # roundtrips for shape in [(3,), (3, 4)]: @@ -3369,24 +3372,24 @@ def test_to_dataframe(self) -> None: arr_np = np.random.randn(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() - actual = arr.to_dataframe()["foo"] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.name, actual.name) - assert_array_equal(expected.index.values, actual.index.values) + expected_s = arr.to_series() + actual_s = arr.to_dataframe()["foo"] + assert_array_equal(np.asarray(expected_s.values), np.asarray(actual_s.values)) + assert_array_equal(np.asarray(expected_s.name), np.asarray(actual_s.name)) + assert_array_equal(expected_s.index.values, actual_s.index.values) - actual = arr.to_dataframe(dim_order=["A", "B"])["foo"] - assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + actual_s = arr.to_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), np.asarray(actual_s.values)) # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] - actual = arr.to_dataframe() - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) - assert_array_equal(expected.index.values, actual.index.values) + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] + actual_df = arr.to_dataframe() + assert_array_equal(np.asarray(expected_df.values), np.asarray(actual_df.values)) + assert_array_equal(expected_df.columns.values, actual_df.columns.values) + assert_array_equal(expected_df.index.values, actual_df.index.values) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dataframe(dim_order=["B", "A", "C"]) @@ -3407,11 +3410,13 @@ def test_to_dataframe_multiindex(self) -> None: arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") actual = arr.to_dataframe() - assert_array_equal(actual["foo"].values, arr_np.flatten()) - assert_array_equal(actual.index.names, list("ABC")) - assert_array_equal(actual.index.levels[0], [1, 2]) - assert_array_equal(actual.index.levels[1], ["a", "b"]) - assert_array_equal(actual.index.levels[2], [5, 6, 7]) + index_pd = actual.index + assert isinstance(index_pd, pd.MultiIndex) + assert_array_equal(np.asarray(actual["foo"].values), arr_np.flatten()) + assert_array_equal(index_pd.names, list("ABC")) + assert_array_equal(index_pd.levels[0], [1, 2]) + assert_array_equal(index_pd.levels[1], ["a", "b"]) + assert_array_equal(index_pd.levels[2], [5, 6, 7]) def test_to_dataframe_0length(self) -> None: # regression test for #3008 @@ -3431,10 +3436,10 @@ def test_to_dataframe_0length(self) -> None: def test_to_dask_dataframe(self) -> None: arr_np = np.arange(3 * 4).reshape(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() + expected_s = arr.to_series() actual = arr.to_dask_dataframe()["foo"] - assert_array_equal(actual.values, expected.values) + assert_array_equal(actual.values, np.asarray(expected_s.values)) actual = arr.to_dask_dataframe(dim_order=["A", "B"])["foo"] assert_array_equal(arr_np.transpose().reshape(-1), actual.values) @@ -3442,13 +3447,15 @@ def test_to_dask_dataframe(self) -> None: # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] actual = arr.to_dask_dataframe()[["C", "foo"]] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) + assert_array_equal(expected_df.values, np.asarray(actual.values)) + assert_array_equal( + expected_df.columns.values, np.asarray(actual.columns.values) + ) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dask_dataframe(dim_order=["B", "A", "C"]) @@ -3464,8 +3471,8 @@ def test_to_pandas_name_matches_coordinate(self) -> None: # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x") series = arr.to_series() - assert_array_equal([1, 2, 3], series.values) - assert_array_equal([0, 1, 2], series.index.values) + assert_array_equal([1, 2, 3], list(series.values)) + assert_array_equal([0, 1, 2], list(series.index.values)) assert "x" == series.name assert "x" == series.index.name @@ -3544,7 +3551,7 @@ def test_nbytes_does_not_load_data(self) -> None: def test_to_and_from_empty_series(self) -> None: # GH697 - expected = pd.Series([], dtype=np.float64) + expected: pd.Series[Any] = pd.Series([], dtype=np.float64) da = DataArray.from_series(expected) assert len(da) == 0 actual = da.to_series() diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fd511af0dfb..4db005ca3fb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -39,6 +39,7 @@ from xarray.core.common import duck_array_ops, full_like from xarray.core.coordinates import Coordinates, DatasetCoordinates from xarray.core.indexes import Index, PandasIndex +from xarray.core.types import ArrayLike from xarray.core.utils import is_scalar from xarray.namedarray.pycompat import array_type, integer_types from xarray.testing import _assert_internal_invariants @@ -580,6 +581,7 @@ def test_constructor_pandas_single(self) -> None: pandas_obj = a.to_pandas() ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ for dim in ds_based_on_pandas.data_vars: + assert isinstance(dim, int) assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self) -> None: @@ -1694,7 +1696,7 @@ def test_sel_categorical_error(self) -> None: with pytest.raises(ValueError): ds.sel(ind="bar", method="nearest") with pytest.raises(ValueError): - ds.sel(ind="bar", tolerance="nearest") + ds.sel(ind="bar", tolerance="nearest") # type: ignore[arg-type] def test_categorical_index(self) -> None: cat = pd.CategoricalIndex( @@ -2044,9 +2046,9 @@ def test_to_pandas(self) -> None: y = np.random.randn(10) t = list("abcdefghij") ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) - actual = ds.to_pandas() - expected = ds.to_dataframe() - assert expected.equals(actual), (expected, actual) + actual_df = ds.to_pandas() + expected_df = ds.to_dataframe() + assert expected_df.equals(actual_df), (expected_df, actual_df) # 2D -> error x2d = np.random.randn(10, 10) @@ -3618,6 +3620,7 @@ def test_reset_index_drop_convert( def test_reorder_levels(self) -> None: ds = create_test_multiindex() mindex = ds["x"].to_index() + assert isinstance(mindex, pd.MultiIndex) midx = mindex.reorder_levels(["level_2", "level_1"]) midx_coords = Coordinates.from_pandas_multiindex(midx, "x") expected = Dataset({}, coords=midx_coords) @@ -3943,7 +3946,9 @@ def test_to_stacked_array_dtype_dims(self) -> None: D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype + mindex = y.xindexes["features"].to_pandas_index() + assert isinstance(mindex, pd.MultiIndex) + assert mindex.levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self) -> None: @@ -4114,9 +4119,9 @@ def test_virtual_variables_default_coords(self) -> None: def test_virtual_variables_time(self) -> None: # access virtual variables data = create_test_data() - assert_array_equal( - data["time.month"].values, data.variables["time"].to_index().month - ) + index = data.variables["time"].to_index() + assert isinstance(index, pd.DatetimeIndex) + assert_array_equal(data["time.month"].values, index.month) assert_array_equal(data["time.season"].values, "DJF") # test virtual variable math assert_array_equal(data["time.dayofyear"] + 1, 2 + np.arange(20)) @@ -4805,20 +4810,20 @@ def test_to_and_from_dataframe(self) -> None: # check pathological cases df = pd.DataFrame([1]) - actual = Dataset.from_dataframe(df) - expected = Dataset({0: ("index", [1])}, {"index": [0]}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({0: ("index", [1])}, {"index": [0]}) + assert_identical(expected_ds, actual_ds) df = pd.DataFrame() - actual = Dataset.from_dataframe(df) - expected = Dataset(coords={"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset(coords={"index": []}) + assert_identical(expected_ds, actual_ds) # GH697 df = pd.DataFrame({"A": []}) - actual = Dataset.from_dataframe(df) - expected = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) + assert_identical(expected_ds, actual_ds) # regression test for GH278 # use int64 to ensure consistent results for the pandas .equals method @@ -4857,7 +4862,7 @@ def test_from_dataframe_categorical_index(self) -> None: def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( - np.array([1, 1, 0, 2]), + np.array([1, 1, 0, 2], dtype=np.int64), # type: ignore[arg-type] categories=pd.Index(["foo", "bar", "baz"], dtype="string"), ) ) @@ -4942,7 +4947,7 @@ def test_from_dataframe_unsorted_levels(self) -> None: def test_from_dataframe_non_unique_columns(self) -> None: # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) - df.columns = ["foo", "foo"] + df.columns = ["foo", "foo"] # type: ignore[assignment] with pytest.raises(ValueError, match=r"non-unique columns"): Dataset.from_dataframe(df) @@ -7231,6 +7236,7 @@ def test_cumulative_integrate(dask) -> None: @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapezoid_datetime(dask, which_datetime) -> None: rs = np.random.RandomState(42) + coord: ArrayLike if which_datetime == "np": coord = np.array( [ diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 9d0eb81bace..0bd8abc3a70 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -118,9 +118,9 @@ def test_format_items(self) -> None: np.arange(4) * np.timedelta64(500, "ms"), "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", ), - (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ( - pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", ), ([1, 2, 3], "1 2 3"), diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f0a0fd14d9d..469e5a3b1f2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -734,7 +734,7 @@ def test_groupby_bins_timeseries() -> None: expected = xr.DataArray( 96 * np.ones((14,)), dims=["time_bins"], - coords={"time_bins": pd.cut(time_bins, time_bins).categories}, + coords={"time_bins": pd.cut(time_bins, time_bins).categories}, # type: ignore[arg-type] ).to_dataset(name="val") assert_identical(actual, expected) @@ -868,7 +868,7 @@ def test_groupby_dataset_errors() -> None: with pytest.raises(ValueError, match=r"length does not match"): data.groupby(data["dim1"][:3]) with pytest.raises(TypeError, match=r"`group` must be"): - data.groupby(data.coords["dim1"].to_index()) + data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type] def test_groupby_dataset_reduce() -> None: @@ -1624,7 +1624,7 @@ def test_groupby_bins( bins = [0, 1.5, 5] df = array.to_dataframe() - df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) + df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() # TODO: can't convert df with IntervalIndex to Xarray @@ -1690,7 +1690,7 @@ def test_groupby_bins_empty(self) -> None: array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty bins = [0, 4, 5] - bin_coords = pd.cut(array["x"], bins).categories + bin_coords = pd.cut(array["x"], bins).categories # type: ignore[call-overload] actual = array.groupby_bins("x", bins).sum() expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) assert_identical(expected, actual) @@ -1701,7 +1701,7 @@ def test_groupby_bins_empty(self) -> None: def test_groupby_bins_multidim(self) -> None: array = self.make_groupby_multidim_example_array() bins = [0, 15, 20] - bin_coords = pd.cut(array["lat"].values.flat, bins).categories + bin_coords = pd.cut(array["lat"].values.flat, bins).categories # type: ignore[call-overload] expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 5ebdfd5da6e..48e254b037b 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -410,13 +410,15 @@ def test_stack(self) -> None: "y": xr.Variable("y", pd.Index([1, 3, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") - assert index.dim == "z" + assert index_xr.dim == "z" + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) # TODO: change to tuple when pandas 3 is minimum - assert list(index.index.names) == ["x", "y"] + assert list(index_pd.names) == ["x", "y"] np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) with pytest.raises( @@ -433,13 +435,15 @@ def test_stack_non_unique(self) -> None: "y": xr.Variable("y", pd.Index([1, 1, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] ) - np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) - np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + np.testing.assert_array_equal(index_pd.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index_pd.levels[1], [1, 2]) def test_unstack(self) -> None: pd_midx = pd.MultiIndex.from_product( @@ -600,10 +604,7 @@ def indexes( _, variables = indexes_and_vars - if isinstance(x_idx, Index): - index_type = Index - else: - index_type = pd.Index + index_type = Index if isinstance(x_idx, Index) else pd.Index return Indexes(indexes, variables, index_type=index_type) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 578e6bcc18e..a973f6b11f7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -6,7 +6,7 @@ from collections.abc import Generator, Hashable from copy import copy from datetime import date, timedelta -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast import numpy as np import pandas as pd @@ -528,7 +528,7 @@ def test__infer_interval_breaks(self) -> None: [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) ) assert_array_equal( - pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator] _infer_interval_breaks(pd.date_range("20000101", periods=3)), ) @@ -1048,7 +1048,9 @@ def test_list_levels(self) -> None: assert cmap_params["cmap"].N == 5 assert cmap_params["norm"].N == 6 - for wrap_levels in [list, np.array, pd.Index, DataArray]: + for wrap_levels in cast( + list[Callable[[Any], dict[Any, Any]]], [list, np.array, pd.Index, DataArray] + ): cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) assert_array_equal(cmap_params["levels"], orig_levels) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 89f6ebba2c3..79869e63ae7 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -206,9 +206,9 @@ def test_rolling_pandas_compat( index=window, center=center, min_periods=min_periods ).reduce(np.nanmean) - np.testing.assert_allclose(s_rolling.values, da_rolling.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling.values) np.testing.assert_allclose(s_rolling.index, da_rolling["index"]) - np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_np.values) np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -221,12 +221,14 @@ def test_rolling_construct(self, center: bool, window: int) -> None: da_rolling = da.rolling(index=window, center=center, min_periods=1) da_rolling_mean = da_rolling.construct("window").mean("window") - np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_mean.values) np.testing.assert_allclose(s_rolling.index, da_rolling_mean["index"]) # with stride da_rolling_mean = da_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) + np.testing.assert_allclose( + np.asarray(s_rolling.values[::2]), da_rolling_mean.values + ) np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean["index"]) # with fill_value @@ -649,7 +651,9 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: index=window, center=center, min_periods=min_periods ).mean() - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -668,7 +672,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("window").mean("window") - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling_mean["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) # with fill_value @@ -695,7 +701,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds_rolling_mean["x"].values ) np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) @@ -704,7 +710,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds2_rolling = ds2.rolling(index=window, center=center) ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds2_rolling_mean["x"].values ) # Mixed coordinates, indexes and 2D coordinates diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 081bf09484a..60c173a9e52 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2649,7 +2649,7 @@ def test_tz_datetime(self) -> None: tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) - times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) + times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) # type: ignore[arg-type] with warnings.catch_warnings(): warnings.simplefilter("ignore") actual: T_DuckArray = as_compatible_data(times_s) @@ -2661,7 +2661,7 @@ def test_tz_datetime(self) -> None: warnings.simplefilter("ignore") actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual2, series.values) + np.testing.assert_array_equal(actual2, np.asarray(series.values)) assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: @@ -2978,26 +2978,35 @@ def test_datetime_conversion_warning(values, warns) -> None: ) -def test_pandas_two_only_datetime_conversion_warnings() -> None: - # Note these tests rely on pandas features that are only present in pandas - # 2.0.0 and above, and so for now cannot be parametrized. - cases = [ - (pd.date_range("2000", periods=1), "datetime64[s]"), - (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), - ( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), +tz_ny = pytz.timezone("America/New_York") + + +@pytest.mark.parametrize( + ["data", "dtype"], + [ + pytest.param(pd.date_range("2000", periods=1), "datetime64[s]", id="index-sec"), + pytest.param( + pd.Series(pd.date_range("2000", periods=1)), + "datetime64[s]", + id="series-sec", ), - ( - pd.Series( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) - ), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), + pytest.param( + pd.date_range("2000", periods=1, tz=tz_ny), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="index-timezone", ), - ] - for data, dtype in cases: - with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): - var = Variable(["time"], data.astype(dtype)) + pytest.param( + pd.Series(pd.date_range("2000", periods=1, tz=tz_ny)), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="series-timezone", + ), + ], +) +def test_pandas_two_only_datetime_conversion_warnings( + data: pd.DatetimeIndex | pd.Series, dtype: str | pd.DatetimeTZDtype +) -> None: + with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): + var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] if var.dtype.kind == "M": assert var.dtype == np.dtype("datetime64[ns]") @@ -3006,9 +3015,7 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. assert isinstance(var._data, PandasIndexingAdapter) - assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("America/New_York") - ) + assert var._data.array.dtype == pd.DatetimeTZDtype("ns", tz_ny) @pytest.mark.parametrize( diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index ee4dd68b3ba..a9f66cdc614 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -88,12 +88,10 @@ def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" template_binop_overload = """ @overload{overload_type_ignore} - def {method}(self, other: {overload_type}) -> {overload_type}: - ... + def {method}(self, other: {overload_type}) -> {overload_type}: ... @overload - def {method}(self, other: {other_type}) -> {return_type}: - ... + def {method}(self, other: {other_type}) -> {return_type}: ... def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" @@ -129,7 +127,7 @@ def {method}(self, *args: Any, **kwargs: Any) -> Self: # The type ignores might not be necessary anymore at some point. # # We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray -# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# In reality this returns NotImplemented, but this is not a valid type in python 3.9. # Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) # TODO: change once python 3.10 is the minimum. # @@ -216,6 +214,10 @@ def unops() -> list[OpsType]: ] +# We use short names T_DA and T_DS to keep below 88 lines so +# ruff does not reformat everything. When reformatting, the +# type-ignores end up in the wrong line :/ + ops_info = {} ops_info["DatasetOpsMixin"] = ( binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() @@ -224,12 +226,12 @@ def unops() -> list[OpsType]: binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() ) ops_info["VariableOpsMixin"] = ( - binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + binops_overload(other_type="VarCompatible", overload_type="T_DA") + inplace(other_type="VarCompatible", type_ignore="misc") + unops() ) ops_info["DatasetGroupByOpsMixin"] = binops( - other_type="GroupByCompatible", return_type="Dataset" + other_type="Dataset | DataArray", return_type="Dataset" ) ops_info["DataArrayGroupByOpsMixin"] = binops( other_type="T_Xarray", return_type="T_Xarray" @@ -237,6 +239,7 @@ def unops() -> list[OpsType]: MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" + # This file was generated using xarray.util.generate_ops. Do not edit manually. from __future__ import annotations @@ -248,15 +251,15 @@ def unops() -> list[OpsType]: from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByCompatible, Self, - T_DataArray, T_Xarray, VarCompatible, ) if TYPE_CHECKING: - from xarray.core.dataset import Dataset''' + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import T_DataArray as T_DA''' CLASS_PREAMBLE = """{newline}