From 11f89ecdd41226cf93da8d1e720d2710849cd23e Mon Sep 17 00:00:00 2001 From: Etienne Schalk <45271239+etienneschalk@users.noreply.github.com> Date: Wed, 13 Mar 2024 16:36:34 +0100 Subject: [PATCH] Do not attempt to broadcast when global option ``arithmetic_broadcast=False`` (#8784) * increase plot size * added old tests * Keep relevant test * what's new * PR comment * remove unnecessary (?) check * unnecessary line removal * removal of variable reassignment to avoid type issue * Update xarray/core/variable.py Co-authored-by: Deepak Cherian * Update xarray/core/variable.py Co-authored-by: Deepak Cherian * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests * what's new * Update doc/whats-new.rst --------- Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 3 +++ xarray/core/options.py | 3 +++ xarray/core/variable.py | 10 +++++++++ xarray/tests/__init__.py | 7 +++++++ xarray/tests/test_dataarray.py | 38 ++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eb5dcbda36f..bc603ab61d3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v2024.03.0 (unreleased) New Features ~~~~~~~~~~~~ +- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` + (:issue:`6806`, :pull:`8784`). + By `Etienne Schalk `_ and `Deepak Cherian `_. - Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) By `Anderson Banihirwe `_. diff --git a/xarray/core/options.py b/xarray/core/options.py index 18e3484e9c4..f5614104357 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -34,6 +34,7 @@ ] class T_Options(TypedDict): + arithmetic_broadcast: bool arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] cmap_divergent: str | Colormap cmap_sequential: str | Colormap @@ -59,6 +60,7 @@ class T_Options(TypedDict): OPTIONS: T_Options = { + "arithmetic_broadcast": True, "arithmetic_join": "inner", "cmap_divergent": "RdBu_r", "cmap_sequential": "viridis", @@ -92,6 +94,7 @@ def _positive_integer(value: int) -> bool: _VALIDATORS = { + "arithmetic_broadcast": lambda value: isinstance(value, bool), "arithmetic_join": _JOIN_OPTIONS.__contains__, "display_max_rows": _positive_integer, "display_values_threshold": _positive_integer, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 315c46369bd..cac4d3049e2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2871,6 +2871,16 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]: def _broadcast_compat_data(self, other): + if not OPTIONS["arithmetic_broadcast"]: + if (isinstance(other, Variable) and self.dims != other.dims) or ( + is_duck_array(other) and self.ndim != other.ndim + ): + raise ValueError( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ) + if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]): # `other` satisfies the necessary Variable API for broadcast_variables new_self, new_other = _broadcast_compat_variables(self, other) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index df0899509cb..2e6e638f5b1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -89,6 +89,13 @@ def _importorskip( has_pynio, requires_pynio = _importorskip("Nio") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The current Dask DataFrame implementation is deprecated.", + category=DeprecationWarning, + ) + has_dask_expr, requires_dask_expr = _importorskip("dask_expr") has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2829fd7d49c..5ab20b2c29c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -51,6 +51,7 @@ requires_bottleneck, requires_cupy, requires_dask, + requires_dask_expr, requires_iris, requires_numexpr, requires_pint, @@ -3203,6 +3204,42 @@ def test_align_str_dtype(self) -> None: assert_identical(expected_b, actual_b) assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast_on_vs_off_global_option_different_dims(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1], dims="x2") + + with xr.set_options(arithmetic_broadcast=True): + expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2")) + actual_xda = xda_1 / xda_2 + assert_identical(actual_xda, expected_xda) + + with xr.set_options(arithmetic_broadcast=False): + with pytest.raises( + ValueError, + match=re.escape( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ), + ): + xda_1 / xda_2 + + @pytest.mark.parametrize("arithmetic_broadcast", [True, False]) + def test_broadcast_on_vs_off_global_option_same_dims( + self, arithmetic_broadcast: bool + ) -> None: + # Ensure that no error is raised when arithmetic broadcasting is disabled, + # when broadcasting is not needed. The two DataArrays have the same + # dimensions of the same size. + xda_1 = xr.DataArray([1], dims="x") + xda_2 = xr.DataArray([1], dims="x") + expected_xda = xr.DataArray([2.0], dims=("x",)) + + with xr.set_options(arithmetic_broadcast=arithmetic_broadcast): + assert_identical(xda_1 + xda_2, expected_xda) + assert_identical(xda_1 + np.array([1.0]), expected_xda) + assert_identical(np.array([1.0]) + xda_1, expected_xda) + def test_broadcast_arrays(self) -> None: x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") @@ -3381,6 +3418,7 @@ def test_to_dataframe_0length(self) -> None: assert len(actual) == 0 assert_array_equal(actual.index.names, list("ABC")) + @requires_dask_expr @requires_dask def test_to_dask_dataframe(self) -> None: arr_np = np.arange(3 * 4).reshape(3, 4)