From 65e926706ab9da8d0e16aca244dad8abb5ca7168 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:22:24 +0100 Subject: [PATCH] fix/remapper-without-imputer (#178) --- CHANGELOG.md | 1 + src/anemoi/training/train/forecaster.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 927647f9..2453a374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) ### Fixed +- Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) ### Added - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 717e88c3..e3691ad1 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -229,12 +229,14 @@ def training_weights_for_imputed_variables( """Update the loss weights mask for imputed variables.""" if "loss_weights_mask" in self.loss.scalar: loss_weights_mask = torch.ones((1, 1), device=batch.device) + found_loss_mask_training = False # iterate over all pre-processors and check if they have a loss_mask_training attribute for pre_processor in self.model.pre_processors.processors.values(): if hasattr(pre_processor, "loss_mask_training"): loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training + found_loss_mask_training = True # if transform_loss_mask function exists for preprocessor apply it - if hasattr(pre_processor, "transform_loss_mask"): + if hasattr(pre_processor, "transform_loss_mask") and found_loss_mask_training: loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) # update scaler with loss_weights_mask retrieved from preprocessors self.loss.update_scalar(scalar=loss_weights_mask.cpu(), name="loss_weights_mask")