From 7f833ba5aa9344d72b08b5a94a261fdde0b286ea Mon Sep 17 00:00:00 2001 From: anaprietonem Date: Tue, 26 Nov 2024 14:52:30 +0000 Subject: [PATCH] changes to fix issues with LAM --- .../training/diagnostics/callbacks/plot.py | 4 +-- src/anemoi/training/diagnostics/plots.py | 28 ++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 08f9d28b..a54bab70 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -938,7 +938,7 @@ def _plot( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() @@ -999,7 +999,7 @@ def process( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() return data, output_tensor diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 0e80f0ec..99067466 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -564,8 +564,14 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) - single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + combined_data = np.concatenate((input_, truth, pred)) + # For 'errors', only persistence and increments need identical colorbar-limits + combined_error = np.concatenate(((pred - input_), (truth - input_))) + from matplotlib.colors import Normalize + + norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data)) + single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader) single_plot( fig, ax[3], @@ -615,7 +621,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( fig, ax[4], @@ -623,7 +629,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, pred - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} increment [pred - input]", datashader=datashader, ) @@ -634,7 +640,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, truth - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} persist err", datashader=datashader, ) @@ -699,13 +705,13 @@ def single_plot( else: df = pd.DataFrame({"val": data, "x": lon, "y": lat}) # Adjust binning to match the resolution of the data - n_pixels = min(int(np.floor(data.shape[0] * 0.004)), 500) + lower_limit = 35 + upper_limit = 500 + n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit) psc = dsshow( df, dsh.Point("x", "y"), dsh.mean("val"), - vmin=data.min(), - vmax=data.max(), cmap=cmap, plot_width=n_pixels, plot_height=n_pixels, @@ -714,8 +720,10 @@ def single_plot( ax=ax, ) - ax.set_xlim((-np.pi, np.pi)) - ax.set_ylim((-np.pi / 2, np.pi / 2)) + xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi) + ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2) + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) continents.plot_continents(ax)