diff --git a/CHANGELOG.md b/CHANGELOG.md index 26b71522..4305b870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed +- Update `n_pixel` used by datashader to better adapt across resolutions #152 + - Fixed bug in power spectra plotting for the n320 resolution. - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 93e2d324..45818b69 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -24,6 +24,7 @@ from matplotlib.collections import PathCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap +from matplotlib.colors import Normalize from matplotlib.colors import TwoSlopeNorm from pyshtools.expand import SHGLQ from pyshtools.expand import SHExpandGLQ @@ -568,8 +569,12 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) - single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + combined_data = np.concatenate((input_, truth, pred)) + # For 'errors', only persistence and increments need identical colorbar-limits + combined_error = np.concatenate(((pred - input_), (truth - input_))) + norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data)) + single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader) single_plot( fig, ax[3], @@ -619,7 +624,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( fig, ax[4], @@ -627,7 +632,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, pred - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} increment [pred - input]", datashader=datashader, ) @@ -638,7 +643,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, truth - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} persist err", datashader=datashader, ) @@ -703,13 +708,13 @@ def single_plot( else: df = pd.DataFrame({"val": data, "x": lon, "y": lat}) # Adjust binning to match the resolution of the data - n_pixels = int(np.floor(data.shape[0] / 212)) + lower_limit = 25 + upper_limit = 500 + n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit) psc = dsshow( df, dsh.Point("x", "y"), dsh.mean("val"), - vmin=data.min(), - vmax=data.max(), cmap=cmap, plot_width=n_pixels, plot_height=n_pixels, @@ -718,8 +723,10 @@ def single_plot( ax=ax, ) - ax.set_xlim((-np.pi, np.pi)) - ax.set_ylim((-np.pi / 2, np.pi / 2)) + xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi) + ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2) + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) continents.plot_continents(ax)