diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 08d6d23aeab..13a48237f01 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ New Features `Nathan Lis `_. - Histogram plots are set with a title displaying the scalar coords if any, similarly to the other plots (:issue:`5791`, :pull:`5792`). By `Maxime Liquet `_. +- Slice plots display the coords units in the same way as x/y/colorbar labels (:pull:`5847`). + By `Victor Negîrneac `_. - Added a new :py:attr:`Dataset.chunksizes`, :py:attr:`DataArray.chunksizes`, and :py:attr:`Variable.chunksizes` property, which will always return a mapping from dimension names to chunking pattern along that dimension, regardless of whether the object is a Dataset, DataArray, or Variable. (:issue:`5846`, :pull:`5900`) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c6f534796d1..89f916db7f4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -22,6 +22,7 @@ import pandas as pd from ..plot.plot import _PlotMethods +from ..plot.utils import _get_units_from_attrs from . import ( computation, dtypes, @@ -3134,7 +3135,11 @@ def _title_for_slice(self, truncate: int = 50) -> str: for dim, coord in self.coords.items(): if coord.size == 1: one_dims.append( - "{dim} = {v}".format(dim=dim, v=format_item(coord.values)) + "{dim} = {v}{unit}".format( + dim=dim, + v=format_item(coord.values), + unit=_get_units_from_attrs(coord), + ) ) title = ", ".join(one_dims) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 77d973f613f..ebf6d7e28ed 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1,5 +1,4 @@ -"""Internal utilties; not for external use -""" +"""Internal utilities; not for external use""" import contextlib import functools import io diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 6fbbe9d4bca..a49302f7f87 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -467,6 +467,21 @@ def _maybe_gca(**kwargs): return plt.axes(**kwargs) +def _get_units_from_attrs(da): + """Extracts and formats the unit/units from a attributes.""" + pint_array_type = DuckArrayModule("pint").type + units = " [{}]" + if isinstance(da.data, pint_array_type): + units = units.format(str(da.data.units)) + elif da.attrs.get("units"): + units = units.format(da.attrs["units"]) + elif da.attrs.get("unit"): + units = units.format(da.attrs["unit"]) + else: + units = "" + return units + + def label_from_attrs(da, extra=""): """Makes informative labels if variable metadata (attrs) follows CF conventions.""" @@ -480,20 +495,7 @@ def label_from_attrs(da, extra=""): else: name = "" - def _get_units_from_attrs(da): - if da.attrs.get("units"): - units = " [{}]".format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = " [{}]".format(da.attrs["unit"]) - else: - units = "" - return units - - pint_array_type = DuckArrayModule("pint").type - if isinstance(da.data, pint_array_type): - units = " [{}]".format(str(da.data.units)) - else: - units = _get_units_from_attrs(da) + units = _get_units_from_attrs(da) # Treat `name` differently if it's a latex sequence if name.startswith("$") and (name.count("$") % 2 == 0): diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 8be20c5f81c..f36143c52c3 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5600,19 +5600,77 @@ def test_duck_array_ops(self): @requires_matplotlib class TestPlots(PlotTestCase): - def test_units_in_line_plot_labels(self): + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ], + ) + def test_units_in_line_plot_labels(self, coord_unit, coord_attrs): arr = np.linspace(1, 10, 3) * unit_registry.Pa - # TODO make coord a Quantity once unit-aware indexes supported - x_coord = xr.DataArray( - np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"} - ) + coord_arr = np.linspace(1, 3, 3) * coord_unit + x_coord = xr.DataArray(coord_arr, dims="x", attrs=coord_attrs) da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure") da.plot.line() ax = plt.gca() assert ax.get_ylabel() == "pressure [pascal]" - assert ax.get_xlabel() == "x [meters]" + assert ax.get_xlabel() == "x [meter]" + + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ], + ) + def test_units_in_slice_line_plot_labels_sel(self, coord_unit, coord_attrs): + arr = xr.DataArray( + name="var_a", + data=np.array([[1, 2], [3, 4]]), + coords=dict( + a=("a", np.array([5, 6]) * coord_unit, coord_attrs), + b=("b", np.array([7, 8]) * coord_unit, coord_attrs), + ), + dims=("a", "b"), + ) + arr.sel(a=5).plot(marker="o") + + assert plt.gca().get_title() == "a = 5 [meter]" + + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="pint.errors.UnitStrippedWarning"), + ), + ], + ) + def test_units_in_slice_line_plot_labels_isel(self, coord_unit, coord_attrs): + arr = xr.DataArray( + name="var_a", + data=np.array([[1, 2], [3, 4]]), + coords=dict( + a=("x", np.array([5, 6]) * coord_unit, coord_attrs), + b=("y", np.array([7, 8])), + ), + dims=("x", "y"), + ) + arr.isel(x=0).plot(marker="o") + assert plt.gca().get_title() == "a = 5 [meter]" def test_units_in_2d_plot_colorbar_label(self): arr = np.ones((2, 3)) * unit_registry.Pa