From c50100f34f4946ae0a69529a6699a4372bd8118b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 6 Oct 2024 17:43:49 +0900 Subject: [PATCH 01/12] Reimplement DataTree aggregations They now allow for dimensions that are missing on particular nodes, and use Xarray's standard generate_aggregations machinery, like aggregations for DataArray and Dataset. Fixes https://github.com/pydata/xarray/issues/8949, https://github.com/pydata/xarray/issues/8963 --- xarray/core/_aggregations.py | 1308 ++++++++++++++++++++++++++ xarray/core/dataset.py | 14 +- xarray/core/datatree.py | 100 +- xarray/core/utils.py | 20 + xarray/namedarray/_aggregations.py | 103 +- xarray/tests/test_datatree.py | 41 +- xarray/util/generate_aggregations.py | 44 +- 7 files changed, 1536 insertions(+), 94 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index b557ad44a32..6b1029791ea 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -19,6 +19,1314 @@ flox_available = module_available("flox") +class DataTreeAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.count() + + Group: / + Dimensions: () + Data variables: + foo int64 8B 5 + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.all() + + Group: / + Dimensions: () + Data variables: + foo bool 1B False + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.any() + + Group: / + Dimensions: () + Data variables: + foo bool 1B True + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.max() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 3.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.max(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.min() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.min(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.mean() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.6 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.mean(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.prod() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.prod(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.prod(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.sum() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.sum(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.sum(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.std() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.02 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.std(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.std(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.14 + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.var() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.04 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.var(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.var(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.3 + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.median() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 2.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.median(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumsum() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumsum(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumprod() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumprod(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + class DatasetAggregations: __slots__ = () diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d57a6957553..f3c24be74b6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -109,6 +109,7 @@ OrderedSet, _default, decode_numpy_dict_values, + dim_arg_to_dims_set, drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, @@ -6986,18 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - if dim is None or dim is ...: - dims = set(self.dims) - elif isinstance(dim, str) or not isinstance(dim, Iterable): - dims = {dim} - else: - dims = set(dim) - - missing_dimensions = tuple(d for d in dims if d not in self.dims) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}" - ) + dims = dim_arg_to_dims_set(dim, self.dims) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f8da26c28b5..04f761fc3f9 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -11,9 +11,10 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Self, Union, overload from xarray.core import utils +from xarray.core._aggregations import DataTreeAggregations from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import Coordinates, DataTreeCoordinates @@ -37,6 +38,8 @@ FilteredMapping, Frozen, _default, + dim_arg_to_dims_set, + drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, ) @@ -54,7 +57,13 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue - from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes + from xarray.core.types import ( + Dims, + ErrorOptions, + ErrorOptionsWithWarn, + NetcdfWriteModes, + ZarrWriteModes, + ) # """ # DEVELOPERS' NOTE @@ -398,6 +407,7 @@ def map( # type: ignore[override] class DataTree( NamedNode["DataTree"], + DataTreeAggregations, TreeAttrAccessMixin, Mapping[str, "DataArray | DataTree"], ): @@ -1607,3 +1617,89 @@ def to_zarr( compute=compute, **kwargs, ) + + def _get_all_dims(self) -> set: + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + return all_dims + + def _selective_indexing( + self, + func: Callable[[Dataset, Mapping[Any, Any]]], + indexers: Mapping[Any, Any], + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + indexers = drop_dims_from_indexers(indexers, self._get_all_dims(), missing_dims) + result = {} + for node in self.subtree: + node_indexers = {k: v for k, v in indexers.items() if k in node.dims} + node_result = func(node.dataset, node_indexers) + for k in node_indexers: + if k not in node.coords and k in node_result.coords: + del node_result.coords[k] + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) # type: ignore + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Positional indexing.""" + + def apply_indexers(dataset, node_indexers): + return dataset.isel(node_indexers, drop=drop) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + return self._selective_indexing( + apply_indexers, indexers, missing_dims=missing_dims + ) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Label-based indexing.""" + + def apply_indexers(dataset, node_indexers): + # TODO: reimplement in terms of map_index_queries(), to avoid + # redundant index look-ups on child nodes + return dataset.sel( + node_indexers, method=method, tolerance=tolerance, drop=drop + ) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") + return self._selective_indexing(apply_indexers, indexers) + + def reduce( + self, + func: Callable, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + keepdims: bool = False, + numeric_only: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this tree by applying `func` along some dimension(s).""" + dims = dim_arg_to_dims_set(dim, self._get_all_dims()) + result = {} + for node in self.subtree: + reduce_dims = [d for d in node._node_dims if d in dims] + node_result = node.dataset.reduce( + func, + reduce_dims, + keep_attrs=keep_attrs, + keepdims=keepdims, + numeric_only=numeric_only, + **kwargs, + ) + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) # type: ignore diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..817586f0a0d 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -830,6 +830,26 @@ def drop_dims_from_indexers( ) +def dim_arg_to_dims_set(dim: Dims, all_dims: Collection) -> set: + """Convert a `dim` argument from Dataset/DataTree into a set of dimensions.""" + + if dim is None or dim is ...: + dims = set(all_dims) + elif isinstance(dim, str) or not isinstance(dim, Iterable): + # TODO: consider dropping `not isinstance(dim, Iterable)`, which is not + # allowed per the type signature + dims = {dim} + else: + dims = set(dim) + + missing_dimensions = tuple(d for d in dims if d not in all_dims) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} not found in data dimensions {tuple(all_dims)}" + ) + return dims + + @overload def parse_dims( dim: Dims, diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 48001422386..ff6bb06456f 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -61,17 +61,14 @@ def count( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.count() Size: 8B - array(5) + np.int64(5) """ return self.reduce( duck_array_ops.count, @@ -116,8 +113,7 @@ def all( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -125,7 +121,7 @@ def all( >>> na.all() Size: 1B - array(False) + np.False_ """ return self.reduce( duck_array_ops.array_all, @@ -170,8 +166,7 @@ def any( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -179,7 +174,7 @@ def any( >>> na.any() Size: 1B - array(True) + np.True_ """ return self.reduce( duck_array_ops.array_any, @@ -230,23 +225,20 @@ def max( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.max() Size: 8B - array(3.) + np.float64(3.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.max(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.max, @@ -298,23 +290,20 @@ def min( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.min() Size: 8B - array(0.) + np.float64(0.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.min(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.min, @@ -370,23 +359,20 @@ def mean( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.mean() Size: 8B - array(1.6) + np.float64(1.6) Use ``skipna`` to control whether NaNs are ignored. >>> na.mean(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.mean, @@ -449,23 +435,20 @@ def prod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.prod() Size: 8B - array(0.) + np.float64(0.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.prod(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -535,23 +518,20 @@ def sum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.sum() Size: 8B - array(8.) + np.float64(8.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.sum(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -618,29 +598,26 @@ def std( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.std() Size: 8B - array(1.0198039) + np.float64(1.019803902718557) Use ``skipna`` to control whether NaNs are ignored. >>> na.std(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.std(skipna=True, ddof=1) Size: 8B - array(1.14017543) + np.float64(1.140175425099138) """ return self.reduce( duck_array_ops.std, @@ -701,29 +678,26 @@ def var( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.var() Size: 8B - array(1.04) + np.float64(1.04) Use ``skipna`` to control whether NaNs are ignored. >>> na.var(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.var(skipna=True, ddof=1) Size: 8B - array(1.3) + np.float64(1.3) """ return self.reduce( duck_array_ops.var, @@ -780,23 +754,20 @@ def median( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.median() Size: 8B - array(2.) + np.float64(2.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.median(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.median, @@ -857,10 +828,7 @@ def cumsum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -934,10 +902,7 @@ def cumprod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3a3afb0647a..9eb819be48d 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1582,7 +1582,9 @@ def test_dataset_method(self): result = dt.isel(x=1) assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") + +class TestAggregations: + def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1592,7 +1594,6 @@ def test_reduce_method(self): result = dt.any() assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1602,7 +1603,6 @@ def test_nan_reduce_method(self): result = dt.mean() assert_equal(result, expected) - @pytest.mark.xfail(reason="cum methods not implemented yet") def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1617,6 +1617,41 @@ def test_cum_method(self): result = dt.cumsum() assert_equal(result, expected) + def test_dim_argument(self): + dt = DataTree.from_dict( + { + "/a": xr.Dataset({"A": ("x", [1, 2])}), + "/b": xr.Dataset({"B": ("y", [1, 2])}), + } + ) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": 1.5}), + } + ) + actual = dt.mean() + assert_equal(expected, actual) + + actual = dt.mean(dim=...) + assert_equal(expected, actual) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": ("y", [1.0, 2.0])}), + } + ) + actual = dt.mean("x") + assert_equal(expected, actual) + + with pytest.raises( + ValueError, + match=re.escape("Dimensions ('invalid',) not found in data dimensions"), + ): + dt.mean("invalid") + class TestOps: @pytest.mark.xfail(reason="arithmetic not implemented yet") diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index d2fc4f6d4e2..089ef558581 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -263,7 +263,7 @@ class DataStructure: create_example: str example_var_name: str numeric_only: bool = False - see_also_modules: tuple[str] = tuple + see_also_modules: tuple[str, ...] = tuple class Method: @@ -287,13 +287,13 @@ def __init__( self.additional_notes = additional_notes if bool_reduce: self.array_method = f"array_{name}" - self.np_example_array = """ - ... np.array([True, True, True, True, True, False], dtype=bool)""" + self.np_example_array = ( + """np.array([True, True, True, True, True, False], dtype=bool)""" + ) else: self.array_method = name - self.np_example_array = """ - ... np.array([1, 2, 3, 0, 2, np.nan])""" + self.np_example_array = """np.array([1, 2, 3, 0, 2, np.nan])""" @dataclass @@ -541,10 +541,27 @@ def generate_code(self, method, has_keep_attrs): ) +DATATREE_OBJECT = DataStructure( + name="DataTree", + create_example=""" + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", {example_array})), + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... )""", + example_var_name="dt", + numeric_only=True, + see_also_modules=("Dataset", "DataArray"), +) DATASET_OBJECT = DataStructure( name="Dataset", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -559,7 +576,8 @@ def generate_code(self, method, has_keep_attrs): DATAARRAY_OBJECT = DataStructure( name="DataArray", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -570,6 +588,15 @@ def generate_code(self, method, has_keep_attrs): numeric_only=False, see_also_modules=("Dataset",), ) +DATATREE_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATATREE_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) DATASET_GENERATOR = GenericAggregationGenerator( cls="", datastructure=DATASET_OBJECT, @@ -634,7 +661,7 @@ def generate_code(self, method, has_keep_attrs): create_example=""" >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x",{example_array}, + ... "x", {example_array} ... )""", example_var_name="na", numeric_only=False, @@ -670,6 +697,7 @@ def write_methods(filepath, generators, preamble): write_methods( filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", generators=[ + DATATREE_GENERATOR, DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR, From 0feb0292eafc0071320875867d2943ab62e3b168 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:01:30 +0900 Subject: [PATCH 02/12] add API docs on DataTree aggregations --- doc/api.rst | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 63fb59bc5e0..49f1aa2d03d 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -827,30 +827,30 @@ Compare one ``DataTree`` object to another. .. DataTree.polyfit .. DataTree.curvefit -.. Aggregation -.. ----------- +Aggregation +----------- -.. Aggregate data in all nodes in the subtree simultaneously. +Aggregate data in all nodes in the subtree simultaneously. -.. .. autosummary:: -.. :toctree: generated/ +.. autosummary:: + :toctree: generated/ -.. DataTree.all -.. DataTree.any -.. DataTree.argmax -.. DataTree.argmin -.. DataTree.idxmax -.. DataTree.idxmin -.. DataTree.max -.. DataTree.min -.. DataTree.mean -.. DataTree.median -.. DataTree.prod -.. DataTree.sum -.. DataTree.std -.. DataTree.var -.. DataTree.cumsum -.. DataTree.cumprod + DataTree.all + DataTree.any + DataTree.argmax + DataTree.argmin + DataTree.idxmax + DataTree.idxmin + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod .. ndarray methods .. --------------- From c9ea92a73b6bf1fc6c5648bb4f41f41983077f8f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:13:04 +0900 Subject: [PATCH 03/12] remove incorrectly added sel methods --- xarray/core/datatree.py | 56 ----------------------------------------- 1 file changed, 56 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 04f761fc3f9..bb24c88a907 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -39,7 +39,6 @@ Frozen, _default, dim_arg_to_dims_set, - drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, ) @@ -60,7 +59,6 @@ from xarray.core.types import ( Dims, ErrorOptions, - ErrorOptionsWithWarn, NetcdfWriteModes, ZarrWriteModes, ) @@ -1624,60 +1622,6 @@ def _get_all_dims(self) -> set: all_dims.update(node._node_dims) return all_dims - def _selective_indexing( - self, - func: Callable[[Dataset, Mapping[Any, Any]]], - indexers: Mapping[Any, Any], - missing_dims: ErrorOptionsWithWarn = "raise", - ) -> Self: - indexers = drop_dims_from_indexers(indexers, self._get_all_dims(), missing_dims) - result = {} - for node in self.subtree: - node_indexers = {k: v for k, v in indexers.items() if k in node.dims} - node_result = func(node.dataset, node_indexers) - for k in node_indexers: - if k not in node.coords and k in node_result.coords: - del node_result.coords[k] - result[node.path] = node_result - return type(self).from_dict(result, name=self.name) # type: ignore - - def isel( - self, - indexers: Mapping[Any, Any] | None = None, - drop: bool = False, - missing_dims: ErrorOptionsWithWarn = "raise", - **indexers_kwargs: Any, - ) -> Self: - """Positional indexing.""" - - def apply_indexers(dataset, node_indexers): - return dataset.isel(node_indexers, drop=drop) - - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") - return self._selective_indexing( - apply_indexers, indexers, missing_dims=missing_dims - ) - - def sel( - self, - indexers: Mapping[Any, Any] | None = None, - method: str | None = None, - tolerance: int | float | Iterable[int | float] | None = None, - drop: bool = False, - **indexers_kwargs: Any, - ) -> Self: - """Label-based indexing.""" - - def apply_indexers(dataset, node_indexers): - # TODO: reimplement in terms of map_index_queries(), to avoid - # redundant index look-ups on child nodes - return dataset.sel( - node_indexers, method=method, tolerance=tolerance, drop=drop - ) - - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - return self._selective_indexing(apply_indexers, indexers) - def reduce( self, func: Callable, From 7528d86a8f89cae7b62314e2a8a4d2d05bd4e1d3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:14:52 +0900 Subject: [PATCH 04/12] fix docstring reprs --- xarray/namedarray/_aggregations.py | 42 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index ff6bb06456f..139cea83b5b 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -68,7 +68,7 @@ def count( >>> na.count() Size: 8B - np.int64(5) + array(5) """ return self.reduce( duck_array_ops.count, @@ -121,7 +121,7 @@ def all( >>> na.all() Size: 1B - np.False_ + array(False) """ return self.reduce( duck_array_ops.array_all, @@ -174,7 +174,7 @@ def any( >>> na.any() Size: 1B - np.True_ + array(True) """ return self.reduce( duck_array_ops.array_any, @@ -232,13 +232,13 @@ def max( >>> na.max() Size: 8B - np.float64(3.0) + array(3.) Use ``skipna`` to control whether NaNs are ignored. >>> na.max(skipna=False) Size: 8B - np.float64(nan) + array(nan) """ return self.reduce( duck_array_ops.max, @@ -297,13 +297,13 @@ def min( >>> na.min() Size: 8B - np.float64(0.0) + array(0.) Use ``skipna`` to control whether NaNs are ignored. >>> na.min(skipna=False) Size: 8B - np.float64(nan) + array(nan) """ return self.reduce( duck_array_ops.min, @@ -366,13 +366,13 @@ def mean( >>> na.mean() Size: 8B - np.float64(1.6) + array(1.6) Use ``skipna`` to control whether NaNs are ignored. >>> na.mean(skipna=False) Size: 8B - np.float64(nan) + array(nan) """ return self.reduce( duck_array_ops.mean, @@ -442,13 +442,13 @@ def prod( >>> na.prod() Size: 8B - np.float64(0.0) + array(0.) Use ``skipna`` to control whether NaNs are ignored. >>> na.prod(skipna=False) Size: 8B - np.float64(nan) + array(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -525,13 +525,13 @@ def sum( >>> na.sum() Size: 8B - np.float64(8.0) + array(8.) Use ``skipna`` to control whether NaNs are ignored. >>> na.sum(skipna=False) Size: 8B - np.float64(nan) + array(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -605,19 +605,19 @@ def std( >>> na.std() Size: 8B - np.float64(1.019803902718557) + array(1.0198039) Use ``skipna`` to control whether NaNs are ignored. >>> na.std(skipna=False) Size: 8B - np.float64(nan) + array(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.std(skipna=True, ddof=1) Size: 8B - np.float64(1.140175425099138) + array(1.14017543) """ return self.reduce( duck_array_ops.std, @@ -685,19 +685,19 @@ def var( >>> na.var() Size: 8B - np.float64(1.04) + array(1.04) Use ``skipna`` to control whether NaNs are ignored. >>> na.var(skipna=False) Size: 8B - np.float64(nan) + array(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.var(skipna=True, ddof=1) Size: 8B - np.float64(1.3) + array(1.3) """ return self.reduce( duck_array_ops.var, @@ -761,13 +761,13 @@ def median( >>> na.median() Size: 8B - np.float64(2.0) + array(2.) Use ``skipna`` to control whether NaNs are ignored. >>> na.median(skipna=False) Size: 8B - np.float64(nan) + array(nan) """ return self.reduce( duck_array_ops.median, From 11359e5f9eb3513f56e40ec0993a042ff98ed8fc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:27:48 +0900 Subject: [PATCH 05/12] mypy fix --- xarray/core/datatree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bb24c88a907..a8148624bf2 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1095,7 +1095,7 @@ def from_dict( d: Mapping[str, Dataset | DataTree | None], /, name: str | None = None, - ) -> DataTree: + ) -> Self: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1646,4 +1646,4 @@ def reduce( **kwargs, ) result[node.path] = node_result - return type(self).from_dict(result, name=self.name) # type: ignore + return type(self).from_dict(result, name=self.name) From d3795855940ea210d147d423b34b60fed138c13d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 21:31:58 +0900 Subject: [PATCH 06/12] fix self import --- xarray/core/datatree.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index a8148624bf2..9dbb7f04094 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -11,7 +11,7 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Self, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload from xarray.core import utils from xarray.core._aggregations import DataTreeAggregations @@ -33,6 +33,7 @@ from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS from xarray.core.treenode import NamedNode, NodePath +from xarray.core.types import Self from xarray.core.utils import ( Default, FilteredMapping, From 49ef260ec8d8cd9bc9dbc233d7bf01a1e20b1978 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 7 Oct 2024 22:11:15 +0900 Subject: [PATCH 07/12] remove unimplemented agg methods --- doc/api.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 49f1aa2d03d..646a1fb0fa1 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -837,10 +837,6 @@ Aggregate data in all nodes in the subtree simultaneously. DataTree.all DataTree.any - DataTree.argmax - DataTree.argmin - DataTree.idxmax - DataTree.idxmin DataTree.max DataTree.min DataTree.mean From a9a940e8d6eef9b6495551ec0058d164c303723e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 10 Oct 2024 17:53:50 +0900 Subject: [PATCH 08/12] replace dim_arg_to_dims_set with parse_dims --- xarray/core/dataset.py | 4 ++-- xarray/core/datatree.py | 4 ++-- xarray/core/utils.py | 20 -------------------- xarray/tests/test_datatree.py | 2 +- 4 files changed, 5 insertions(+), 25 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f3c24be74b6..ff80248cf48 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -109,7 +109,6 @@ OrderedSet, _default, decode_numpy_dict_values, - dim_arg_to_dims_set, drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, @@ -119,6 +118,7 @@ is_duck_dask_array, is_scalar, maybe_wrap_array, + parse_dims, ) from xarray.core.variable import ( IndexVariable, @@ -6987,7 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - dims = dim_arg_to_dims_set(dim, self.dims) + dims = set(parse_dims(dim, tuple(self.dims))) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 9dbb7f04094..e1d352e3ff9 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -39,9 +39,9 @@ FilteredMapping, Frozen, _default, - dim_arg_to_dims_set, either_dict_or_kwargs, maybe_wrap_array, + parse_dims, ) from xarray.core.variable import Variable @@ -1634,7 +1634,7 @@ def reduce( **kwargs: Any, ) -> Self: """Reduce this tree by applying `func` along some dimension(s).""" - dims = dim_arg_to_dims_set(dim, self._get_all_dims()) + dims = set(parse_dims(dim, tuple(self._get_all_dims()))) result = {} for node in self.subtree: reduce_dims = [d for d in node._node_dims if d in dims] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 817586f0a0d..e5168342e1e 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -830,26 +830,6 @@ def drop_dims_from_indexers( ) -def dim_arg_to_dims_set(dim: Dims, all_dims: Collection) -> set: - """Convert a `dim` argument from Dataset/DataTree into a set of dimensions.""" - - if dim is None or dim is ...: - dims = set(all_dims) - elif isinstance(dim, str) or not isinstance(dim, Iterable): - # TODO: consider dropping `not isinstance(dim, Iterable)`, which is not - # allowed per the type signature - dims = {dim} - else: - dims = set(dim) - - missing_dimensions = tuple(d for d in dims if d not in all_dims) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} not found in data dimensions {tuple(all_dims)}" - ) - return dims - - @overload def parse_dims( dim: Dims, diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9eb819be48d..1f6b41eafa7 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1648,7 +1648,7 @@ def test_dim_argument(self): with pytest.raises( ValueError, - match=re.escape("Dimensions ('invalid',) not found in data dimensions"), + match=re.escape("Dimension(s) 'invalid' do not exist."), ): dt.mean("invalid") From 830f79753a6082b68f4f77af8923aea21ef17c89 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 10 Oct 2024 18:25:33 +0900 Subject: [PATCH 09/12] add parse_dims_as_set --- xarray/core/computation.py | 8 +++--- xarray/core/dataset.py | 4 +-- xarray/core/datatree.py | 4 +-- xarray/core/utils.py | 52 ++++++++++++++++++++++++++++++++++---- xarray/tests/test_utils.py | 10 ++++---- 5 files changed, 59 insertions(+), 19 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 91a184d55cd..a5b2e72e5c5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -31,7 +31,7 @@ from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import Dims, T_DataArray -from xarray.core.utils import is_dict_like, is_scalar, parse_dims +from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array @@ -1846,11 +1846,9 @@ def dot( dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dim = tuple(d for d, c in dim_counts.items() if c > 1) + dot_dims = {d for d, c in dim_counts.items() if c > 1} else: - dim = parse_dims(dim, all_dims=tuple(all_dims)) - - dot_dims: set[Hashable] = set(dim) + dot_dims = parse_dims_as_set(dim, all_dims=set(all_dims)) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0f29ee13176..afb78a72711 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -118,7 +118,7 @@ is_duck_dask_array, is_scalar, maybe_wrap_array, - parse_dims, + parse_dims_as_set, ) from xarray.core.variable import ( IndexVariable, @@ -6987,7 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - dims = set(parse_dims(dim, tuple(self.dims))) + dims = parse_dims_as_set(dim, self._dims.keys()) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 52159cf344d..980937bf711 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -43,7 +43,7 @@ drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, - parse_dims, + parse_dims_as_set, ) from xarray.core.variable import Variable @@ -1631,7 +1631,7 @@ def reduce( **kwargs: Any, ) -> Self: """Reduce this tree by applying `func` along some dimension(s).""" - dims = set(parse_dims(dim, tuple(self._get_all_dims()))) + dims = parse_dims_as_set(dim, self._get_all_dims()) result = {} for node in self.subtree: reduce_dims = [d for d in node._node_dims if d in dims] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..3ed26e02f02 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -58,6 +58,7 @@ Mapping, MutableMapping, MutableSet, + Set, Sequence, ValuesView, ) @@ -831,7 +832,7 @@ def drop_dims_from_indexers( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -841,7 +842,7 @@ def parse_dims( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -850,7 +851,7 @@ def parse_dims( ) -> tuple[Hashable, ...] | None | EllipsisType: ... -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -891,6 +892,47 @@ def parse_dims( return tuple(dim) +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: Set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[True] = True, +) -> Set[Hashable]: ... + + +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: Set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[False], +) -> Set[Hashable] | None | EllipsisType: ... + + +def parse_dims_as_set( + dim: Dims, + all_dims: Set[Hashable], + *, + check_exists: bool = True, + replace_none: bool = True, +) -> Set[Hashable] | None | EllipsisType: + """Like parse_dims_as_tuple, but returning a set instead of a tuple.""" + # TODO: Consider removing parse_dims_as_tuple? + if dim is None or dim is ...: + if replace_none: + return all_dims + return dim + if isinstance(dim, str): + dim = {dim} + dim = set(dim) + if check_exists: + _check_dims(dim, all_dims) + return dim + + @overload def parse_ordered_dims( dim: Dims, @@ -958,7 +1000,7 @@ def parse_ordered_dims( return dims[:idx] + other_dims + dims[idx + 1 :] else: # mypy cannot resolve that the sequence cannot contain "..." - return parse_dims( # type: ignore[call-overload] + return parse_dims_as_tuple( # type: ignore[call-overload] dim=dim, all_dims=all_dims, check_exists=check_exists, @@ -966,7 +1008,7 @@ def parse_ordered_dims( ) -def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: +def _check_dims(dim: Set[Hashable], all_dims: Set[Hashable]) -> None: wrong_dims = (dim - all_dims) - {...} if wrong_dims: wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 9ef4a688302..f62fbb63cb5 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -283,16 +283,16 @@ def test_infix_dims_errors(supplied, all_): pytest.param(..., ..., id="ellipsis"), ], ) -def test_parse_dims(dim, expected) -> None: +def test_parse_dims_as_tuple(dim, expected) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=False) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=False) assert actual == expected def test_parse_dims_set() -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables dim = {"a", 1} - actual = utils.parse_dims(dim, all_dims) + actual = utils.parse_dims_as_tuple(dim, all_dims) assert set(actual) == dim @@ -301,7 +301,7 @@ def test_parse_dims_set() -> None: ) def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=True) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=True) assert actual == all_dims @@ -316,7 +316,7 @@ def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: def test_parse_dims_raises(dim) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables with pytest.raises(ValueError, match="'x'"): - utils.parse_dims(dim, all_dims, check_exists=True) + utils.parse_dims_as_tuple(dim, all_dims, check_exists=True) @pytest.mark.parametrize( From c25b7c97775fc72ba2fcaf7771d0832aafea9a33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:27:21 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/datatree.py | 1 - xarray/core/utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 980937bf711..5fdc5f29682 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -34,7 +34,6 @@ from xarray.core.options import OPTIONS as XR_OPTS from xarray.core.treenode import NamedNode, NodePath from xarray.core.types import Self -from xarray.core.types import Self from xarray.core.utils import ( Default, FilteredMapping, diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 3ed26e02f02..01aeec3cd7c 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -58,8 +58,8 @@ Mapping, MutableMapping, MutableSet, - Set, Sequence, + Set, ValuesView, ) from enum import Enum From edf554175eea6c48f1faeafd2d9ce2851db1fc26 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Sat, 12 Oct 2024 23:49:08 -0400 Subject: [PATCH 11/12] fix mypy errors --- xarray/core/computation.py | 1 + xarray/core/dataset.py | 2 +- xarray/core/utils.py | 12 ++++++------ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a5b2e72e5c5..e2a6676252a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1841,6 +1841,7 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + dot_dims: set[Hashable] if dim is None: # find dimensions that occur more than once dim_counts: Counter = Counter() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index afb78a72711..e0cd92bab6e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6987,7 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - dims = parse_dims_as_set(dim, self._dims.keys()) + dims = parse_dims_as_set(dim, set(self._dims.keys())) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 01aeec3cd7c..e2781366265 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -895,30 +895,30 @@ def parse_dims_as_tuple( @overload def parse_dims_as_set( dim: Dims, - all_dims: Set[Hashable], + all_dims: set[Hashable], *, check_exists: bool = True, replace_none: Literal[True] = True, -) -> Set[Hashable]: ... +) -> set[Hashable]: ... @overload def parse_dims_as_set( dim: Dims, - all_dims: Set[Hashable], + all_dims: set[Hashable], *, check_exists: bool = True, replace_none: Literal[False], -) -> Set[Hashable] | None | EllipsisType: ... +) -> set[Hashable] | None | EllipsisType: ... def parse_dims_as_set( dim: Dims, - all_dims: Set[Hashable], + all_dims: set[Hashable], *, check_exists: bool = True, replace_none: bool = True, -) -> Set[Hashable] | None | EllipsisType: +) -> set[Hashable] | None | EllipsisType: """Like parse_dims_as_tuple, but returning a set instead of a tuple.""" # TODO: Consider removing parse_dims_as_tuple? if dim is None or dim is ...: From 4bab2beb26b001df4bf1a07fabc92ad54bf78057 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Sat, 12 Oct 2024 23:59:05 -0400 Subject: [PATCH 12/12] change tests to match slightly different error now thrown --- xarray/tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1178498de19..eafc11b630c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5615,7 +5615,7 @@ def test_reduce_bad_dim(self) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): data.mean(dim="bad_dim") @@ -5644,7 +5644,7 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): getattr(data, func)(dim="bad_dim")