Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Change number of pixels used by datashader #152

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD)
### Fixed
- Update `n_pixel` used by datashader to better adapt across resolutions #152


- Fixed bug in power spectra plotting for the n320 resolution.
- Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165)
Expand Down
27 changes: 17 additions & 10 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from matplotlib.collections import PathCollection
from matplotlib.colors import BoundaryNorm
from matplotlib.colors import ListedColormap
from matplotlib.colors import Normalize
from matplotlib.colors import TwoSlopeNorm
from pyshtools.expand import SHGLQ
from pyshtools.expand import SHExpandGLQ
Expand Down Expand Up @@ -568,8 +569,12 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
datashader=datashader,
)
else:
single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader)
single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader)
combined_data = np.concatenate((input_, truth, pred))
# For 'errors', only persistence and increments need identical colorbar-limits
combined_error = np.concatenate(((pred - input_), (truth - input_)))
norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data))
single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader)
single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader)
single_plot(
fig,
ax[3],
Expand Down Expand Up @@ -619,15 +624,15 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
datashader=datashader,
)
else:
single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader)
single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader)
single_plot(
fig,
ax[4],
lon,
lat,
pred - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()),
title=f"{vname} increment [pred - input]",
datashader=datashader,
)
Expand All @@ -638,7 +643,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
lat,
truth - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()),
title=f"{vname} persist err",
datashader=datashader,
)
Expand Down Expand Up @@ -703,13 +708,13 @@ def single_plot(
else:
df = pd.DataFrame({"val": data, "x": lon, "y": lat})
# Adjust binning to match the resolution of the data
n_pixels = int(np.floor(data.shape[0] / 212))
lower_limit = 25
upper_limit = 500
n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit)
psc = dsshow(
df,
dsh.Point("x", "y"),
dsh.mean("val"),
vmin=data.min(),
vmax=data.max(),
cmap=cmap,
plot_width=n_pixels,
plot_height=n_pixels,
Expand All @@ -718,8 +723,10 @@ def single_plot(
ax=ax,
)

ax.set_xlim((-np.pi, np.pi))
ax.set_ylim((-np.pi / 2, np.pi / 2))
xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi)
ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2)
ax.set_xlim((xmin - 0.1, xmax + 0.1))
ax.set_ylim((ymin - 0.1, ymax + 0.1))

continents.plot_continents(ax)

Expand Down
Loading