Skip to content

Commit

Permalink
Merge branch 'main' into grouper-public-api
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored Apr 21, 2024
2 parents 2bf596c + 5f24079 commit 04b433a
Show file tree
Hide file tree
Showing 27 changed files with 808 additions and 418 deletions.
12 changes: 11 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ New Features
~~~~~~~~~~~~
- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.

- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -36,6 +40,12 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`)
By `Eni Awowale <https://github.com/eni-awowale>`_, `Julia Signell <https://github.com/jsignell>`_
and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
4 changes: 3 additions & 1 deletion properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from hypothesis import given # isort:skip

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
npst.unsigned_integer_dtypes(endianness="="),
npst.integer_dtypes(endianness="="),
npst.floating_dtypes(endianness="="),
)

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))
Expand Down
14 changes: 8 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz,dev]"]
dev = [
"hypothesis",
"mypy",
"pre-commit",
"pytest",
"pytest-cov",
Expand Down Expand Up @@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
[tool.mypy]
enable_error_code = "redundant-self"
exclude = [
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
]
files = "xarray"
show_error_codes = true
Expand All @@ -98,8 +99,8 @@ warn_unused_ignores = true

# Ignore mypy errors for modules imported from datatree_.
[[tool.mypy.overrides]]
module = "xarray.datatree_.*"
ignore_errors = true
module = "xarray.datatree_.*"

# Much of the numerical computing stack doesn't have type annotations yet.
[[tool.mypy.overrides]]
Expand Down Expand Up @@ -129,6 +130,7 @@ module = [
"opt_einsum.*",
"pandas.*",
"pooch.*",
"pyarrow.*",
"pydap.*",
"pytest.*",
"scipy.*",
Expand Down Expand Up @@ -255,6 +257,9 @@ target-version = "py39"
# E402: module level import not at top of file
# E501: line too long - let black worry about that
# E731: do not assign a lambda expression, use a def
extend-safe-fixes = [
"TID252", # absolute imports
]
ignore = [
"E402",
"E501",
Expand All @@ -268,9 +273,6 @@ select = [
"I", # isort
"UP", # Pyupgrade
]
extend-safe-fixes = [
"TID252", # absolute imports
]

[tool.ruff.lint.per-file-ignores]
# don't enforce absolute imports
Expand Down
60 changes: 47 additions & 13 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
Expand Down Expand Up @@ -6853,10 +6854,13 @@ def reduce(
if (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
not is_extension_array_dtype(var.dtype)
and (
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
)
):
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
Expand Down Expand Up @@ -7169,13 +7173,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
)

def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
columns_in_order = [k for k in self.variables if k not in self.dims]
non_extension_array_columns = [
k
for k in columns_in_order
if not is_extension_array_dtype(self.variables[k].data)
]
extension_array_columns = [
k
for k in columns_in_order
if is_extension_array_dtype(self.variables[k].data)
]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns
for k in non_extension_array_columns
]
index = self.coords.to_index([*ordered_dims])
return pd.DataFrame(dict(zip(columns, data)), index=index)
broadcasted_df = pd.DataFrame(
dict(zip(non_extension_array_columns, data)), index=index
)
for extension_array_column in extension_array_columns:
extension_array = self.variables[extension_array_column].data.array
index = self[self.variables[extension_array_column].dims[0]].data
extension_array_df = pd.DataFrame(
{extension_array_column: extension_array},
index=self[self.variables[extension_array_column].dims[0]].data,
)
extension_array_df.index.name = self.variables[extension_array_column].dims[
0
]
broadcasted_df = broadcasted_df.join(extension_array_df)
return broadcasted_df[columns_in_order]

def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Expand Down Expand Up @@ -7322,11 +7350,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

# Cast to a NumPy array first, in case the Series is a pandas Extension
# array (which doesn't have a valid NumPy dtype)
# TODO: allow users to control how this casting happens, e.g., by
# forwarding arguments to pandas.Series.to_numpy?
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
arrays = []
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v):
arrays.append((k, np.asarray(v)))
else:
extension_arrays.append((k, v))

indexes: dict[Hashable, Index] = {}
index_vars: dict[Hashable, Variable] = {}
Expand All @@ -7340,6 +7370,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
xr_idx = PandasIndex(lev, dim)
indexes[dim] = xr_idx
index_vars.update(xr_idx.create_variables())
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
extension_arrays = []
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
Expand All @@ -7353,7 +7385,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
else:
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
return obj
for name, extension_array in extension_arrays:
obj[name] = (dims, extension_array)
return obj[dataframe.columns] if len(dataframe.columns) else obj

def to_dask_dataframe(
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False
Expand Down
16 changes: 8 additions & 8 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.core.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
Expand All @@ -33,14 +41,6 @@
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin
from xarray.datatree_.datatree.formatting import datatree_repr
from xarray.datatree_.datatree.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.datatree_.datatree.mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import NodePath, TreeNode

Expand Down Expand Up @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

Expand Down Expand Up @@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from numpy import concatenate as _concatenate
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
Expand Down Expand Up @@ -156,7 +157,7 @@ def isnull(data):
return full_like(data, dtype=bool, fill_value=False)
else:
# at this point, array should have dtype=object
if isinstance(data, np.ndarray):
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
return pandas_isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
Expand Down Expand Up @@ -221,9 +222,19 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
):
import cupy as cp

Expand Down
Loading

0 comments on commit 04b433a

Please sign in to comment.