Skip to content

Commit

Permalink
Display coords' units for slice plots (#5847)
Browse files Browse the repository at this point in the history
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
  • Loading branch information
caenrigen and Illviljan committed Oct 30, 2021
1 parent 867646f commit 3c60814
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 23 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ New Features
`Nathan Lis <https://github.com/wxman22>`_.
- Histogram plots are set with a title displaying the scalar coords if any, similarly to the other plots (:issue:`5791`, :pull:`5792`).
By `Maxime Liquet <https://github.com/maximlt>`_.
- Slice plots display the coords units in the same way as x/y/colorbar labels (:pull:`5847`).
By `Victor Negîrneac <https://github.com/caenrigen>`_.
- Added a new :py:attr:`Dataset.chunksizes`, :py:attr:`DataArray.chunksizes`, and :py:attr:`Variable.chunksizes`
property, which will always return a mapping from dimension names to chunking pattern along that dimension,
regardless of whether the object is a Dataset, DataArray, or Variable. (:issue:`5846`, :pull:`5900`)
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pandas as pd

from ..plot.plot import _PlotMethods
from ..plot.utils import _get_units_from_attrs
from . import (
computation,
dtypes,
Expand Down Expand Up @@ -3134,7 +3135,11 @@ def _title_for_slice(self, truncate: int = 50) -> str:
for dim, coord in self.coords.items():
if coord.size == 1:
one_dims.append(
"{dim} = {v}".format(dim=dim, v=format_item(coord.values))
"{dim} = {v}{unit}".format(
dim=dim,
v=format_item(coord.values),
unit=_get_units_from_attrs(coord),
)
)

title = ", ".join(one_dims)
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Internal utilties; not for external use
"""
"""Internal utilities; not for external use"""
import contextlib
import functools
import io
Expand Down
30 changes: 16 additions & 14 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,21 @@ def _maybe_gca(**kwargs):
return plt.axes(**kwargs)


def _get_units_from_attrs(da):
"""Extracts and formats the unit/units from a attributes."""
pint_array_type = DuckArrayModule("pint").type
units = " [{}]"
if isinstance(da.data, pint_array_type):
units = units.format(str(da.data.units))
elif da.attrs.get("units"):
units = units.format(da.attrs["units"])
elif da.attrs.get("unit"):
units = units.format(da.attrs["unit"])
else:
units = ""
return units


def label_from_attrs(da, extra=""):
"""Makes informative labels if variable metadata (attrs) follows
CF conventions."""
Expand All @@ -480,20 +495,7 @@ def label_from_attrs(da, extra=""):
else:
name = ""

def _get_units_from_attrs(da):
if da.attrs.get("units"):
units = " [{}]".format(da.attrs["units"])
elif da.attrs.get("unit"):
units = " [{}]".format(da.attrs["unit"])
else:
units = ""
return units

pint_array_type = DuckArrayModule("pint").type
if isinstance(da.data, pint_array_type):
units = " [{}]".format(str(da.data.units))
else:
units = _get_units_from_attrs(da)
units = _get_units_from_attrs(da)

# Treat `name` differently if it's a latex sequence
if name.startswith("$") and (name.count("$") % 2 == 0):
Expand Down
70 changes: 64 additions & 6 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -5600,19 +5600,77 @@ def test_duck_array_ops(self):

@requires_matplotlib
class TestPlots(PlotTestCase):
def test_units_in_line_plot_labels(self):
@pytest.mark.parametrize(
"coord_unit, coord_attrs",
[
(1, {"units": "meter"}),
pytest.param(
unit_registry.m,
{},
marks=pytest.mark.xfail(reason="indexes don't support units"),
),
],
)
def test_units_in_line_plot_labels(self, coord_unit, coord_attrs):
arr = np.linspace(1, 10, 3) * unit_registry.Pa
# TODO make coord a Quantity once unit-aware indexes supported
x_coord = xr.DataArray(
np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
)
coord_arr = np.linspace(1, 3, 3) * coord_unit
x_coord = xr.DataArray(coord_arr, dims="x", attrs=coord_attrs)
da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")

da.plot.line()

ax = plt.gca()
assert ax.get_ylabel() == "pressure [pascal]"
assert ax.get_xlabel() == "x [meters]"
assert ax.get_xlabel() == "x [meter]"

@pytest.mark.parametrize(
"coord_unit, coord_attrs",
[
(1, {"units": "meter"}),
pytest.param(
unit_registry.m,
{},
marks=pytest.mark.xfail(reason="indexes don't support units"),
),
],
)
def test_units_in_slice_line_plot_labels_sel(self, coord_unit, coord_attrs):
arr = xr.DataArray(
name="var_a",
data=np.array([[1, 2], [3, 4]]),
coords=dict(
a=("a", np.array([5, 6]) * coord_unit, coord_attrs),
b=("b", np.array([7, 8]) * coord_unit, coord_attrs),
),
dims=("a", "b"),
)
arr.sel(a=5).plot(marker="o")

assert plt.gca().get_title() == "a = 5 [meter]"

@pytest.mark.parametrize(
"coord_unit, coord_attrs",
[
(1, {"units": "meter"}),
pytest.param(
unit_registry.m,
{},
marks=pytest.mark.xfail(reason="pint.errors.UnitStrippedWarning"),
),
],
)
def test_units_in_slice_line_plot_labels_isel(self, coord_unit, coord_attrs):
arr = xr.DataArray(
name="var_a",
data=np.array([[1, 2], [3, 4]]),
coords=dict(
a=("x", np.array([5, 6]) * coord_unit, coord_attrs),
b=("y", np.array([7, 8])),
),
dims=("x", "y"),
)
arr.isel(x=0).plot(marker="o")
assert plt.gca().get_title() == "a = 5 [meter]"

def test_units_in_2d_plot_colorbar_label(self):
arr = np.ones((2, 3)) * unit_registry.Pa
Expand Down

0 comments on commit 3c60814

Please sign in to comment.