Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
changes to fix issues with LAM
Browse files Browse the repository at this point in the history
  • Loading branch information
anaprietonem committed Nov 26, 2024
1 parent da1d79b commit 7f833ba
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def _plot(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
)
output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy()
output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy()
data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan)
data = data.numpy()

Expand Down Expand Up @@ -999,7 +999,7 @@ def process(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
)
output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy()
output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy()
data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan)
data = data.numpy()
return data, output_tensor
Expand Down
28 changes: 18 additions & 10 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,14 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
datashader=datashader,
)
else:
single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader)
single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader)
combined_data = np.concatenate((input_, truth, pred))
# For 'errors', only persistence and increments need identical colorbar-limits
combined_error = np.concatenate(((pred - input_), (truth - input_)))
from matplotlib.colors import Normalize

norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data))
single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader)
single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader)
single_plot(
fig,
ax[3],
Expand Down Expand Up @@ -615,15 +621,15 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
datashader=datashader,
)
else:
single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader)
single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader)
single_plot(
fig,
ax[4],
lon,
lat,
pred - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()),
title=f"{vname} increment [pred - input]",
datashader=datashader,
)
Expand All @@ -634,7 +640,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
lat,
truth - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()),
title=f"{vname} persist err",
datashader=datashader,
)
Expand Down Expand Up @@ -699,13 +705,13 @@ def single_plot(
else:
df = pd.DataFrame({"val": data, "x": lon, "y": lat})
# Adjust binning to match the resolution of the data
n_pixels = min(int(np.floor(data.shape[0] * 0.004)), 500)
lower_limit = 35
upper_limit = 500
n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit)
psc = dsshow(
df,
dsh.Point("x", "y"),
dsh.mean("val"),
vmin=data.min(),
vmax=data.max(),
cmap=cmap,
plot_width=n_pixels,
plot_height=n_pixels,
Expand All @@ -714,8 +720,10 @@ def single_plot(
ax=ax,
)

ax.set_xlim((-np.pi, np.pi))
ax.set_ylim((-np.pi / 2, np.pi / 2))
xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi)
ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2)
ax.set_xlim((xmin - 0.1, xmax + 0.1))
ax.set_ylim((ymin - 0.1, ymax + 0.1))

continents.plot_continents(ax)

Expand Down

0 comments on commit 7f833ba

Please sign in to comment.