From 6b45f8dd021fc6c83257938aa8c65d1f7f129cc6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:09:42 +0000 Subject: [PATCH 01/23] [Changelog] Update to 0.3.1 (#172) - Update changelog --------- Co-authored-by: Harrison Cook Co-authored-by: HCookie <48088699+HCookie@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4305b870..a949c1d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,55 +8,78 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) + +## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 + ### Fixed -- Update `n_pixel` used by datashader to better adapt across resolutions #152 +- Update `n_pixel` used by datashader to better adapt across resolutions #152 - Fixed bug in power spectra plotting for the n320 resolution. + - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) + ### Added + - Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) - Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159) ### Changed + ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 ### Changed + - Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) ### Fixed - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) + - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) + - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) - Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training - Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) + - Save entire config in mlflow ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) + - Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + + - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) + - Add without subsetting in ScaleTensor - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) + - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) + - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) + - Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) + - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) + - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) + - Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) + ### Changed - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) From 7e363fac9e58d21f11693aad5d9f8fa6120f33ce Mon Sep 17 00:00:00 2001 From: Simon Lang Date: Fri, 29 Nov 2024 09:26:37 +0000 Subject: [PATCH 02/23] full shuffle of the dataset (#153) * full shuffle of the dataset * added changelog entry --------- Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> --- CHANGELOG.md | 3 +++ src/anemoi/training/data/dataset.py | 28 +++++++++------------------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a949c1d5..dc2e9f14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you! ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 +### Changed +- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) + ### Fixed - Update `n_pixel` used by datashader to better adapt across resolutions #152 diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 40065e06..062d2d4d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -201,6 +201,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: low = shard_start + worker_id * self.n_samples_per_worker high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) LOGGER.debug( "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", @@ -212,27 +213,17 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] - - # each worker must have a different seed for its random number generator, - # otherwise all the workers will output exactly the same data - # should we check lightning env variable "PL_SEED_WORKERS" here? - # but we alwyas want to seed these anyways ... - base_seed = get_base_seed() - seed = ( - base_seed * (self.model_comm_group_id + 1) - worker_id - ) # note that test, validation etc. datasets get same seed - torch.manual_seed(seed) - random.seed(seed) - self.rng = np.random.default_rng(seed=seed) + torch.manual_seed(base_seed) + random.seed(base_seed) + self.rng = np.random.default_rng(seed=base_seed) sanity_rnd = self.rng.random(1) LOGGER.debug( ( "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d) using seed %d, sanity rnd %f" + "group_rank %d, base_seed %d), sanity rnd %f" ), worker_id, self.label, @@ -241,7 +232,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.model_comm_group_id, self.model_comm_group_rank, base_seed, - seed, sanity_rnd, ) @@ -256,12 +246,12 @@ def __iter__(self) -> torch.Tensor: """ if self.shuffle: shuffled_chunk_indices = self.rng.choice( - self.chunk_index_range, - size=self.n_samples_per_worker, + self.valid_date_indices, + size=len(self.valid_date_indices), replace=False, - ) + )[self.chunk_index_range] else: - shuffled_chunk_indices = self.chunk_index_range + shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] LOGGER.debug( ( From 460b604039567198c2b0993279a5975ca8522aef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= <148321314+havardhhaugen@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:46:19 +0100 Subject: [PATCH 03/23] Feature: Option to set / change area weighting outside of graph-creation (#136) * Implementation of aw_rescaling * Pre-commit * Updated implementation based on feedback * Small fixes - training now worked for all cases * Docstrings GraphNodeAttributes, minor fixes * Update changelog * Removed obsolete config options * Docstrings * Unit testing * Updated documentation * area_weights uses AreaWeights from anemoi-graphs * pre-commit * if test to check for scaled_attribute --- CHANGELOG.md | 1 + docs/user-guide/training.rst | 22 +++ src/anemoi/training/config/model/gnn.yaml | 2 - .../config/model/graphtransformer.yaml | 2 - .../training/config/model/transformer.yaml | 2 - .../training/config/training/default.yaml | 5 + src/anemoi/training/losses/nodeweights.py | 141 ++++++++++++++++++ src/anemoi/training/train/forecaster.py | 8 +- tests/train/test_nodeweights.py | 91 +++++++++++ 9 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 src/anemoi/training/losses/nodeweights.py create mode 100644 tests/train/test_nodeweights.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dc2e9f14..7f2d4c19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ Keep it human-readable, your future self will thank you! - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136) - Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 588b34d9..e90b1583 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -183,6 +183,28 @@ levels nearer to the surface). By default anemoi-training uses a ReLU Pressure Level scaler with a minimum weighting of 0.2 (i.e. no pressure level has a weighting less than 0.2). +The loss is also scaled by assigning a weight to each node on the output +grid. These weights are calculated during graph-creation and stored as +an attribute in the graph object. The configuration option +``config.training.node_loss_weights`` is used to specify the node +attribute used as weights in the loss function. By default +anemoi-training uses area weighting, where each node is weighted +according to the size of the geographical area it represents. + +It is also possible to rescale the weight of a subset of nodes after +they are loaded from the graph. For instance, for a stretched grid setup +we can rescale the weight of nodes in the limited area such that their +sum equals 0.25 of the sum of all node weights with the following config +setup + +.. code:: yaml + + node_loss_weights: + _target_: anemoi.training.losses.nodeweights.ReweightedGraphNodeAttribute + target_nodes: data + scaled_attribute: cutout + weight_frac_of_total: 0.25 + *************** Learning rate *************** diff --git a/src/anemoi/training/config/model/gnn.yaml b/src/anemoi/training/config/model/gnn.yaml index 4f4c176c..92a17fd4 100644 --- a/src/anemoi/training/config/model/gnn.yaml +++ b/src/anemoi/training/config/model/gnn.yaml @@ -45,8 +45,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 5c2e819a..9c48967b 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -50,8 +50,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/transformer.yaml b/src/anemoi/training/config/model/transformer.yaml index b26c9ecc..cd6a1e7b 100644 --- a/src/anemoi/training/config/model/transformer.yaml +++ b/src/anemoi/training/config/model/transformer.yaml @@ -49,8 +49,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index af168ecc..66d0631c 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -118,3 +118,8 @@ pressure_level_scaler: _target_: anemoi.training.data.scaling.ReluPressureLevelScaler minimum: 0.2 slope: 0.001 + +node_loss_weights: + _target_: anemoi.training.losses.nodeweights.GraphNodeAttribute + target_nodes: ${graph.data} + node_attribute: area_weight diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py new file mode 100644 index 00000000..ed4afaf4 --- /dev/null +++ b/src/anemoi/training/losses/nodeweights.py @@ -0,0 +1,141 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +import torch +from anemoi.graphs.nodes.attributes import AreaWeights +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + + +class GraphNodeAttribute: + """Base class to load and optionally change the weight attribute of nodes in the graph. + + Attributes + ---------- + target: str + name of target nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + + Methods + ------- + weights(self, graph_data) + Load node weight attribute. Compute area weights if they can not be found in graph + object. + """ + + def __init__(self, target_nodes: str, node_attribute: str): + """Initialize graph node attribute with target nodes and node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + """ + self.target = target_nodes + self.node_attribute = node_attribute + + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: + """Nodes weighted by the size of the geographical area they represent. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + torch.Tensor + area weights of the target nodes + """ + return AreaWeights(norm="unit-max", fill_value=0).compute(graph_data, self.target) + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + """Returns weight of type self.node_attribute for nodes self.target. + + Attempts to load from graph_data and calculates area weights for the target + nodes if they do not exist. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + torch.Tensor + weight of target nodes + """ + if self.node_attribute in graph_data[self.target]: + attr_weight = graph_data[self.target][self.node_attribute].squeeze() + + LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) + else: + attr_weight = self.area_weights(graph_data).squeeze() + + LOGGER.info( + "Node attribute %s not found in graph. Default area weighting will be used", + self.node_attribute, + ) + + return attr_weight + + +class ReweightedGraphNodeAttribute(GraphNodeAttribute): + """Method to reweight a subset of the target nodes defined by scaled_attribute. + + Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum + over all nodes. + """ + + def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float): + """Initialize reweighted graph node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + scaled_attribute: str + name of node attribute defining the subset of nodes to be scaled, key in HeteroData graph object + weight_frac_of_total: float + sum of weight of subset nodes as a fraction of sum of weight of all nodes after rescaling + + """ + super().__init__(target_nodes=target_nodes, node_attribute=node_attribute) + self.scaled_attribute = scaled_attribute + self.fraction = weight_frac_of_total + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + attr_weight = super().weights(graph_data) + + if self.scaled_attribute in graph_data[self.target]: + mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + else: + error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object" + raise KeyError(error_msg) + + unmasked_sum = torch.sum(attr_weight[~mask]) + weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) + attr_weight[mask] = weight_per_masked_node + + LOGGER.info( + "Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes", + self.scaled_attribute, + self.fraction, + ) + + return attr_weight diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index f92050cf..717e88c3 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -84,7 +84,7 @@ def __init__( self.save_hyperparameters() self.latlons_data = graph_data[config.graph.data].x - self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + self.node_weights = self.get_node_weights(config, graph_data) if config.model.get("output_mask", None) is not None: self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) @@ -315,6 +315,12 @@ def get_variable_scaling( return torch.from_numpy(variable_loss_scaling) + @staticmethod + def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: + node_weighting = instantiate(config.training.node_loss_weights) + + return node_weighting.weights(graph_data) + def set_model_comm_group( self, model_comm_group: ProcessGroup, diff --git a/tests/train/test_nodeweights.py b/tests/train/test_nodeweights.py new file mode 100644 index 00000000..00b1f41d --- /dev/null +++ b/tests/train/test_nodeweights.py @@ -0,0 +1,91 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import pytest +import torch +from anemoi.graphs.nodes.attributes import AreaWeights +from torch_geometric.data import HeteroData + +from anemoi.training.losses.nodeweights import GraphNodeAttribute +from anemoi.training.losses.nodeweights import ReweightedGraphNodeAttribute + + +def fake_graph() -> HeteroData: + hdata = HeteroData() + lons = torch.tensor([1.56, 3.12, 4.68, 6.24]) + lats = torch.tensor([-3.12, -1.56, 1.56, 3.12]) + cutout_mask = torch.tensor([False, True, False, False]).unsqueeze(1) + area_weights = torch.ones(cutout_mask.shape) + hdata["data"]["x"] = torch.stack((lats, lons), dim=1) + hdata["data"]["cutout"] = cutout_mask + hdata["data"]["area_weight"] = area_weights + + return hdata + + +def fake_sv_area_weights() -> torch.Tensor: + return AreaWeights(norm="unit-max", fill_value=0).compute(fake_graph(), "data").squeeze() + + +def fake_reweighted_sv_area_weights(frac: float) -> torch.Tensor: + weights = fake_sv_area_weights().unsqueeze(1) + cutout_mask = fake_graph()["data"]["cutout"] + unmasked_sum = torch.sum(weights[~cutout_mask]) + weight_per_masked_node = frac / (1.0 - frac) * unmasked_sum / sum(cutout_mask) + weights[cutout_mask] = weight_per_masked_node + + return weights.squeeze() + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", fake_graph(), fake_graph()["data"]["area_weight"]), + ("data", "non_existent_attr", fake_graph(), fake_sv_area_weights()), + ], +) +def test_grap_node_attributes( + target_nodes: str, + node_attribute: str, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = GraphNodeAttribute(target_nodes=target_nodes, node_attribute=node_attribute).weights(fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights) + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "scaled_attribute", "weight_frac_of_total", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", "cutout", 0.0, fake_graph(), torch.tensor([1.0, 0.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.5, fake_graph(), torch.tensor([1.0, 3.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.97, fake_graph(), torch.tensor([1.0, 97.0, 1.0, 1.0])), + ("data", "non_existent_attr", "cutout", 0.0, fake_graph(), fake_reweighted_sv_area_weights(0.0)), + ("data", "non_existent_attr", "cutout", 0.5, fake_graph(), fake_reweighted_sv_area_weights(0.5)), + ("data", "non_existent_attr", "cutout", 0.99, fake_graph(), fake_reweighted_sv_area_weights(0.99)), + ], +) +def test_graph_node_attributes( + target_nodes: str, + node_attribute: str, + scaled_attribute: str, + weight_frac_of_total: float, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = ReweightedGraphNodeAttribute( + target_nodes=target_nodes, + node_attribute=node_attribute, + scaled_attribute=scaled_attribute, + weight_frac_of_total=weight_frac_of_total, + ).weights(graph_data=fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights) From bb30beb159512e694d7765dadd10a0b730a86a8b Mon Sep 17 00:00:00 2001 From: Jesper Dramsch Date: Tue, 3 Dec 2024 13:30:07 +0100 Subject: [PATCH 04/23] Update sanity checks for training data consistency (#120) * fix: remove resolution check * feat: first implementation of Callback to check variable order in pre-training * feat: add variable order checks for pre-training and current training * tests: implement tests for variable order * docs: changelog * tests: make variable for number of fixed callbacks * refactor: remove nested if as per review * fix: remove resolution from config * Fix linting issues --------- Co-authored-by: Harrison Cook --- CHANGELOG.md | 39 +-- src/anemoi/training/config/data/zarr.yaml | 3 +- src/anemoi/training/data/datamodule.py | 9 +- .../diagnostics/callbacks/__init__.py | 9 +- .../training/diagnostics/callbacks/sanity.py | 169 ++++++++++ .../callbacks/test_variable_order.py | 289 ++++++++++++++++++ tests/diagnostics/test_callbacks.py | 10 +- 7 files changed, 486 insertions(+), 42 deletions(-) create mode 100644 src/anemoi/training/diagnostics/callbacks/sanity.py create mode 100644 tests/diagnostics/callbacks/test_variable_order.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f2d4c19..927647f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) +### Fixed + +### Added +- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) + +### Changed + +### Removed +- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 @@ -16,71 +25,44 @@ Keep it human-readable, your future self will thank you! - Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) ### Fixed - - Update `n_pixel` used by datashader to better adapt across resolutions #152 - - Fixed bug in power spectra plotting for the n320 resolution. - - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) - ### Added - - Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) - Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159) -### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 -### Changed - -- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) - ### Fixed - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) - - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) - - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) - - Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) - - Enable longer validation rollout than training - - Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) - - Save entire config in mlflow ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) - - Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) - - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) - - Add without subsetting in ScaleTensor - - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) - - Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) - - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) - Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136) - - Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) @@ -89,6 +71,9 @@ Keep it human-readable, your future self will thank you! - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) - Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) +- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) + +### Removed ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 diff --git a/src/anemoi/training/config/data/zarr.yaml b/src/anemoi/training/config/data/zarr.yaml index 3b9a4537..943899da 100644 --- a/src/anemoi/training/config/data/zarr.yaml +++ b/src/anemoi/training/config/data/zarr.yaml @@ -1,5 +1,4 @@ format: zarr -resolution: o96 # Time frequency requested from dataset frequency: 6h # Time step of model (must be multiple of frequency) @@ -82,5 +81,5 @@ processors: # _convert_: all # config: ${data.remapper} -# Values set in the code + # Values set in the code num_features: null # number of features in the forecast state diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index ba9ff0c3..84cbab9d 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -60,11 +60,6 @@ def __init__(self, config: DictConfig) -> None: if not self.config.dataloader.get("pin_memory", True): LOGGER.info("Data loader memory pinning disabled.") - def _check_resolution(self, resolution: str) -> None: - assert ( - self.config.data.resolution.lower() == resolution.lower() - ), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}" - @cached_property def statistics(self) -> dict: return self.ds_train.statistics @@ -153,7 +148,7 @@ def _get_dataset( label: str = "generic", ) -> NativeGridDataset: r = max(rollout, self.rollout) - data = NativeGridDataset( + return NativeGridDataset( data_reader=data_reader, rollout=r, multistep=self.config.training.multistep_input, @@ -161,8 +156,6 @@ def _get_dataset( shuffle=shuffle, label=label, ) - self._check_resolution(data.resolution) - return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: assert stage in {"training", "validation", "test"} diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f3597843..65a19ce1 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -23,6 +23,7 @@ from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback +from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback @@ -196,7 +197,13 @@ def get_callbacks(config: DictConfig) -> list[Callback]: trainer_callbacks.extend(_get_config_enabled_callbacks(config)) # Parent UUID callback - trainer_callbacks.append(ParentUUIDCallback(config)) + # Check variable order callback + trainer_callbacks.extend( + ( + ParentUUIDCallback(config), + CheckVariableOrder(), + ), + ) return trainer_callbacks diff --git a/src/anemoi/training/diagnostics/callbacks/sanity.py b/src/anemoi/training/diagnostics/callbacks/sanity.py new file mode 100644 index 00000000..751bf273 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/sanity.py @@ -0,0 +1,169 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import pytorch_lightning as pl + +LOGGER = logging.getLogger(__name__) + + +class CheckVariableOrder(pl.callbacks.Callback): + """Check the order of the variables in a pre-trained / fine-tuning model.""" + + def __init__(self) -> None: + super().__init__() + self._model_name_to_index = None + + def on_load_checkpoint(self, trainer: pl.Trainer, _: pl.LightningModule, checkpoint: dict) -> None: + """Cache the model mapping from the checkpoint. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + checkpoint : dict + Pytorch Lightning checkpoint + """ + self._model_name_to_index = checkpoint["hyper_parameters"]["data_indices"].name_to_index + data_name_to_index = trainer.datamodule.data_indices.name_to_index + + self._compare_variables(data_name_to_index) + + def on_sanity_check_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Cache the model mapping from the datamodule if not loaded from checkpoint. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + if self._model_name_to_index is None: + self._model_name_to_index = trainer.datamodule.data_indices.name_to_index + + def on_train_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the training data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_train.name_to_index + + self._compare_variables(data_name_to_index) + + def on_validation_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the validation data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_valid.name_to_index + + self._compare_variables(data_name_to_index) + + def on_test_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the test data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_test.name_to_index + + self._compare_variables(data_name_to_index) + + def _compare_variables(self, data_name_to_index: dict[str, int]) -> None: + """Compare the order of the variables in the model from checkpoint and the data. + + Parameters + ---------- + data_name_to_index : dict[str, int] + The dictionary mapping variable names to their indices in the data. + + Raises + ------ + ValueError + If the variable order in the model and data is verifiably different. + """ + if self._model_name_to_index is None: + LOGGER.info("No variable order to compare. Skipping variable order check.") + return + + if self._model_name_to_index == data_name_to_index: + LOGGER.info("The order of the variables in the model matches the order in the data.") + LOGGER.debug("%s, %s", self._model_name_to_index, data_name_to_index) + return + + keys1 = set(self._model_name_to_index.keys()) + keys2 = set(data_name_to_index.keys()) + + error_msg = "" + + # Find keys unique to each dictionary + only_in_model = {key: self._model_name_to_index[key] for key in (keys1 - keys2)} + only_in_data = {key: data_name_to_index[key] for key in (keys2 - keys1)} + + # Find common keys + common_keys = keys1 & keys2 + + # Compare values for common keys + different_values = { + k: (self._model_name_to_index[k], data_name_to_index[k]) + for k in common_keys + if self._model_name_to_index[k] != data_name_to_index[k] + } + + LOGGER.warning( + "The variables in the model do not match the variables in the data. " + "If you're fine-tuning or pre-training, you may have to adjust the " + "variable order and naming in your config.", + ) + if only_in_model: + LOGGER.warning("Variables only in model: %s", only_in_model) + if only_in_data: + LOGGER.warning("Variables only in data: %s", only_in_data) + if set(only_in_model.values()) == set(only_in_data.values()): + # This checks if the order is the same, but the naming is different. This is not be treated as an error. + LOGGER.warning( + "The variable naming is different, but the order appears to be the same. Continuing with training.", + ) + else: + # If the renamed variables are not in the same index locations, raise an error. + error_msg += ( + "The variable order in the model and data is different.\n" + "Please adjust the variable order in your config, you may need to " + "use the 'reorder' and 'rename' key in the dataloader config.\n" + "Refer to the Anemoi Datasets documentation for more information.\n" + ) + if different_values: + # If the variables are named the same but in different order, raise an error. + error_msg += ( + f"Detected a different sort order of the same variables: {different_values}.\n" + "Please adjust the variable order in your config, you may need to use the " + f"'reorder' key in the dataloader config. With:\n `reorder: {self._model_name_to_index}`\n" + ) + + if error_msg: + LOGGER.error(error_msg) + raise ValueError(error_msg) diff --git a/tests/diagnostics/callbacks/test_variable_order.py b/tests/diagnostics/callbacks/test_variable_order.py new file mode 100644 index 00000000..6f91cdc2 --- /dev/null +++ b/tests/diagnostics/callbacks/test_variable_order.py @@ -0,0 +1,289 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +import pytest +from anemoi.models.data_indices.collection import IndexCollection + +from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder +from anemoi.training.train.train import AnemoiTrainer + + +@pytest.fixture +def name_to_index() -> dict: + return {"a": 0, "b": 1, "c": 2} + + +@pytest.fixture +def name_to_index_permute() -> dict: + return {"a": 0, "b": 2, "c": 1} + + +@pytest.fixture +def name_to_index_rename() -> dict: + return {"a": 0, "b": 1, "d": 2} + + +@pytest.fixture +def name_to_index_partial_rename_permute() -> dict: + return {"a": 2, "b": 1, "d": 0} + + +@pytest.fixture +def name_to_index_rename_permute() -> dict: + return {"x": 2, "b": 1, "d": 0} + + +@pytest.fixture +def fake_trainer(mocker: Any, name_to_index: dict) -> AnemoiTrainer: + trainer = mocker.Mock(spec=AnemoiTrainer) + trainer.datamodule.data_indices.name_to_index = name_to_index + return trainer + + +@pytest.fixture +def checkpoint(mocker: Any, name_to_index: dict) -> dict[str, dict[str, IndexCollection]]: + data_index = mocker.Mock(spec=IndexCollection) + data_index.name_to_index = name_to_index + return {"hyper_parameters": {"data_indices": data_index}} + + +@pytest.fixture +def callback() -> CheckVariableOrder: + callback = CheckVariableOrder() + assert callback is not None + assert hasattr(callback, "on_load_checkpoint") + assert hasattr(callback, "on_sanity_check_start") + assert hasattr(callback, "on_train_epoch_start") + assert hasattr(callback, "on_validation_epoch_start") + assert hasattr(callback, "on_test_epoch_start") + + assert callback._model_name_to_index is None + + return callback + + +def test_on_load_checkpoint( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + checkpoint: dict, + name_to_index: dict, +) -> None: + assert callback._model_name_to_index is None + callback.on_load_checkpoint(fake_trainer, None, checkpoint) + assert callback._model_name_to_index == name_to_index + + assert callback._compare_variables(name_to_index) is None + + +def test_on_sanity(fake_trainer: AnemoiTrainer, callback: CheckVariableOrder, name_to_index: dict) -> None: + assert callback._model_name_to_index is None + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + assert callback._compare_variables(name_to_index) is None + + +def test_on_epoch(fake_trainer: AnemoiTrainer, callback: CheckVariableOrder, name_to_index: dict) -> None: + """Test all epoch functions with "working" indices.""" + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index + fake_trainer.datamodule.ds_test.name_to_index = name_to_index + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + assert callback._compare_variables(name_to_index) is None + + +def test_on_epoch_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_permute: dict, +) -> None: + """Test all epoch functions with permuted indices. + + Expecting errors in all cases. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_permute + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_train_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_validation_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_test_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback._compare_variables(name_to_index_permute) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + + +def test_on_epoch_rename( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_rename: dict, +) -> None: + """Test all epoch functions with renamed indices. + + Expecting passes in all cases. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_rename + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_rename + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + callback._compare_variables(name_to_index_rename) + + +def test_on_epoch_rename_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_rename_permute: dict, +) -> None: + """Test all epoch functions with renamed and permuted indices. + + Expects all passes (but warnings). + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_rename_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_rename_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename_permute + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + callback._compare_variables(name_to_index_rename_permute) + + +def test_on_epoch_partial_rename_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_partial_rename_permute: dict, +) -> None: + """Test all epoch functions with partially renamed and permuted indices. + + Expects all errors. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_partial_rename_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_partial_rename_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_partial_rename_permute + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_train_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_validation_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_test_epoch_start(fake_trainer, None) + + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback._compare_variables(name_to_index_partial_rename_permute) + + +def test_on_epoch_wrong_validation( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_permute: dict, + name_to_index_rename: dict, +) -> None: + """Test all epoch functions with "working" indices, but different validation indices.""" + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename + callback.on_train_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_validation_epoch_start(fake_trainer, None) + assert " {'c': (2, 1), 'b': (1, 2)}" in str( + exc_info.value, + ) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + callback.on_test_epoch_start(fake_trainer, None) + + assert callback._compare_variables(name_to_index) is None diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py index a61b19f1..58ea6440 100644 --- a/tests/diagnostics/test_callbacks.py +++ b/tests/diagnostics/test_callbacks.py @@ -14,6 +14,8 @@ from anemoi.training.diagnostics.callbacks import get_callbacks +NUM_FIXED_CALLBACKS = 2 # ParentUUIDCallback, CheckVariableOrder + default_config = """ diagnostics: callbacks: [] @@ -39,7 +41,7 @@ def test_no_extra_callbacks_set(): # No extra callbacks set config = omegaconf.OmegaConf.create(yaml.safe_load(default_config)) callbacks = get_callbacks(config) - assert len(callbacks) == 1 # ParentUUIDCallback + assert len(callbacks) == NUM_FIXED_CALLBACKS # ParentUUIDCallback, CheckVariableOrder, etc def test_add_config_enabled_callback(): @@ -47,7 +49,7 @@ def test_add_config_enabled_callback(): config = omegaconf.OmegaConf.create(default_config) config.diagnostics.callbacks.append({"log": {"mlflow": {"enabled": True}}}) callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1 def test_add_callback(): @@ -56,7 +58,7 @@ def test_add_callback(): {"_target_": "anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback"}, ) callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1 def test_add_plotting_callback(monkeypatch): @@ -73,4 +75,4 @@ def __init__(self, config: omegaconf.DictConfig): config.diagnostics.plot.enabled = True config.diagnostics.plot.callbacks = [{"_target_": "anemoi.training.diagnostics.callbacks.plot.PlotLoss"}] callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1 From 9d5aa2fab77bc80d2b52bff533efc30edfe7e192 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:03:58 +0000 Subject: [PATCH 05/23] [pre-commit.ci] pre-commit autoupdate (#177) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.2 → v0.8.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.2...v0.8.1) - [github.com/jshwi/docsig: v0.64.0 → v0.65.0](https://github.com/jshwi/docsig/compare/v0.64.0...v0.65.0) * fix: pre-commit docsig * fix: qa --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gert Mertes --- .pre-commit-config.yaml | 5 ++--- src/anemoi/training/diagnostics/callbacks/plot.py | 2 +- src/anemoi/training/diagnostics/mlflow/logger.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbc225df..8f820a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.8.1 hooks: - id: ruff args: @@ -64,7 +64,7 @@ repos: hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig - rev: v0.64.0 + rev: v0.65.0 hooks: - id: docsig args: @@ -74,6 +74,5 @@ repos: - --check-protected # Check protected methods - --check-class # Check class docstrings - --disable=E113 # Disable empty docstrings - - --summary # Print a summary ci: autoupdate_schedule: monthly diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index a54bab70..197a401e 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -816,7 +816,7 @@ def _plot( # reorder parameter_names by position self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] if not isinstance(pl_module.loss, BaseWeightedLoss): - logging.warning( + LOGGER.warning( "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", RuntimeWarning, ) diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 78a80be9..7b8e5f53 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -207,7 +207,7 @@ def _store_buffered_logs(self) -> None: # split lines and keep \n at the end of each line lines = [e + b"\n" for e in data.split(b"\n") if e] - ansi_csi_re = re.compile(b"\001?\033\\[((?:\\d|;)*)([a-dA-D])\002?") + ansi_csi_re = re.compile(b"\001?\033\\[((?:\\d|;)*)([a-dA-D])\002?") # noqa: RUF039 def _handle_csi(line: bytes) -> bytes: # removes the cursor up and down symbols from the line @@ -252,7 +252,7 @@ def __init__( run_name: str | None = None, tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"), save_dir: str | None = "./mlruns", - log_model: Literal[True, False, "all"] = False, + log_model: Literal["all"] | bool = False, prefix: str = "", resumed: bool | None = False, forked: bool | None = False, From 65e926706ab9da8d0e16aca244dad8abb5ca7168 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:22:24 +0100 Subject: [PATCH 06/23] fix/remapper-without-imputer (#178) --- CHANGELOG.md | 1 + src/anemoi/training/train/forecaster.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 927647f9..2453a374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) ### Fixed +- Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) ### Added - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 717e88c3..e3691ad1 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -229,12 +229,14 @@ def training_weights_for_imputed_variables( """Update the loss weights mask for imputed variables.""" if "loss_weights_mask" in self.loss.scalar: loss_weights_mask = torch.ones((1, 1), device=batch.device) + found_loss_mask_training = False # iterate over all pre-processors and check if they have a loss_mask_training attribute for pre_processor in self.model.pre_processors.processors.values(): if hasattr(pre_processor, "loss_mask_training"): loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training + found_loss_mask_training = True # if transform_loss_mask function exists for preprocessor apply it - if hasattr(pre_processor, "transform_loss_mask"): + if hasattr(pre_processor, "transform_loss_mask") and found_loss_mask_training: loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) # update scaler with loss_weights_mask retrieved from preprocessors self.loss.update_scalar(scalar=loss_weights_mask.cpu(), name="loss_weights_mask") From 5c4ac3fae6ec7bc5d5b70836a88aba979a59730b Mon Sep 17 00:00:00 2001 From: Cathal O'Brien Date: Wed, 4 Dec 2024 15:08:23 +0100 Subject: [PATCH 07/23] [PROFILER] dont crash if an env var isnt found (#180) * check env vars safely * changelog --- CHANGELOG.md | 1 + src/anemoi/training/train/profiler.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2453a374..7d20fda0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) ### Fixed - Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) +- Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) ### Added - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) diff --git a/src/anemoi/training/train/profiler.py b/src/anemoi/training/train/profiler.py index 40dd50e7..817a98f1 100644 --- a/src/anemoi/training/train/profiler.py +++ b/src/anemoi/training/train/profiler.py @@ -57,8 +57,8 @@ def print_title() -> None: @staticmethod def print_metadata() -> None: - console.print(f"[bold blue] SLURM NODE(s) {os.environ['HOST']} [/bold blue]!") - console.print(f"[bold blue] SLURM JOB ID {os.environ['SLURM_JOB_ID']} [/bold blue]!") + console.print(f"[bold blue] SLURM NODE(s) {os.getenv('SLURM_JOB_NODELIST', '')} [/bold blue]!") + console.print(f"[bold blue] SLURM JOB ID {os.getenv('SLURM_JOB_ID', '')} [/bold blue]!") console.print(f"[bold blue] TIMESTAMP {datetime.now(timezone.utc).strftime('%d/%m/%Y %H:%M:%S')} [/bold blue]!") @rank_zero_only From 891405eb72fe90028af436d69d1609ed2faa0771 Mon Sep 17 00:00:00 2001 From: Icedoom <9220778+icedoom888@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:08:48 +0100 Subject: [PATCH 08/23] Feature/transfer-learning (#166) * Introduced resume flag and checkpoint loading for transfer learning, removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting * Added len of dataset computed dynamically * debugging validation * Small changes * Removed prints * Not working * small changes * Imputer changes * Added sanification of checkpoint, effective batch size, git pre commit * gpc * gpc * New implementation: do not store modified checkpoint, load it directly after changing it * Added logging * Transfer learning working: implemented checkpoint cleaning with large models * Reverted some changes concerning imputer issues * Reverted some changes concerning imputer issues * Cleaned code for final review * Changed changelog and assigned TODO correctly * Changed changelog and assigned TODO correctly * Addressed review: copy checkpoint before removing metadata file * gpc passed * Removed logger in debugging mode * removed dataset lenght due to checkpointing issues * Reintroduced correct config on graphtansformer * gpc passed * Removed patched for issue #57, code expects patched checkpoint already * Removed new path name for patched checkpoint (ignoring fully issue #57) + removed fix for missing config * Adapted changelog * Switched logging to info from debug --- CHANGELOG.md | 4 +++ .../training/config/training/default.yaml | 1 + src/anemoi/training/data/datamodule.py | 11 +++++++ src/anemoi/training/data/dataset.py | 5 ++- src/anemoi/training/train/forecaster.py | 4 +++ src/anemoi/training/train/train.py | 32 ++++++++++++++++--- src/anemoi/training/utils/checkpoint.py | 28 +++++++++++++++- 7 files changed, 78 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d20fda0..e9c0997a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ Keep it human-readable, your future self will thank you! - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) ### Added +- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. +- TRANSFER LEARNING: enabled new functionality. You can now load checkpoints from different models and different training runs. +- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. + Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) ### Changed diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 66d0631c..c397ff75 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -2,6 +2,7 @@ run_id: null fork_run_id: null load_weights_only: null # only load model weights, do not restore optimiser states etc. +transfer_learning: null # activate to perform transfer learning # run in deterministic mode ; slows down deterministic: False diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 84cbab9d..edcb7728 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -147,7 +147,17 @@ def _get_dataset( rollout: int = 1, label: str = "generic", ) -> NativeGridDataset: + r = max(rollout, self.rollout) + + # Compute effective batch size + effective_bs = ( + self.config.dataloader.batch_size["training"] + * self.config.hardware.num_gpus_per_node + * self.config.hardware.num_nodes + // self.config.hardware.num_gpus_per_model + ) + return NativeGridDataset( data_reader=data_reader, rollout=r, @@ -155,6 +165,7 @@ def _get_dataset( timeincrement=self.timeincrement, shuffle=shuffle, label=label, + effective_bs=effective_bs, ) def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 062d2d4d..69aa154c 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -38,6 +38,7 @@ def __init__( timeincrement: int = 1, shuffle: bool = True, label: str = "generic", + effective_bs: int = 1, ) -> None: """Initialize (part of) the dataset state. @@ -55,9 +56,11 @@ def __init__( Shuffle batches, by default True label : str, optional label for the dataset, by default "generic" - + effective_bs : int, default 1 + effective batch size useful to compute the lenght of the dataset """ self.label = label + self.effective_bs = effective_bs self.data = data_reader diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index e3691ad1..a31039b3 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -603,8 +603,10 @@ def on_train_epoch_end(self) -> None: self.rollout = min(self.rollout, self.rollout_max) def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: + with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) + self.log( f"val_{getattr(self.loss, 'name', self.loss.__class__.__name__.lower())}", val_loss, @@ -615,6 +617,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) + for mname, mvalue in metrics.items(): self.log( "val_" + mname, @@ -626,6 +629,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch_size=batch.shape[0], sync_dist=True, ) + return val_loss, y_preds def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index a18ed4dc..6812dd99 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -34,6 +34,7 @@ from anemoi.training.diagnostics.logger import get_wandb_logger from anemoi.training.distributed.strategy import DDPGroupStrategy from anemoi.training.train.forecaster import GraphForecaster +from anemoi.training.utils.checkpoint import transfer_learning_loading from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.seeding import get_base_seed @@ -62,9 +63,8 @@ def __init__(self, config: DictConfig) -> None: OmegaConf.resolve(config) self.config = config - # Default to not warm-starting from a checkpoint self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) - self.load_weights_only = config.training.load_weights_only + self.load_weights_only = self.config.training.load_weights_only self.parent_uuid = None self.config.training.run_id = self.run_id @@ -83,6 +83,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" datamodule = AnemoiDatasetsDataModule(self.config) self.config.data.num_features = len(datamodule.ds_train.data.variables) + LOGGER.info("Number of data variables: %s", str(len(datamodule.ds_train.data.variables))) + LOGGER.debug("Variables: %s", str(datamodule.ds_train.data.variables)) return datamodule @cached_property @@ -145,10 +147,21 @@ def model(self) -> GraphForecaster: "metadata": self.metadata, "statistics": self.datamodule.statistics, } + + model = GraphForecaster(**kwargs) + if self.load_weights_only: + # Sanify the checkpoint for transfer learning + if self.config.training.transfer_learning: + LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint) + return transfer_learning_loading(model, self.last_checkpoint) + LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs) - return GraphForecaster(**kwargs) + + return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + + LOGGER.info("Model initialised from scratch.") + return model @rank_zero_only def _get_mlflow_run_id(self) -> str: @@ -200,6 +213,7 @@ def last_checkpoint(self) -> str | None: fork_id or self.lineage_run, self.config.hardware.files.warm_start or "last.ckpt", ) + # Check if the last checkpoint exists if Path(checkpoint).exists(): LOGGER.info("Resuming training from last checkpoint: %s", checkpoint) @@ -296,11 +310,15 @@ def _log_information(self) -> None: * self.config.hardware.num_gpus_per_node / self.config.hardware.num_gpus_per_model ) + LOGGER.debug( "Total GPU count / model group size: %d - NB: the learning rate will be scaled by this factor!", total_number_of_model_instances, ) - LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate) + LOGGER.debug( + "Effective learning rate: %.3e", + int(total_number_of_model_instances) * self.config.training.lr.rate, + ) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): @@ -352,6 +370,8 @@ def strategy(self) -> DDPGroupStrategy: def train(self) -> None: """Training entry point.""" + LOGGER.debug("Setting up trainer..") + trainer = pl.Trainer( accelerator=self.accelerator, callbacks=self.callbacks, @@ -378,6 +398,8 @@ def train(self) -> None: enable_progress_bar=self.config.diagnostics.enable_progress_bar, ) + LOGGER.debug("Starting training..") + trainer.fit( self.model, datamodule=self.datamodule, diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index ddb5a1c8..a78ef524 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -7,16 +7,19 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - from __future__ import annotations +import logging from pathlib import Path import torch +import torch.nn as nn from anemoi.utils.checkpoints import save_metadata from anemoi.training.train.forecaster import GraphForecaster +LOGGER = logging.getLogger(__name__) + def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]: """Load the lightning checkpoint and extract the pytorch model and its metadata. @@ -65,3 +68,26 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path: torch.save(model, inference_filepath) save_metadata(inference_filepath, metadata) return inference_filepath + + +def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module: + + # Load the checkpoint + checkpoint = torch.load(ckpt_path, map_location=model.device) + + # Filter out layers with size mismatch + state_dict = checkpoint["state_dict"] + + model_state_dict = model.state_dict() + + for key in state_dict.copy(): + if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: + LOGGER.info("Skipping loading parameter: %s", key) + LOGGER.info("Checkpoint shape: %s", str(state_dict[key].shape)) + LOGGER.info("Model shape: %s", str(model_state_dict[key].shape)) + + del state_dict[key] # Remove the mismatched key + + # Load the filtered st-ate_dict into the model + model.load_state_dict(state_dict, strict=False) + return model From e002b8c76b9feab8d6c797c00f49eade0ba6521c Mon Sep 17 00:00:00 2001 From: fprill <4728053+fprill@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:50:06 +0100 Subject: [PATCH 09/23] fix: allow None as graph save_path setting. (#181) --- src/anemoi/training/train/train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 6812dd99..ed6c9e39 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -120,14 +120,18 @@ def graph_data(self) -> HeteroData: Creates the graph in all workers. """ - graph_filename = Path( - self.config.hardware.paths.graph, - self.config.hardware.files.graph, - ) + if self.config.hardware.files.graph is not None: + graph_filename = Path( + self.config.hardware.paths.graph, + self.config.hardware.files.graph, + ) + + if graph_filename.exists() and not self.config.graph.overwrite: + LOGGER.info("Loading graph data from %s", graph_filename) + return torch.load(graph_filename) - if graph_filename.exists() and not self.config.graph.overwrite: - LOGGER.info("Loading graph data from %s", graph_filename) - return torch.load(graph_filename) + else: + graph_filename = None from anemoi.graphs.create import GraphCreator From 2179a59630656a38fff443812386eee534590560 Mon Sep 17 00:00:00 2001 From: Ewan <131677160+pinnstorm@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:10:46 +0000 Subject: [PATCH 10/23] fix: remove saving of metadata for training ckpt (#190) * remove saving of unused metadata for training ckpt, fixing #57 --- CHANGELOG.md | 1 + src/anemoi/training/diagnostics/callbacks/checkpoint.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9c0997a..f5e3fbc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Keep it human-readable, your future self will thank you! ### Fixed - Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) +- Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) ### Added - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py index cb95f5a4..2aba5246 100644 --- a/src/anemoi/training/diagnostics/callbacks/checkpoint.py +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -173,9 +173,6 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s if trainer.is_global_zero: from weakref import proxy - # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) - # notify loggers for logger in trainer.loggers: logger.after_save_checkpoint(proxy(self)) From da26cb7c859727ce2e0a81ae53dc6543265ff5c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oph=C3=A9lia=20Miralles?= Date: Tue, 10 Dec 2024 15:10:51 +0100 Subject: [PATCH 11/23] Fixes to callback plots (#182) * Lower bound delta lat in power spectrum plot and align input color map for precip plots --- CHANGELOG.md | 1 + .../config/diagnostics/plot/detailed.yaml | 1 + .../training/diagnostics/callbacks/plot.py | 3 + src/anemoi/training/diagnostics/plots.py | 67 +++++++++++++++++-- 4 files changed, 67 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5e3fbc3..3c4e0466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you! - Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) - Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) +- Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic). ### Added - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index d1ac8b0f..f7d4e78e 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -44,6 +44,7 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum # every_n_batches: 100 # Override for batch frequency + # min_delta: 0.01 # Minimum distance between two consecutive points sample_idx: ${diagnostics.plot.sample_idx} parameters: - z_500 diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 197a401e..3adcb55a 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -1018,6 +1018,7 @@ def __init__( config: OmegaConf, sample_idx: int, parameters: list[str], + min_delta: float | None = None, every_n_batches: int | None = None, ) -> None: """Initialise the PlotSpectrum callback. @@ -1036,6 +1037,7 @@ def __init__( super().__init__(config, every_n_batches=every_n_batches) self.sample_idx = sample_idx self.parameters = parameters + self.min_delta = min_delta @rank_zero_only def _plot( @@ -1070,6 +1072,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + min_delta=self.min_delta, ) self._output_figure( diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 45818b69..cd99ce37 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -138,6 +138,7 @@ def plot_power_spectrum( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + min_delta: float | None = None, ) -> Figure: """Plots power spectrum. @@ -156,6 +157,8 @@ def plot_power_spectrum( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + min_delta: float, optional + Minimum distance between lat/lon points, if None defaulted to 1km Returns ------- @@ -163,6 +166,7 @@ def plot_power_spectrum( The figure object handle. """ + min_delta = min_delta or 0.0003 n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) @@ -177,9 +181,17 @@ def plot_power_spectrum( # Calculate delta_lat on the projected grid delta_lat = abs(np.diff(pc_lat)) non_zero_delta_lat = delta_lat[delta_lat != 0] + min_delta_lat = np.min(abs(non_zero_delta_lat)) + + if min_delta_lat < min_delta: + LOGGER.warning( + "Min. distance between lat/lon points is < specified minimum distance. Defaulting to min_delta=%s.", + min_delta, + ) + min_delta_lat = min_delta # Define a regular grid for interpolation - n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) + n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / min_delta_lat)) n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1 regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon) regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat) @@ -313,14 +325,14 @@ def plot_histogram( # enforce the same binning for both histograms bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt)) bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt)) - hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max]) - hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max]) + hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max]) + hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max]) else: # enforce the same binning for both histograms bin_min = min(np.nanmin(yt), np.nanmin(yp)) bin_max = max(np.nanmax(yt), np.nanmax(yp)) - hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max]) - hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max]) + hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max]) + hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max]) # Visualization trick for tp if variable_name in precip_and_related_fields: @@ -623,6 +635,51 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", datashader=datashader, ) + elif vname in precip_and_related_fields: + # Create a custom colormap for precipitation + nws_precip_colors = cmap_precip + precip_colormap = ListedColormap(nws_precip_colors) + + # Defining the actual precipitation accumulation levels in mm + cummulation_lvls = clevels + norm = BoundaryNorm(cummulation_lvls, len(cummulation_lvls) + 1) + + # converting to mm from m + input_ *= 1000.0 + truth *= 1000.0 + pred *= 1000.0 + single_plot( + fig, + ax[0], + lon=lon, + lat=lat, + data=input_, + cmap=precip_colormap, + title=f"{vname} input", + datashader=datashader, + ) + single_plot( + fig, + ax[4], + lon=lon, + lat=lat, + data=pred - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} increment [pred - input]", + datashader=datashader, + ) + single_plot( + fig, + ax[5], + lon=lon, + lat=lat, + data=truth - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} persist err", + datashader=datashader, + ) else: single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( From 2df18a706b6b283f96bd9535988fc54c793de83c Mon Sep 17 00:00:00 2001 From: Icedoom <9220778+icedoom888@users.noreply.github.com> Date: Tue, 10 Dec 2024 23:01:54 +0100 Subject: [PATCH 12/23] [bugfix] loading only the weights of the checkpoint --- src/anemoi/training/train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index ed6c9e39..feb8dd7c 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -162,7 +162,7 @@ def model(self) -> GraphForecaster: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) LOGGER.info("Model initialised from scratch.") return model From 7e4a5f7e8329da75166e829cff57206263c901e7 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Fri, 13 Dec 2024 09:36:26 +0100 Subject: [PATCH 13/23] Support masking of unconnected nodes (LAM) (#171) Co-authored-by: Harrison Cook --- CHANGELOG.md | 4 + .../config/dataloader/native_grid.yaml | 6 ++ .../training/config/graph/limited_area.yaml | 24 ++++-- src/anemoi/training/data/datamodule.py | 19 +++- src/anemoi/training/data/dataset.py | 27 +++--- src/anemoi/training/data/grid_indices.py | 86 +++++++++++++++++++ src/anemoi/training/train/forecaster.py | 2 +- src/anemoi/training/train/train.py | 2 +- 8 files changed, 147 insertions(+), 23 deletions(-) create mode 100644 src/anemoi/training/data/grid_indices.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c4e0466..0e52fe11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ Keep it human-readable, your future self will thank you! ### Removed - Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) +### Added + +- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) + ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 ### Changed diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 9513ecc7..8a52029e 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -40,6 +40,12 @@ limit_batches: test: 20 predict: 20 +# set a custom mask for grid points. +# Useful for LAM (dropping unconnected nodes from forcing dataset) +grid_indices: + _target_: anemoi.training.data.grid_indices.FullGrid + nodes_name: ${graph.data} + # ============ # Dataloader definitions # These follow the anemoi-datasets patterns diff --git a/src/anemoi/training/config/graph/limited_area.yaml b/src/anemoi/training/config/graph/limited_area.yaml index a22405b6..93600cb1 100644 --- a/src/anemoi/training/config/graph/limited_area.yaml +++ b/src/anemoi/training/config/graph/limited_area.yaml @@ -10,14 +10,14 @@ nodes: node_builder: _target_: anemoi.graphs.nodes.ZarrDatasetNodes dataset: ${dataloader.training.dataset} - attributes: ${graph.attributes.nodes} + attributes: ${graph.attributes.data_nodes} # Hidden nodes hidden: node_builder: _target_: anemoi.graphs.nodes.LimitedAreaTriNodes # options: ZarrDatasetNodes, NPZFileNodes, TriNodes resolution: 5 # grid resolution for npz (o32, o48, ...) reference_node_name: ${graph.data} - mask_attr_name: cutout + mask_attr_name: cutout_mask edges: # Encoder configuration @@ -26,6 +26,9 @@ edges: edge_builders: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + cutoff_factor: 2 # only for cutoff method + source_mask_attr_name: boundary_mask attributes: ${graph.attributes.edges} # Processor configuration - source_name: ${graph.hidden} @@ -39,18 +42,29 @@ edges: target_name: ${graph.data} edge_builders: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges - target_mask_attr_name: cutout + target_mask_attr_name: cutout_mask num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} +post_processors: + - _target_: anemoi.graphs.processors.RemoveUnconnectedNodes + nodes_name: data + ignore: cutout_mask # optional + save_mask_indices_to_attr: indices_connected_nodes # optional + + attributes: - nodes: + data_nodes: area_weight: _target_: anemoi.graphs.nodes.attributes.AreaWeights # options: Area, Uniform norm: unit-max # options: l1, l2, unit-max, unit-sum, unit-std - cutout: + cutout_mask: _target_: anemoi.graphs.nodes.attributes.CutOutMask + boundary_mask: + _target_: anemoi.graphs.nodes.attributes.BooleanNot + masks: + _target_: anemoi.graphs.nodes.attributes.CutOutMask edges: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeLength diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index edcb7728..69538c67 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -7,15 +7,18 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from __future__ import annotations import logging from functools import cached_property +from typing import TYPE_CHECKING from typing import Callable import pytorch_lightning as pl from anemoi.datasets.data import open_dataset from anemoi.models.data_indices.collection import IndexCollection from anemoi.utils.dates import frequency_to_seconds +from hydra.utils import instantiate from omegaconf import DictConfig from omegaconf import OmegaConf from torch.utils.data import DataLoader @@ -25,11 +28,16 @@ LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from torch_geometric.data import HeteroData + + from anemoi.training.data.grid_indices import BaseGridIndices + class AnemoiDatasetsDataModule(pl.LightningDataModule): """Anemoi Datasets data module for PyTorch Lightning.""" - def __init__(self, config: DictConfig) -> None: + def __init__(self, config: DictConfig, graph_data: HeteroData) -> None: """Initialize Anemoi Datasets data module. Parameters @@ -41,6 +49,7 @@ def __init__(self, config: DictConfig) -> None: super().__init__() self.config = config + self.graph_data = graph_data # Set the maximum rollout to be expected self.rollout = ( @@ -72,6 +81,13 @@ def metadata(self) -> dict: def data_indices(self) -> IndexCollection: return IndexCollection(self.config, self.ds_train.name_to_index) + @cached_property + def grid_indices(self) -> type[BaseGridIndices]: + reader_group_size = self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model) + grid_indices = instantiate(self.config.dataloader.grid_indices, reader_group_size=reader_group_size) + grid_indices.setup(self.graph_data) + return grid_indices + @cached_property def timeincrement(self) -> int: """Determine the step size relative to the data frequency.""" @@ -164,6 +180,7 @@ def _get_dataset( multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, shuffle=shuffle, + grid_indices=self.grid_indices, label=label, effective_bs=effective_bs, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 69aa154c..32f6241d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -13,6 +13,7 @@ import os import random from functools import cached_property +from typing import TYPE_CHECKING from typing import Callable import numpy as np @@ -26,6 +27,9 @@ LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from anemoi.training.data.grid_indices import BaseGridIndices + class NativeGridDataset(IterableDataset): """Iterable dataset for AnemoI data on the arbitrary grids.""" @@ -33,6 +37,7 @@ class NativeGridDataset(IterableDataset): def __init__( self, data_reader: Callable, + grid_indices: type[BaseGridIndices], rollout: int = 1, multistep: int = 1, timeincrement: int = 1, @@ -46,6 +51,8 @@ def __init__( ---------- data_reader : Callable user function that opens and returns the zarr array data + grid_indices : Type[BaseGridIndices] + indices of the grid to keep. Defaults to None, which keeps all spatial indices. rollout : int, optional length of rollout window, by default 12 timeincrement : int, optional @@ -66,6 +73,7 @@ def __init__( self.rollout = rollout self.timeincrement = timeincrement + self.grid_indices = grid_indices # lazy init self.n_samples_per_epoch_total: int = 0 @@ -90,8 +98,6 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] - self.grid_dim: int = -1 - self.grid_size = self.data.shape[self.grid_dim] @cached_property def statistics(self) -> dict: @@ -160,14 +166,7 @@ def set_comm_group_info( self.reader_group_rank = reader_group_rank self.reader_group_size = reader_group_size - if self.reader_group_size > 1: - # get the grid shard size and start/end indices - grid_shard_size = self.grid_size // self.reader_group_size - self.grid_start = self.reader_group_rank * grid_shard_size - if self.reader_group_rank == self.reader_group_size - 1: - self.grid_end = self.grid_size - else: - self.grid_end = (self.reader_group_rank + 1) * grid_shard_size + assert self.reader_group_size >= 1, "reader_group_size must be positive" LOGGER.debug( "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " @@ -274,11 +273,9 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - if self.reader_group_size > 1: # read only a subset of the grid - x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] - else: # read the full grid - x = self.data[start : end : self.timeincrement, :, :, :] - + grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank) + x = self.data[start : end : self.timeincrement, :, :, :] + x = x[..., grid_shard_indices] # select the grid shard x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 diff --git a/src/anemoi/training/data/grid_indices.py b/src/anemoi/training/data/grid_indices.py new file mode 100644 index 00000000..4e6f3f68 --- /dev/null +++ b/src/anemoi/training/data/grid_indices.py @@ -0,0 +1,86 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING +from typing import Union + +if TYPE_CHECKING: + from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + +ArrayIndex = Union[slice, int, Sequence[int]] + + +class BaseGridIndices(ABC): + """Base class for custom grid indices.""" + + def __init__(self, nodes_name: str, reader_group_size: int) -> None: + self.nodes_name = nodes_name + self.reader_group_size = reader_group_size + + def setup(self, graph: HeteroData) -> None: + self.grid_size = self.compute_grid_size(graph) + + def split_seq_in_shards(self, reader_group_rank: int) -> tuple[int, int]: + """Get the indices to split a sequence into equal size shards.""" + grid_shard_size = self.grid_size // self.reader_group_size + grid_start = reader_group_rank * grid_shard_size + if reader_group_rank == self.reader_group_size - 1: + grid_end = self.grid_size + else: + grid_end = (reader_group_rank + 1) * grid_shard_size + + return slice(grid_start, grid_end) + + @abstractmethod + def compute_grid_size(self, graph: HeteroData) -> int: ... + + @abstractmethod + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: ... + + +class FullGrid(BaseGridIndices): + """The full grid is loaded.""" + + def compute_grid_size(self, graph: HeteroData) -> int: + return graph[self.nodes_name].num_nodes + + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: + return self.split_seq_in_shards(reader_group_rank) + + +class MaskedGrid(BaseGridIndices): + """Grid is masked based on a node attribute.""" + + def __init__(self, nodes_name: str, reader_group_size: int, node_attribute_name: str): + super().__init__(nodes_name, reader_group_size) + self.node_attribute_name = node_attribute_name + + def setup(self, graph: HeteroData) -> None: + LOGGER.info( + "The graph attribute %s of the %s nodes will be used to masking the spatial dimension.", + self.node_attribute_name, + self.nodes_name, + ) + self.spatial_indices = graph[self.nodes_name][self.node_attribute_name].squeeze().tolist() + super().setup(graph) + + def compute_grid_size(self, _graph: HeteroData) -> int: + return len(self.spatial_indices) + + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: + sequence_indices = self.split_seq_in_shards(reader_group_rank) + return self.spatial_indices[sequence_indices] diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index a31039b3..3421988c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -489,7 +489,7 @@ def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: torch.Tensor Allgathered (full) batch """ - grid_size = self.model.metadata["dataset"]["shape"][-1] + grid_size = len(self.latlons_data) # number of points if grid_size == batch.shape[-2]: return batch # already have the full grid diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index feb8dd7c..c236e4f5 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -81,7 +81,7 @@ def __init__(self, config: DictConfig) -> None: @cached_property def datamodule(self) -> AnemoiDatasetsDataModule: """DataModule instance and DataSets.""" - datamodule = AnemoiDatasetsDataModule(self.config) + datamodule = AnemoiDatasetsDataModule(self.config, self.graph_data) self.config.data.num_features = len(datamodule.ds_train.data.variables) LOGGER.info("Number of data variables: %s", str(len(datamodule.ds_train.data.variables))) LOGGER.debug("Variables: %s", str(datamodule.ds_train.data.variables)) From 881fa27c93df7abff3f4949e20c741c3cddcccd9 Mon Sep 17 00:00:00 2001 From: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:21:28 +0100 Subject: [PATCH 14/23] 161 documentation for anemoi training broken (#197) * wip: possibility to save checkpoint at the end of fitting loop * cleaning * fix for tzinfo type * remove changes in checkpoint file * add shield for info regarding docs building --- README.md | 3 +++ docs/conf.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e51c7697..c3e8468d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # anemoi-training +[![Documentation Status](https://readthedocs.org/projects/anemoi-training/badge/?version=latest)](https://anemoi-training.readthedocs.io/en/latest/?badge=latest) + + **DISCLAIMER** This project is **BETA** and will be **Experimental** for the foreseeable future. Interfaces and functionality are likely to change, and the project itself may be scrapped. diff --git a/docs/conf.py b/docs/conf.py index dc7a24d9..5d66dcfc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,7 +42,7 @@ author = "Anemoi contributors" -year = datetime.datetime.now(tz="UTC").year +year = datetime.datetime.now(tz=datetime.timezone.utc).year years = "2024" if year == 2024 else f"2024-{year}" copyright = f"{years}, Anemoi contributors" # noqa: A001 From a2d8e6da49523116f249fa061fc504efb3308f34 Mon Sep 17 00:00:00 2001 From: Jasper Wijnands <111133748+jswijnands@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:12:22 +0100 Subject: [PATCH 15/23] Feature: MSE metrics inside/outside regional domain for stretched grid models using scalar functionality (KNMI) (#199) * Initial commit w-MSE splitting functionality * New graph name to fix anemoi compatibility issues with existing graph files * Code fixes for limitedarea loss * code quality improvements * expand scalar to prevent index out of bound error * Updated callbacks for CERRA config * Change naming of stretched grid metrics * Add all four new metrics to CERRA config * Code quality improvements * Reduced model complexity of CERRA config to run on a single GPU * Temporary solution to be able to log the overall MSE inside/outside with variable scaling * Removed CERRA config that was used for testing * Update limited area mask scalar based on review comments Co-authored-by: Harrison Cook * Added change log entry * Update test script after adding 'all' variable --------- Co-authored-by: Harrison Cook --- CHANGELOG.md | 1 + src/anemoi/training/losses/limitedarea.py | 109 ++++++++++++++++++++++ src/anemoi/training/train/forecaster.py | 16 +++- tests/train/test_loss_scaling.py | 1 + 4 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 src/anemoi/training/losses/limitedarea.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e52fe11..c10b0664 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) +- Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199) ### Changed diff --git a/src/anemoi/training/losses/limitedarea.py b/src/anemoi/training/losses/limitedarea.py new file mode 100644 index 00000000..3fb24409 --- /dev/null +++ b/src/anemoi/training/losses/limitedarea.py @@ -0,0 +1,109 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import logging + +import torch + +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +LOGGER = logging.getLogger(__name__) + + +class WeightedMSELossLimitedArea(BaseWeightedLoss): + """Node-weighted MSE loss, calculated only within or outside the limited area. + + Further, the loss can be computed for the specified region (default), + or as the contribution to the overall loss. + """ + + name = "wmse" + + def __init__( + self, + node_weights: torch.Tensor, + inside_lam: bool = True, + wmse_contribution: bool = False, + ignore_nans: bool = False, + **kwargs, + ) -> None: + """Node- and feature weighted MSE Loss. + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + mask: torch.Tensor + the mask marking the indices of the regional data points (bool) + inside_lam: bool + compute the loss inside or outside the limited area, by default inside (True) + wmse_contribution: bool + compute loss as the contribution to the overall MSE, by default False + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + """ + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) + + self.inside_lam = inside_lam + self.wmse_contribution = wmse_contribution + + if inside_lam: + self.name += "_inside_lam" + else: + self.name += "_outside_lam" + if wmse_contribution: + self.name += "_contribution" + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + scalar_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted MSE loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: + feature indices (relative to full model output) of the features passed in pred and target + + Returns + ------- + torch.Tensor + Weighted MSE loss + """ + out = torch.square(pred - target) + + limited_area_mask = self.scalar.subset("limited_area_mask").get_scalar(out.ndim, out.device) + + if not self.inside_lam: + limited_area_mask = ~limited_area_mask + + if not self.wmse_contribution: + self.node_weights *= limited_area_mask[0, 0, :, 0] + + out *= limited_area_mask + + out = self.scale(out, scalar_indices, without_scalars=["limited_area_mask"]) + + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 3421988c..8c660722 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -98,12 +98,23 @@ def __init__( _, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) + # Check if the model is a stretched grid + if "lam_resolution" in getattr(config.graph.nodes.hidden, "node_builder", []): + mask_name = config.graph.nodes.hidden.node_builder.mask_attr_name + limited_area_mask = graph_data[config.graph.data][mask_name].squeeze().bool() + else: + limited_area_mask = torch.ones((1,)) + # Kwargs to pass to the loss function loss_kwargs = {"node_weights": self.node_weights} # Scalars to include in the loss function, must be of form (dim, scalar) # Add mask multiplying NaN locations with zero. At this stage at [[1]]. # Filled after first application of preprocessor. dimension=[-2, -1] (latlon, n_outputs). - scalars = {"variable": (-1, variable_scaling), "loss_weights_mask": ((-2, -1), torch.ones((1, 1)))} + scalars = { + "variable": (-1, variable_scaling), + "loss_weights_mask": ((-2, -1), torch.ones((1, 1))), + "limited_area_mask": (2, limited_area_mask), + } self.updated_loss_mask = False self.loss = self.get_loss_function(config.training.training_loss, scalars=scalars, **loss_kwargs) @@ -276,6 +287,9 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> if key in config.training.metrics: metric_ranges_validation[key] = [idx] + # Add the full list of output indices + metric_ranges_validation["all"] = data_indices.internal_model.output.full.tolist() + return metric_ranges, metric_ranges_validation @staticmethod diff --git a/tests/train/test_loss_scaling.py b/tests/train/test_loss_scaling.py index 8dd3772a..4887f3fe 100644 --- a/tests/train/test_loss_scaling.py +++ b/tests/train/test_loss_scaling.py @@ -146,6 +146,7 @@ def test_metric_range(fake_data: tuple[DictConfig, IndexCollection]) -> None: metric_range, metric_ranges_validation = GraphForecaster.get_val_metric_ranges(config, data_indices) del metric_range["all"] + del metric_ranges_validation["all"] expected_metric_range_validation = { "pl_y": [ From 318c14c7278e1f4d48d60a1dc3711b1e3adebc77 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 13 Dec 2024 13:19:11 +0000 Subject: [PATCH 16/23] Remove excess metadata from variables (#201) Fix remove `metadata.dataset.specific.forward.forward.attrs.variables_metadata.` from logger params --- src/anemoi/training/diagnostics/mlflow/logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7b8e5f53..03a4b6de 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -510,6 +510,7 @@ def _clean_params(params: dict[str, Any]) -> dict[str, Any]: "diagnostics", "metadata.config", "metadata.dataset.variables_metadata", + "metadata.dataset.specific.forward.forward.attrs.variables_metadata", ] keys_to_remove = [key for key in params if any(key.startswith(prefix) for prefix in prefixes_to_remove)] for key in keys_to_remove: From c99069ee00147e889c947f712720a6352a9490e9 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:20:08 +0000 Subject: [PATCH 17/23] Store numpy arrays in checkpoints (#174) * Store numpy arrays in checkpoints * changelog * fix failing test * feat: save remove unconnected mask in supporting_arrays * feat: add output mask to supporting arrays * fix: keep get_node_weights as staticmethod --------- Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: Mario Santa Cruz --- CHANGELOG.md | 1 + src/anemoi/training/data/datamodule.py | 4 ++++ src/anemoi/training/data/dataset.py | 5 +++++ src/anemoi/training/data/grid_indices.py | 16 +++++++++++++--- .../training/diagnostics/callbacks/checkpoint.py | 10 ++++++++-- src/anemoi/training/train/forecaster.py | 15 +++++++++------ src/anemoi/training/train/train.py | 5 +++++ src/anemoi/training/utils/masks.py | 8 ++++++++ tests/diagnostics/test_checkpoint.py | 1 + 9 files changed, 54 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c10b0664..ccbad57d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ Keep it human-readable, your future self will thank you! ### Added +- Add supporting arrrays (numpy) to checkpoint - Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 69538c67..e0502acd 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -77,6 +77,10 @@ def statistics(self) -> dict: def metadata(self) -> dict: return self.ds_train.metadata + @cached_property + def supporting_arrays(self) -> dict: + return self.ds_train.supporting_arrays | self.grid_indices.supporting_arrays + @cached_property def data_indices(self) -> IndexCollection: return IndexCollection(self.config, self.ds_train.name_to_index) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 32f6241d..431f0227 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -109,6 +109,11 @@ def metadata(self) -> dict: """Return dataset metadata.""" return self.data.metadata() + @cached_property + def supporting_arrays(self) -> dict: + """Return dataset supporting_arrays.""" + return self.data.supporting_arrays() + @cached_property def name_to_index(self) -> dict: """Return dataset statistics.""" diff --git a/src/anemoi/training/data/grid_indices.py b/src/anemoi/training/data/grid_indices.py index 4e6f3f68..91c638ed 100644 --- a/src/anemoi/training/data/grid_indices.py +++ b/src/anemoi/training/data/grid_indices.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING from typing import Union +import numpy as np + if TYPE_CHECKING: from torch_geometric.data import HeteroData @@ -45,6 +47,10 @@ def split_seq_in_shards(self, reader_group_rank: int) -> tuple[int, int]: return slice(grid_start, grid_end) + @property + def supporting_arrays(self) -> dict: + return {} + @abstractmethod def compute_grid_size(self, graph: HeteroData) -> int: ... @@ -75,12 +81,16 @@ def setup(self, graph: HeteroData) -> None: self.node_attribute_name, self.nodes_name, ) - self.spatial_indices = graph[self.nodes_name][self.node_attribute_name].squeeze().tolist() + self.grid_indices = graph[self.nodes_name][self.node_attribute_name].squeeze().tolist() super().setup(graph) + @property + def supporting_arrays(self) -> dict: + return {"grid_indices": np.array(self.grid_indices, dtype=np.int64)} + def compute_grid_size(self, _graph: HeteroData) -> int: - return len(self.spatial_indices) + return len(self.grid_indices) def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: sequence_indices = self.split_seq_in_shards(reader_group_rank) - return self.spatial_indices[sequence_indices] + return self.grid_indices[sequence_indices] diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py index 2aba5246..88c8ecf2 100644 --- a/src/anemoi/training/diagnostics/callbacks/checkpoint.py +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -149,16 +149,22 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s tmp_metadata = model.metadata model.metadata = None - metadata = dict(**tmp_metadata) + tmp_supporting_arrays = model.supporting_arrays + model.supporting_arrays = None + + # Make sure we don't accidentally modidy these + metadata = tmp_metadata.copy() + supporting_arrays = tmp_supporting_arrays.copy() inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) torch.save(model, inference_checkpoint_filepath) - save_metadata(inference_checkpoint_filepath, metadata) + save_metadata(inference_checkpoint_filepath, metadata, supporting_arrays=supporting_arrays) model.config = save_config model.metadata = tmp_metadata + model.supporting_arrays = tmp_supporting_arrays self._last_global_step_saved = trainer.global_step diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 8c660722..6a6ec58a 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -50,6 +50,7 @@ def __init__( statistics: dict, data_indices: IndexCollection, metadata: dict, + supporting_arrays: dict, ) -> None: """Initialize graph neural network forecaster. @@ -65,16 +66,24 @@ def __init__( Indices of the training data, metadata : dict Provenance information + supporting_arrays : dict + Supporting NumPy arrays to store in the checkpoint """ super().__init__() graph_data = graph_data.to(self.device) + if config.model.get("output_mask", None) is not None: + self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) + else: + self.output_mask = NoOutputMask() + self.model = AnemoiModelInterface( statistics=statistics, data_indices=data_indices, metadata=metadata, + supporting_arrays=supporting_arrays | self.output_mask.supporting_arrays, graph_data=graph_data, config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))), ) @@ -85,11 +94,6 @@ def __init__( self.latlons_data = graph_data[config.graph.data].x self.node_weights = self.get_node_weights(config, graph_data) - - if config.model.get("output_mask", None) is not None: - self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) - else: - self.output_mask = NoOutputMask() self.node_weights = self.output_mask.apply(self.node_weights, dim=0, fill_value=0.0) self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled @@ -334,7 +338,6 @@ def get_variable_scaling( @staticmethod def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: node_weighting = instantiate(config.training.node_loss_weights) - return node_weighting.weights(graph_data) def set_model_comm_group( diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index c236e4f5..694fb2da 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -150,6 +150,7 @@ def model(self) -> GraphForecaster: "graph_data": self.graph_data, "metadata": self.metadata, "statistics": self.datamodule.statistics, + "supporting_arrays": self.supporting_arrays, } model = GraphForecaster(**kwargs) @@ -249,6 +250,10 @@ def metadata(self) -> dict: }, ) + @cached_property + def supporting_arrays(self) -> dict: + return self.datamodule.supporting_arrays + @cached_property def profiler(self) -> PyTorchProfiler | None: """Returns a pytorch profiler object, if profiling is enabled.""" diff --git a/src/anemoi/training/utils/masks.py b/src/anemoi/training/utils/masks.py index fd0581a0..75fc888f 100644 --- a/src/anemoi/training/utils/masks.py +++ b/src/anemoi/training/utils/masks.py @@ -22,6 +22,10 @@ class BaseMask: """Base class for masking model output.""" + @property + def supporting_arrays(self) -> dict: + return {} + @abstractmethod def apply(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: error_message = "Method `apply` must be implemented in subclass." @@ -39,6 +43,10 @@ class Boolean1DMask(BaseMask): def __init__(self, values: torch.Tensor) -> None: self.mask = values.bool().squeeze() + @property + def supporting_arrays(self) -> dict: + return {"output_mask": self.mask.numpy()} + def broadcast_like(self, x: torch.Tensor, dim: int) -> torch.Tensor: assert x.shape[dim] == len( self.mask, diff --git a/tests/diagnostics/test_checkpoint.py b/tests/diagnostics/test_checkpoint.py index 63e6ccc9..354bda96 100644 --- a/tests/diagnostics/test_checkpoint.py +++ b/tests/diagnostics/test_checkpoint.py @@ -36,6 +36,7 @@ def __init__(self, *, config: DotDict, metadata: dict): self.config = config self.metadata = metadata + self.supporting_arrays = {} self.fc1 = nn.Linear(32, 5) self.fc2 = nn.Linear(5, 1) self.relu = nn.ReLU() From 90978df01acbbc706deecce5753c20ef303af83e Mon Sep 17 00:00:00 2001 From: Jasper Wijnands <111133748+jswijnands@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:38:34 +0100 Subject: [PATCH 18/23] Fix/stretched grid check (#204) * Identify stretched grid from graph * Code quality improvements * Added change log entry --- CHANGELOG.md | 1 + src/anemoi/training/train/forecaster.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccbad57d..2fb5c36f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you! - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) - Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) - Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic). +- Identify stretched grid models based on graph rather than configuration file [#204](https://github.com/ecmwf/anemoi-training/pull/204) ### Added - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting. diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 6a6ec58a..08db6cd9 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -103,7 +103,7 @@ def __init__( _, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) # Check if the model is a stretched grid - if "lam_resolution" in getattr(config.graph.nodes.hidden, "node_builder", []): + if graph_data["hidden"].node_type == "StretchedTriNodes": mask_name = config.graph.nodes.hidden.node_builder.mask_attr_name limited_area_mask = graph_data[config.graph.data][mask_name].squeeze().bool() else: From 15312f9b2bd868159050a04b34fb132c86db3b3d Mon Sep 17 00:00:00 2001 From: Jasper Wijnands <111133748+jswijnands@users.noreply.github.com> Date: Tue, 17 Dec 2024 10:21:46 +0100 Subject: [PATCH 19/23] Fix/stretched grid check (#204) * Identify stretched grid from graph * Code quality improvements * Added change log entry From 6dd537ed24739a2e21b3c7c09d4e9732b7f26010 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 09:23:00 +0000 Subject: [PATCH 20/23] Fix 'all' validation metrics (#202) * Allow metrics in normalised space * Update ScaleTensor - Generalise `without` and `subset` - Allow removal by dim * Subset within the loss function * Use internal model mapping --- CHANGELOG.md | 1 + docs/modules/losses.rst | 27 +++- .../training/config/training/default.yaml | 20 ++- src/anemoi/training/losses/utils.py | 127 ++++++++++++++++-- src/anemoi/training/losses/weightedloss.py | 16 +-- src/anemoi/training/train/forecaster.py | 71 +++++++--- tests/train/test_scalar.py | 61 +++++++-- 7 files changed, 272 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fb5c36f..e7a0c56e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ Keep it human-readable, your future self will thank you! - Add supporting arrrays (numpy) to checkpoint - Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) +- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 diff --git a/docs/modules/losses.rst b/docs/modules/losses.rst index 32ad9783..f8161a04 100644 --- a/docs/modules/losses.rst +++ b/docs/modules/losses.rst @@ -73,11 +73,36 @@ Currently, the following scalars are available for use: ******************** Validation metrics as defined in the config file at -``config.training.validation_metrics`` follow the same initialise +``config.training.validation_metrics`` follow the same initialisation behaviour as the loss function, but can be a list. In this case all losses are calculated and logged as a dictionary with the corresponding name +Scaling Validation Losses +========================= + +Validation metrics can **not** by default be scaled by scalars across +the variable dimension, but can be by all other scalars. If you want to +scale a validation metric by the variable weights, it must be added to +`config.training.scale_validation_metrics`. + +These metrics are then kept in the normalised, preprocessed space, and +thus the indexing of scalars aligns with the indexing of the tensors. + +By default, only `all` is kept in the normalised space and scaled. + +.. code:: yaml + + # List of validation metrics to keep in normalised space, and scalars to be applied + # Use '*' in reference all metrics, or a list of metric names. + # Unlike above, variable scaling is possible due to these metrics being + # calculated in the same way as the training loss, within the internal model space. + scale_validation_metrics: + scalars_to_apply: ['variable'] + metrics: + - 'all' + # - "*" + *********************** Custom Loss Functions *********************** diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index c397ff75..6c915eb5 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -58,16 +58,32 @@ loss_gradient_scaling: False # Validation metrics calculation, # This may be a list, in which case all metrics will be calculated -# and logged according to their name +# and logged according to their name. +# These metrics are calculated in the output model space, and thus +# have undergone postprocessing. validation_metrics: # loss class to initialise - _target_: anemoi.training.losses.mse.WeightedMSELoss # Scalars to include in loss calculation - # Available scalars include, 'variable' + # Cannot scale over the variable dimension due to possible remappings. + # Available scalars include: + # - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function + # Use the `scale_validation_metrics` section to variable scale. scalars: [] # other kwargs ignore_nans: True +# List of validation metrics to keep in normalised space, and scalars to be applied +# Use '*' in reference all metrics, or a list of metric names. +# Unlike above, variable scaling is possible due to these metrics being +# calculated in the same way as the training loss, within the internal model space. +scale_validation_metrics: + scalars_to_apply: ['variable'] + metrics: + - 'all' + # - "*" + + # length of the "rollout" window (see Keisler's paper) rollout: start: 1 diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index e98e0bfe..4df5712f 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -186,7 +186,7 @@ def add_scalar( scalar: torch.Tensor, *, name: str | None = None, - ) -> None: + ) -> ScaleTensor: """Add new scalar to be applied along `dimension`. Dimension can be a single int even for a multi-dimensional scalar, @@ -201,6 +201,11 @@ def add_scalar( Scalar tensor to apply name : str | None, optional Name of the scalar, by default None + + Returns + ------- + ScaleTensor + ScaleTensor with the scalar removed """ if not isinstance(scalar, torch.Tensor): scalar = torch.tensor([scalar]) if isinstance(scalar, (int, float)) else torch.tensor(scalar) @@ -229,6 +234,62 @@ def add_scalar( self.tensors[name] = (dimension, scalar) self._specified_dimensions[name] = dimension + return self + + def remove_scalar(self, scalar_to_remove: str | int) -> ScaleTensor: + """ + Remove scalar from ScaleTensor. + + Parameters + ---------- + scalar_to_remove : str | int + Name or index of tensor to remove + + Raises + ------ + ValueError + If the scalar is not in the scalars + + Returns + ------- + ScaleTensor + ScaleTensor with the scalar removed + """ + for scalar_to_pop in self.subset(scalar_to_remove).tensors: + self.tensors.pop(scalar_to_pop) + self._specified_dimensions.pop(scalar_to_pop) + return self + + def freeze_state(self) -> FrozenStateRecord: # noqa: F821 + """ + Freeze the state of the Scalar with a context manager. + + Any changes made will be reverted on exit. + + Returns + ------- + FrozenStateRecord + Context manager to freeze the state of this ScaleTensor + """ + record_of_scalars: dict = self.tensors.copy() + + class FrozenStateRecord: + """Freeze the state of the ScaleTensor. Any changes will be reverted on exit.""" + + def __enter__(self): + pass + + def __exit__(context_self, *a): # noqa: N805 + for key in list(self.tensors.keys()): + if key not in record_of_scalars: + self.remove_scalar(key) + + for key in record_of_scalars: + if key not in self: + self.add_scalar(*record_of_scalars[key], name=key) + + return FrozenStateRecord() + def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None: """Update an existing scalar maintaining original dimensions. @@ -300,32 +361,34 @@ def update(self, updated_scalars: dict[str, torch.Tensor] | None = None, overrid for name, tensor in kwargs.items(): self.update_scalar(name, tensor, override=override) - def subset(self, scalars: str | Sequence[str]) -> ScaleTensor: - """Get subset of the scalars, filtering by name. - - See `.subset_by_dim` for subsetting by affected dimensions. + def subset(self, scalar_identifier: str | Sequence[str] | int | Sequence[int]) -> ScaleTensor: + """Get subset of the scalars, filtering by name or dimension. Parameters ---------- - scalars : str | Sequence[str] - Name/s of the scalars to get + scalar_identifier : str | Sequence[str] | int | Sequence[int] + Name/s or dimension/s of the scalars to get Returns ------- ScaleTensor Subset of self """ - if isinstance(scalars, str): - scalars = [scalars] - return ScaleTensor(**{name: self.tensors[name] for name in scalars}) + if isinstance(scalar_identifier, (str, int)): + scalar_identifier = [scalar_identifier] + if any(isinstance(scalar, int) for scalar in scalar_identifier): + return self.subset_by_dim(scalar_identifier) + return self.subset_by_str(scalar_identifier) - def without(self, scalars: str | Sequence[str]) -> ScaleTensor: - """Get subset of the scalars, filtering out by name. + def subset_by_str(self, scalars: str | Sequence[str]) -> ScaleTensor: + """Get subset of the scalars, filtering by name. + + See `.subset_by_dim` for subsetting by affected dimensions. Parameters ---------- scalars : str | Sequence[str] - Name/s of the scalars to exclude + Name/s of the scalars to get Returns ------- @@ -334,7 +397,7 @@ def without(self, scalars: str | Sequence[str]) -> ScaleTensor: """ if isinstance(scalars, str): scalars = [scalars] - return ScaleTensor(**{name: tensor for name, tensor in self.tensors.items() if name not in scalars}) + return ScaleTensor(**{name: self.tensors[name] for name in scalars}) def subset_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: """Get subset of the scalars, filtering by dimension. @@ -364,6 +427,42 @@ def subset_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: return ScaleTensor(**subset_scalars) + def without(self, scalar_identifier: str | Sequence[str] | int | Sequence[int]) -> ScaleTensor: + """Get subset of the scalars, filtering out by name or dimension. + + Parameters + ---------- + scalar_identifier : str | Sequence[str] | int | Sequence[int] + Name/s or dimension/s of the scalars to exclude + + Returns + ------- + ScaleTensor + Subset of self + """ + if isinstance(scalar_identifier, (str, int)): + scalar_identifier = [scalar_identifier] + if any(isinstance(scalar, int) for scalar in scalar_identifier): + return self.without_by_dim(scalar_identifier) + return self.without_by_str(scalar_identifier) + + def without_by_str(self, scalars: str | Sequence[str]) -> ScaleTensor: + """Get subset of the scalars, filtering out by name. + + Parameters + ---------- + scalars : str | Sequence[str] + Name/s of the scalars to exclude + + Returns + ------- + ScaleTensor + Subset of self + """ + if isinstance(scalars, str): + scalars = [scalars] + return ScaleTensor(**{name: tensor for name, tensor in self.tensors.items() if name not in scalars}) + def without_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: """Get subset of the scalars, filtering out by dimension. diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py index 7ed97b21..5f71772a 100644 --- a/src/anemoi/training/losses/weightedloss.py +++ b/src/anemoi/training/losses/weightedloss.py @@ -72,7 +72,7 @@ def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = Fal def scale( self, x: torch.Tensor, - scalar_indices: tuple[int, ...] | None = None, + subset_indices: tuple[int, ...] | None = None, *, without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: @@ -82,8 +82,8 @@ def scale( ---------- x : torch.Tensor Tensor to be scaled, shape (bs, ensemble, lat*lon, n_outputs) - scalar_indices: tuple[int,...], optional - Indices to subset the calculated scalar with, by default None. + subset_indices: tuple[int,...], optional + Indices to subset the calculated scalar and `x` tensor with, by default None. without_scalars: list[str] | list[int] | None, optional list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. By default None @@ -93,8 +93,11 @@ def scale( torch.Tensor Scaled error tensor """ + if subset_indices is None: + subset_indices = [Ellipsis] + if len(self.scalar) == 0: - return x + return x[subset_indices] scale_tensor = self.scalar if without_scalars is not None and len(without_scalars) > 0: @@ -105,11 +108,8 @@ def scale( scalar = scale_tensor.get_scalar(x.ndim).to(x) - if scalar_indices is None: - return x * scalar - scalar = scalar.expand_as(x) - return x * scalar[scalar_indices] + return x[subset_indices] * scalar[subset_indices] def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor: """Scale a tensor by the node_weights. diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 08db6cd9..0059d90a 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -87,7 +87,7 @@ def __init__( graph_data=graph_data, config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))), ) - + self.config = config self.data_indices = data_indices self.save_hyperparameters() @@ -100,7 +100,7 @@ def __init__( variable_scaling = self.get_variable_scaling(config, data_indices) - _, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) + self.internal_metric_ranges, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) # Check if the model is a stretched grid if graph_data["hidden"].node_type == "StretchedTriNodes": @@ -112,23 +112,24 @@ def __init__( # Kwargs to pass to the loss function loss_kwargs = {"node_weights": self.node_weights} # Scalars to include in the loss function, must be of form (dim, scalar) + # Use -1 for the variable dimension, -2 for the latlon dimension # Add mask multiplying NaN locations with zero. At this stage at [[1]]. # Filled after first application of preprocessor. dimension=[-2, -1] (latlon, n_outputs). - scalars = { + self.scalars = { "variable": (-1, variable_scaling), "loss_weights_mask": ((-2, -1), torch.ones((1, 1))), "limited_area_mask": (2, limited_area_mask), } self.updated_loss_mask = False - self.loss = self.get_loss_function(config.training.training_loss, scalars=scalars, **loss_kwargs) + self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs) - assert isinstance(self.loss, torch.nn.Module) and not isinstance( + assert isinstance(self.loss, BaseWeightedLoss) and not isinstance( self.loss, torch.nn.ModuleList, - ), f"Loss function must be a `torch.nn.Module`, not a {type(self.loss).__name__!r}" + ), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}" - self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=scalars, **loss_kwargs) + self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=self.scalars, **loss_kwargs) if not isinstance(self.metrics, torch.nn.ModuleList): self.metrics = torch.nn.ModuleList([self.metrics]) @@ -176,7 +177,7 @@ def get_loss_function( config: DictConfig, scalars: Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None] = None, # noqa: FA100 **kwargs, - ) -> Union[torch.nn.Module, torch.nn.ModuleList]: # noqa: FA100 + ) -> Union[BaseWeightedLoss, torch.nn.ModuleList]: # noqa: FA100 """Get loss functions from config. Can be ModuleList if multiple losses are specified. @@ -196,7 +197,7 @@ def get_loss_function( Returns ------- - Union[torch.nn.Module, torch.nn.ModuleList] + Union[BaseWeightedLoss, torch.nn.ModuleList] Loss function, or list of metrics Raises @@ -255,6 +256,8 @@ def training_weights_for_imputed_variables( loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) # update scaler with loss_weights_mask retrieved from preprocessors self.loss.update_scalar(scalar=loss_weights_mask.cpu(), name="loss_weights_mask") + self.scalars["loss_weights_mask"] = ((-2, -1), loss_weights_mask.cpu()) + self.updated_loss_mask = True @staticmethod @@ -292,7 +295,7 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> metric_ranges_validation[key] = [idx] # Add the full list of output indices - metric_ranges_validation["all"] = data_indices.internal_model.output.full.tolist() + metric_ranges_validation["all"] = data_indices.model.output.full.tolist() return metric_ranges, metric_ranges_validation @@ -569,11 +572,36 @@ def calculate_val_metrics( continue for mkey, indices in self.val_metric_ranges.items(): - metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( - y_pred_postprocessed[..., indices], - y_postprocessed[..., indices], - scalar_indices=[..., indices] if -1 in metric.scalar else None, - ) + if "scale_validation_metrics" in self.config.training and ( + mkey in self.config.training.scale_validation_metrics.metrics + or "*" in self.config.training.scale_validation_metrics.metrics + ): + with metric.scalar.freeze_state(): + for key in self.config.training.scale_validation_metrics.scalars_to_apply: + metric.add_scalar(*self.scalars[key], name=key) + + # Use internal model space indices + internal_model_indices = self.internal_metric_ranges[mkey] + + metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( + y_pred, + y, + scalar_indices=[..., internal_model_indices], + ) + else: + if -1 in metric.scalar: + exception_msg = ( + "Validation metrics cannot be scaled over the variable dimension" + " in the post processed space. Please specify them in the config" + " at `scale_validation_metrics`." + ) + raise ValueError(exception_msg) + + metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( + y_pred_postprocessed, + y_postprocessed, + scalar_indices=[..., indices], + ) return metrics @@ -620,7 +648,20 @@ def on_train_epoch_end(self) -> None: self.rollout = min(self.rollout, self.rollout_max) def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: + """ + Calculate the loss over a validation batch using the training loss function. + + Parameters + ---------- + batch : torch.Tensor + Validation batch + batch_idx : int + Batch inces + Returns + ------- + None + """ with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) diff --git a/tests/train/test_scalar.py b/tests/train/test_scalar.py index 9a37e353..b87619d2 100644 --- a/tests/train/test_scalar.py +++ b/tests/train/test_scalar.py @@ -168,33 +168,72 @@ def test_scale_tensor_two_dim( torch.testing.assert_close(scale.scale(input_tensor), output) -def test_scalar_subset() -> None: - scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) - subset = scale.subset("test") +@pytest.mark.parametrize("subset_id", ["test", 0]) +def test_scalar_subset(subset_id) -> None: # noqa: ANN001 + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.subset(subset_id) assert "test" in subset assert "wow" not in subset assert 0 in subset + assert 1 not in subset + + +@pytest.mark.parametrize("without_id", ["test", 0]) +def test_scalar_subset_without(without_id) -> None: # noqa: ANN001 + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.without(without_id) + assert "test" not in subset + assert "wow" in subset + assert 1 in subset -def test_scalar_subset_without() -> None: +@pytest.mark.parametrize("without_id", ["test"]) +def test_scalar_subset_without_overlap(without_id) -> None: # noqa: ANN001 scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) - subset = scale.without("test") + subset = scale.without(without_id) assert "test" not in subset assert "wow" in subset assert 0 in subset -def test_scalar_subset_by_dim() -> None: +def test_scalar_remove_str() -> None: scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) - subset = scale.subset_by_dim(0) + subset = scale.remove_scalar("wow") assert "test" in subset assert "wow" not in subset assert 0 in subset -def test_scalar_subset_by_dim_without() -> None: +def test_scalar_remove_int() -> None: scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) - subset = scale.without_by_dim(0) - assert "test" not in subset + subset = scale.remove_scalar(1) + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + assert 1 not in subset + + +def test_scalar_freeze_str() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + with scale.freeze_state(): + subset = scale.remove_scalar("wow") + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + assert 1 not in subset + + assert "wow" in subset + assert 1 in subset + + +def test_scalar_freeze_int() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + with scale.freeze_state(): + subset = scale.remove_scalar(1) + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + assert 1 not in subset + assert "wow" in subset - assert 0 not in subset + assert 1 in subset From 7887b954aa352f5cb46431a76431731ab4881a79 Mon Sep 17 00:00:00 2001 From: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:33:23 +0100 Subject: [PATCH 21/23] Fix/183 not saving last checkpoint when max steps area reached (#191) * wip: possibility to save checkpoint at the end of fitting loop * cleaning * update changelog and rename function * wip * wip * add docstrings and comments for better readability * remove change --- CHANGELOG.md | 1 + .../diagnostics/callbacks/checkpoint.py | 30 ++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7a0c56e..23894907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you! - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) - Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) - Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic). +- Fixes to checkpoint saving - ensure last checkpoint if saving when using max_steps [#191] (https://github.com/ecmwf/anemoi-training/pull/191) - Identify stretched grid models based on graph rather than configuration file [#204](https://github.com/ecmwf/anemoi-training/pull/204) ### Added diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py index 88c8ecf2..4da76327 100644 --- a/src/anemoi/training/diagnostics/callbacks/checkpoint.py +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -43,6 +43,7 @@ def __init__(self, config: OmegaConf, **kwargs: dict) -> None: """ super().__init__(**kwargs) + self.config = config self.start = time.time() self._model_metadata = None @@ -76,6 +77,34 @@ def model_metadata(self, model: torch.nn.Module) -> dict: return self._model_metadata + def _adjust_epoch_progress(self, trainer: pl.Trainer) -> None: + """ + Adjust the epoch progress when saving a mid-epoch checkpoint. + + Since Pytorch Lightning advances one epoch at end of training (on_train-end), + we need to correct the checkpoint epoch progress to avoid inconsistencies. + """ + trainer.fit_loop.epoch_progress.current.processed = trainer.fit_loop.epoch_progress.current.processed - 1 + trainer.fit_loop.epoch_progress.current.completed = trainer.fit_loop.epoch_progress.current.completed - 1 + trainer.fit_loop.epoch_progress.total.processed = trainer.fit_loop.epoch_progress.total.processed - 1 + trainer.fit_loop.epoch_progress.total.completed = trainer.fit_loop.epoch_progress.total.completed - 1 + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """ + Save the last checkpoint at the end of training. + + If the candidates aren't better than the last checkpoint, then no checkpoints are saved. + Note - this method if triggered when using max_epochs, it won't save any checkpoints + since the monitor candidates won't show any changes with regard the the 'on_train_epoch_end' hook. + """ + del pl_module + if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): + if trainer.fit_loop.epoch_progress.current.completed == trainer.fit_loop.epoch_progress.current.ready: + self._adjust_epoch_progress(trainer) + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + def tracker_metadata(self, trainer: pl.Trainer) -> dict: if self._tracker_metadata is not None: return {self._tracker_name: self._tracker_metadata} @@ -169,7 +198,6 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s self._last_global_step_saved = trainer.global_step trainer.strategy.barrier() - # saving checkpoint used for pytorch-lightning based training trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) From 38b75fadd5fcdf935547e8239180a9280158a12d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Wed, 18 Dec 2024 12:36:53 +0100 Subject: [PATCH 22/23] fix(config): Default configs for Stretched & Limited Area Graphs (#173) * fix: log warning when no trainable tensors * feat: propose default configs * fix: update changelog * fix: update defaults * fix: update loss_scaling function * fix: update to latest config * fix: remove hidden node attributes (stretched) * fix: update plot of trainable params * fix: style --- CHANGELOG.md | 14 ++--- .../training/config/graph/limited_area.yaml | 11 ++-- .../training/config/graph/stretched_grid.yaml | 18 +++--- src/anemoi/training/config/lam.yaml | 36 +++++++++++ src/anemoi/training/config/stretched.yaml | 37 ++++++++++++ .../training/diagnostics/callbacks/plot.py | 59 +++++++++++++------ src/anemoi/training/diagnostics/plots.py | 38 ++++++------ 7 files changed, 156 insertions(+), 57 deletions(-) create mode 100644 src/anemoi/training/config/lam.yaml create mode 100644 src/anemoi/training/config/stretched.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index 23894907..286fd915 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you! - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) - Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) - Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic). +- GraphTrainableParameters callback will log a warning when no trainable parameters are specified [#173](https://github.com/ecmwf/anemoi-training/pull/173) - Fixes to checkpoint saving - ensure last checkpoint if saving when using max_steps [#191] (https://github.com/ecmwf/anemoi-training/pull/191) - Identify stretched grid models based on graph rather than configuration file [#204](https://github.com/ecmwf/anemoi-training/pull/204) @@ -23,26 +24,25 @@ Keep it human-readable, your future self will thank you! - Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`. Used for experiment reproducibility across different computing configurations. - Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) +- Added default configuration files for stretched grid and limited area model experiments [173](https://github.com/ecmwf/anemoi-training/pull/173) - Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199) +- Add supporting arrrays (numpy) to checkpoint +- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) +- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) ### Changed ### Removed - Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) -### Added - -- Add supporting arrrays (numpy) to checkpoint -- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) -- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) - ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 ### Changed - Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) ### Fixed -- Update `n_pixel` used by datashader to better adapt across resolutions #152 + +- Update `n_pixel` used by datashader to better adapt across resolutions [#152](https://github.com/ecmwf/anemoi-training/pull/152) - Fixed bug in power spectra plotting for the n320 resolution. - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) diff --git a/src/anemoi/training/config/graph/limited_area.yaml b/src/anemoi/training/config/graph/limited_area.yaml index 93600cb1..90b42637 100644 --- a/src/anemoi/training/config/graph/limited_area.yaml +++ b/src/anemoi/training/config/graph/limited_area.yaml @@ -10,7 +10,7 @@ nodes: node_builder: _target_: anemoi.graphs.nodes.ZarrDatasetNodes dataset: ${dataloader.training.dataset} - attributes: ${graph.attributes.data_nodes} + attributes: ${graph.attributes.nodes} # Hidden nodes hidden: node_builder: @@ -26,8 +26,8 @@ edges: edge_builders: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method - - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges - cutoff_factor: 2 # only for cutoff method + - _target_: anemoi.graphs.edges.CutOffEdges # connects only boundary nodes + cutoff_factor: 1.5 # only for cutoff method source_mask_attr_name: boundary_mask attributes: ${graph.attributes.edges} # Processor configuration @@ -46,16 +46,15 @@ edges: num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} - post_processors: - _target_: anemoi.graphs.processors.RemoveUnconnectedNodes nodes_name: data ignore: cutout_mask # optional save_mask_indices_to_attr: indices_connected_nodes # optional - attributes: - data_nodes: + nodes: + # Attributes for data nodes area_weight: _target_: anemoi.graphs.nodes.attributes.AreaWeights # options: Area, Uniform norm: unit-max # options: l1, l2, unit-max, unit-sum, unit-std diff --git a/src/anemoi/training/config/graph/stretched_grid.yaml b/src/anemoi/training/config/graph/stretched_grid.yaml index a92f319b..ae5a9342 100644 --- a/src/anemoi/training/config/graph/stretched_grid.yaml +++ b/src/anemoi/training/config/graph/stretched_grid.yaml @@ -11,12 +11,7 @@ nodes: node_builder: _target_: anemoi.graphs.nodes.ZarrDatasetNodes dataset: ${dataloader.training.dataset} - attributes: - area_weight: - _target_: anemoi.graphs.nodes.attributes.AreaWeights - norm: unit-max - cutout: - _target_: anemoi.graphs.nodes.attributes.CutOutMask + attributes: ${graph.attributes.nodes} hidden: node_builder: _target_: anemoi.graphs.nodes.StretchedTriNodes @@ -25,10 +20,6 @@ nodes: reference_node_name: ${graph.data} mask_attr_name: cutout margin_radius_km: 11 - attributes: - area_weights: - _target_: anemoi.graphs.nodes.attributes.AreaWeights - norm: unit-max edges: # Encoder @@ -54,6 +45,13 @@ edges: attributes: ${graph.attributes.edges} attributes: + nodes: + # Attributes for data nodes + area_weight: + _target_: anemoi.graphs.nodes.attributes.AreaWeights + norm: unit-max + cutout: + _target_: anemoi.graphs.nodes.attributes.CutOutMask edges: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeLength diff --git a/src/anemoi/training/config/lam.yaml b/src/anemoi/training/config/lam.yaml new file mode 100644 index 00000000..f476ae1f --- /dev/null +++ b/src/anemoi/training/config/lam.yaml @@ -0,0 +1,36 @@ +defaults: +- data: zarr +- dataloader: native_grid +- diagnostics: evaluation +- hardware: example +- graph: limited_area +- model: graphtransformer +- training: default +- _self_ + + +### This file is for local experimentation. +## When you commit your changes, assign the new features and keywords +## to the correct defaults. +# For example to change from default GPU count: +# hardware: +# num_gpus_per_node: 1 + +dataloader: + dataset: + cutout: + - dataset: ${hardware.paths.data}/${hardware.files.dataset} + thinning: ??? + - dataset: ${hardware.paths.data}/${hardware.files.forcing_dataset} + adjust: all + min_distance_km: 0 + grid_indices: + _target_: anemoi.training.data.grid_indices.MaskedGrid + nodes_name: data + node_attribute_name: indices_connected_nodes +model: + output_mask: cutout_mask # it must be a node attribute of the output nodes +hardware: + files: + dataset: ??? + forcing_dataset: ??? diff --git a/src/anemoi/training/config/stretched.yaml b/src/anemoi/training/config/stretched.yaml new file mode 100644 index 00000000..3aa5bbb8 --- /dev/null +++ b/src/anemoi/training/config/stretched.yaml @@ -0,0 +1,37 @@ +defaults: +- data: zarr +- dataloader: native_grid +- diagnostics: evaluation +- hardware: example +- graph: stretched_grid +- model: graphtransformer +- training: default +- _self_ + + +### This file is for local experimentation. +## When you commit your changes, assign the new features and keywords +## to the correct defaults. +# For example to change from default GPU count: +# hardware: +# num_gpus_per_node: 1 + +dataloader: + dataset: + cutout: + - dataset: ${hardware.paths.data}/${hardware.files.dataset} + thinning: ??? + - dataset: ${hardware.paths.data}/${hardware.files.forcing_dataset} + adjust: all + min_distance_km: 0 +training: + loss_scaling: + spatial: + _target_: anemoi.training.data.scaling.ReweightedGraphAttribute + target_nodes: ${graph.data} + scaled_attribute: area_weight # it must be a node attribute of the output nodes + cutout_weight_frac_of_global: ??? +hardware: + files: + dataset: ??? + forcing_dataset: ??? diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 3adcb55a..8b5383f6 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -29,6 +29,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +from anemoi.models.layers.mapper import GraphEdgeMixin from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only @@ -46,6 +47,7 @@ from typing import Any import pytorch_lightning as pl + from anemoi.models.layers.graph import NamedNodesAttributes from omegaconf import OmegaConf LOGGER = logging.getLogger(__name__) @@ -100,7 +102,7 @@ def _output_figure( exp_log_tag: str = "val_pred_sample", ) -> None: """Figure output: save to file and/or display in notebook.""" - if self.save_basedir is not None: + if self.save_basedir is not None and fig is not None: save_path = Path( self.save_basedir, "plots", @@ -645,6 +647,23 @@ def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None Override for frequency to plot at, by default None """ super().__init__(config, every_n_epochs=every_n_epochs) + self.q_extreme_limit = config.get("quantile_edges_to_represent", 0.05) + + def get_node_trainable_tensors(self, node_attributes: NamedNodesAttributes) -> dict[str, torch.Tensor]: + return { + name: tt.trainable for name, tt in node_attributes.trainable_tensors.items() if tt.trainable is not None + } + + def get_edge_trainable_modules(self, model: torch.nn.Module) -> dict[tuple[str, str], torch.Tensor]: + trainable_modules = { + (model._graph_name_data, model._graph_name_hidden): model.encoder, + (model._graph_name_hidden, model._graph_name_data): model.decoder, + } + + if isinstance(model.processor, GraphEdgeMixin): + trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor + + return {name: module for name, module in trainable_modules.items() if module.trainable.trainable is not None} @rank_zero_only def _plot( @@ -656,25 +675,31 @@ def _plot( _ = epoch model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - fig = plot_graph_node_features(model, datashader=self.datashader_plotting) + if len(node_trainable_tensors := self.get_node_trainable_tensors(model.node_attributes)): + fig = plot_graph_node_features(model, node_trainable_tensors, datashader=self.datashader_plotting) - self._output_figure( - trainer.logger, - fig, - epoch=trainer.current_epoch, - tag="node_trainable_params", - exp_log_tag="node_trainable_params", - ) + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="node_trainable_params", + exp_log_tag="node_trainable_params", + ) + else: + LOGGER.warning("There are no trainable node attributes to plot.") - fig = plot_graph_edge_features(model) + if len(edge_trainable_modules := self.get_edge_trainable_modules(model)): + fig = plot_graph_edge_features(model, edge_trainable_modules, q_extreme_limit=self.q_extreme_limit) - self._output_figure( - trainer.logger, - fig, - epoch=trainer.current_epoch, - tag="edge_trainable_params", - exp_log_tag="edge_trainable_params", - ) + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="edge_trainable_params", + exp_log_tag="edge_trainable_params", + ) + else: + LOGGER.warning("There are no trainable edge attributes to plot.") class PlotLoss(BasePerBatchPlotCallback): diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index cd99ce37..0ce55cdd 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -18,7 +18,6 @@ import matplotlib.style as mplstyle import numpy as np import pandas as pd -from anemoi.models.layers.mapper import GraphEdgeMixin from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection from matplotlib.collections import PathCollection @@ -35,7 +34,7 @@ if TYPE_CHECKING: from matplotlib.figure import Figure - from torch import nn + from torch import nn, Tensor from dataclasses import dataclass @@ -874,13 +873,19 @@ def edge_plot( fig.colorbar(psc, ax=ax) -def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figure: +def plot_graph_node_features( + model: nn.Module, + trainable_tensors: dict[str, Tensor], + datashader: bool = False, +) -> Figure: """Plot trainable graph node features. Parameters ---------- model: AneomiModelEncProcDec Model object + trainable_tensors: dict[str, torch.Tensor] + Node trainable tensors datashader: bool, optional Scatter plot, by default False @@ -889,14 +894,15 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu Figure Figure object handle """ - nrows = len(nodes_name := model._graph_data.node_types) - ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name) + nrows = len(trainable_tensors) + ncols = max(tt.shape[1] for tt in trainable_tensors.values()) + figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) - for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()): + for row, (mesh, trainable_tensor) in enumerate(trainable_tensors.items()): latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy() - node_features = trainable_tensor.trainable.cpu().detach().numpy() + node_features = trainable_tensor.cpu().detach().numpy() lat, lon = latlons[:, 0], latlons[:, 1] @@ -915,13 +921,19 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu return fig -def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure: +def plot_graph_edge_features( + model: nn.Module, + trainable_modules: dict[tuple[str, str], Tensor], + q_extreme_limit: float = 0.05, +) -> Figure: """Plot trainable graph edge features. Parameters ---------- model: AneomiModelEncProcDec Model object + trainable_modules: dict[tuple[str, str], torch.Tensor] + Edge trainable tensors. q_extreme_limit : float, optional Plot top & bottom quantile of edges trainable values, by default 0.05 (5%). @@ -930,16 +942,8 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure Figure object handle """ - trainable_modules = { - (model._graph_name_data, model._graph_name_hidden): model.encoder, - (model._graph_name_hidden, model._graph_name_data): model.decoder, - } - - if isinstance(model.processor, GraphEdgeMixin): - trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor - - ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) nrows = len(trainable_modules) + ncols = max(tt.trainable.trainable.shape[1] for tt in trainable_modules.values()) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) From 638d1b470626b75ff09169bdc81c758e3d258f30 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:06:30 +0100 Subject: [PATCH 23/23] fix: set anemoi-models=0.4.1 as the minimum required version (#209) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10d8efda..df6fc0a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dynamic = [ "version" ] dependencies = [ "anemoi-datasets>=0.5.2", "anemoi-graphs>=0.4.1", - "anemoi-models>=0.3", + "anemoi-models>=0.4.1", "anemoi-utils[provenance]>=0.4.4", "datashader>=0.16.3", "einops>=0.6.1",