Skip to content

Commit

Permalink
Use .to_numpy() for quantified facetgrids (pydata#5886)
Browse files Browse the repository at this point in the history
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
  • Loading branch information
TomNicholas and Illviljan authored Oct 28, 2021
1 parent c210f8b commit 36f05d7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ Bug fixes
By `Jimmy Westling <https://github.com/illviljan>`_.
- Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`).
By `Maxime Liquet <https://github.com/maximlt>`_.
- Faceted plots will no longer raise a `pint.UnitStrippedWarning` when a `pint.Quantity` array is plotted,
and will correctly display the units of the data in the colorbar (if there is one) (:pull:`5886`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- With backends, check for path-like objects rather than ``pathlib.Path``
type, use ``os.fspath`` (:pull:`5879`).
By `Mike Taves <https://github.com/mwtoews>`_.
Expand Down
14 changes: 7 additions & 7 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def __init__(
)

# Set up the lists of names for the row and column facet variables
col_names = list(data[col].values) if col else []
row_names = list(data[row].values) if row else []
col_names = list(data[col].to_numpy()) if col else []
row_names = list(data[row].to_numpy()) if row else []

if single_group:
full = [{single_group: x} for x in data[single_group].values]
full = [{single_group: x} for x in data[single_group].to_numpy()]
empty = [None for x in range(nrow * ncol - len(full))]
name_dicts = full + empty
else:
Expand Down Expand Up @@ -251,7 +251,7 @@ def map_dataarray(self, func, x, y, **kwargs):
raise ValueError("cbar_ax not supported by FacetGrid.")

cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data.values, **kwargs
func, self.data.to_numpy(), **kwargs
)

self._cmap_extend = cmap_params.get("extend")
Expand Down Expand Up @@ -347,7 +347,7 @@ def map_dataset(

if hue and meta_data["hue_style"] == "continuous":
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data[hue].values, **kwargs
func, self.data[hue].to_numpy(), **kwargs
)
kwargs["meta_data"]["cmap_params"] = cmap_params
kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs
Expand Down Expand Up @@ -423,7 +423,7 @@ def _adjust_fig_for_guide(self, guide):
def add_legend(self, **kwargs):
self.figlegend = self.fig.legend(
handles=self._mappables[-1],
labels=list(self._hue_var.values),
labels=list(self._hue_var.to_numpy()),
title=self._hue_label,
loc="center right",
**kwargs,
Expand Down Expand Up @@ -619,7 +619,7 @@ def map(self, func, *args, **kwargs):
if namedict is not None:
data = self.data.loc[namedict]
plt.sca(ax)
innerargs = [data[a].values for a in args]
innerargs = [data[a].to_numpy() for a in args]
maybe_mappable = func(*innerargs, **kwargs)
# TODO: better way to verify that an artist is mappable?
# https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
Expand Down
10 changes: 5 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def newplotfunc(
# Matplotlib does not support normalising RGB data, so do it here.
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
if robust or vmax is not None or vmin is not None:
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust)
vmin, vmax, robust = None, None, False

if subplot_kws is None:
Expand Down Expand Up @@ -1146,10 +1146,6 @@ def newplotfunc(
else:
dims = (yval.dims[0], xval.dims[0])

# better to pass the ndarrays directly to plotting functions
xval = xval.to_numpy()
yval = yval.to_numpy()

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if imshow_rgb:
Expand All @@ -1162,6 +1158,10 @@ def newplotfunc(
if dims != darray.dims:
darray = darray.transpose(*dims, transpose_coords=True)

# better to pass the ndarrays directly to plotting functions
xval = xval.to_numpy()
yval = yval.to_numpy()

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

Expand Down
26 changes: 25 additions & 1 deletion xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -5614,11 +5614,35 @@ def test_units_in_line_plot_labels(self):
assert ax.get_ylabel() == "pressure [pascal]"
assert ax.get_xlabel() == "x [meters]"

def test_units_in_2d_plot_labels(self):
def test_units_in_2d_plot_colorbar_label(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)

assert cax.get_ylabel() == "pressure [pascal]"

def test_units_facetgrid_plot_labels(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
fgrid = da.plot.line(x="x", col="y")

assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]"

def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self):
arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure")

da.plot.imshow(x="x", y="y", col="w") # no colorbar to check labels of

def test_units_facetgrid_2d_contourf_plot_colorbar_labels(self):
arr = np.ones((2, 3, 4)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y", "z"], name="pressure")

fig, (ax1, ax2, ax3, cax) = plt.subplots(1, 4)
fgrid = da.plot.contourf(x="x", y="y", col="z")

assert fgrid.cbar.ax.get_ylabel() == "pressure [pascal]"

0 comments on commit 36f05d7

Please sign in to comment.