diff --git a/CHANGELOG.md b/CHANGELOG.md index dc2e9f14..7f2d4c19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ Keep it human-readable, your future self will thank you! - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136) - Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 588b34d9..e90b1583 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -183,6 +183,28 @@ levels nearer to the surface). By default anemoi-training uses a ReLU Pressure Level scaler with a minimum weighting of 0.2 (i.e. no pressure level has a weighting less than 0.2). +The loss is also scaled by assigning a weight to each node on the output +grid. These weights are calculated during graph-creation and stored as +an attribute in the graph object. The configuration option +``config.training.node_loss_weights`` is used to specify the node +attribute used as weights in the loss function. By default +anemoi-training uses area weighting, where each node is weighted +according to the size of the geographical area it represents. + +It is also possible to rescale the weight of a subset of nodes after +they are loaded from the graph. For instance, for a stretched grid setup +we can rescale the weight of nodes in the limited area such that their +sum equals 0.25 of the sum of all node weights with the following config +setup + +.. code:: yaml + + node_loss_weights: + _target_: anemoi.training.losses.nodeweights.ReweightedGraphNodeAttribute + target_nodes: data + scaled_attribute: cutout + weight_frac_of_total: 0.25 + *************** Learning rate *************** diff --git a/src/anemoi/training/config/model/gnn.yaml b/src/anemoi/training/config/model/gnn.yaml index 4f4c176c..92a17fd4 100644 --- a/src/anemoi/training/config/model/gnn.yaml +++ b/src/anemoi/training/config/model/gnn.yaml @@ -45,8 +45,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 5c2e819a..9c48967b 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -50,8 +50,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/transformer.yaml b/src/anemoi/training/config/model/transformer.yaml index b26c9ecc..cd6a1e7b 100644 --- a/src/anemoi/training/config/model/transformer.yaml +++ b/src/anemoi/training/config/model/transformer.yaml @@ -49,8 +49,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index af168ecc..66d0631c 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -118,3 +118,8 @@ pressure_level_scaler: _target_: anemoi.training.data.scaling.ReluPressureLevelScaler minimum: 0.2 slope: 0.001 + +node_loss_weights: + _target_: anemoi.training.losses.nodeweights.GraphNodeAttribute + target_nodes: ${graph.data} + node_attribute: area_weight diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py new file mode 100644 index 00000000..ed4afaf4 --- /dev/null +++ b/src/anemoi/training/losses/nodeweights.py @@ -0,0 +1,141 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +import torch +from anemoi.graphs.nodes.attributes import AreaWeights +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + + +class GraphNodeAttribute: + """Base class to load and optionally change the weight attribute of nodes in the graph. + + Attributes + ---------- + target: str + name of target nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + + Methods + ------- + weights(self, graph_data) + Load node weight attribute. Compute area weights if they can not be found in graph + object. + """ + + def __init__(self, target_nodes: str, node_attribute: str): + """Initialize graph node attribute with target nodes and node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + """ + self.target = target_nodes + self.node_attribute = node_attribute + + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: + """Nodes weighted by the size of the geographical area they represent. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + torch.Tensor + area weights of the target nodes + """ + return AreaWeights(norm="unit-max", fill_value=0).compute(graph_data, self.target) + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + """Returns weight of type self.node_attribute for nodes self.target. + + Attempts to load from graph_data and calculates area weights for the target + nodes if they do not exist. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + torch.Tensor + weight of target nodes + """ + if self.node_attribute in graph_data[self.target]: + attr_weight = graph_data[self.target][self.node_attribute].squeeze() + + LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) + else: + attr_weight = self.area_weights(graph_data).squeeze() + + LOGGER.info( + "Node attribute %s not found in graph. Default area weighting will be used", + self.node_attribute, + ) + + return attr_weight + + +class ReweightedGraphNodeAttribute(GraphNodeAttribute): + """Method to reweight a subset of the target nodes defined by scaled_attribute. + + Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum + over all nodes. + """ + + def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float): + """Initialize reweighted graph node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + scaled_attribute: str + name of node attribute defining the subset of nodes to be scaled, key in HeteroData graph object + weight_frac_of_total: float + sum of weight of subset nodes as a fraction of sum of weight of all nodes after rescaling + + """ + super().__init__(target_nodes=target_nodes, node_attribute=node_attribute) + self.scaled_attribute = scaled_attribute + self.fraction = weight_frac_of_total + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + attr_weight = super().weights(graph_data) + + if self.scaled_attribute in graph_data[self.target]: + mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + else: + error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object" + raise KeyError(error_msg) + + unmasked_sum = torch.sum(attr_weight[~mask]) + weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) + attr_weight[mask] = weight_per_masked_node + + LOGGER.info( + "Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes", + self.scaled_attribute, + self.fraction, + ) + + return attr_weight diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index f92050cf..717e88c3 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -84,7 +84,7 @@ def __init__( self.save_hyperparameters() self.latlons_data = graph_data[config.graph.data].x - self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + self.node_weights = self.get_node_weights(config, graph_data) if config.model.get("output_mask", None) is not None: self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) @@ -315,6 +315,12 @@ def get_variable_scaling( return torch.from_numpy(variable_loss_scaling) + @staticmethod + def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: + node_weighting = instantiate(config.training.node_loss_weights) + + return node_weighting.weights(graph_data) + def set_model_comm_group( self, model_comm_group: ProcessGroup, diff --git a/tests/train/test_nodeweights.py b/tests/train/test_nodeweights.py new file mode 100644 index 00000000..00b1f41d --- /dev/null +++ b/tests/train/test_nodeweights.py @@ -0,0 +1,91 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import pytest +import torch +from anemoi.graphs.nodes.attributes import AreaWeights +from torch_geometric.data import HeteroData + +from anemoi.training.losses.nodeweights import GraphNodeAttribute +from anemoi.training.losses.nodeweights import ReweightedGraphNodeAttribute + + +def fake_graph() -> HeteroData: + hdata = HeteroData() + lons = torch.tensor([1.56, 3.12, 4.68, 6.24]) + lats = torch.tensor([-3.12, -1.56, 1.56, 3.12]) + cutout_mask = torch.tensor([False, True, False, False]).unsqueeze(1) + area_weights = torch.ones(cutout_mask.shape) + hdata["data"]["x"] = torch.stack((lats, lons), dim=1) + hdata["data"]["cutout"] = cutout_mask + hdata["data"]["area_weight"] = area_weights + + return hdata + + +def fake_sv_area_weights() -> torch.Tensor: + return AreaWeights(norm="unit-max", fill_value=0).compute(fake_graph(), "data").squeeze() + + +def fake_reweighted_sv_area_weights(frac: float) -> torch.Tensor: + weights = fake_sv_area_weights().unsqueeze(1) + cutout_mask = fake_graph()["data"]["cutout"] + unmasked_sum = torch.sum(weights[~cutout_mask]) + weight_per_masked_node = frac / (1.0 - frac) * unmasked_sum / sum(cutout_mask) + weights[cutout_mask] = weight_per_masked_node + + return weights.squeeze() + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", fake_graph(), fake_graph()["data"]["area_weight"]), + ("data", "non_existent_attr", fake_graph(), fake_sv_area_weights()), + ], +) +def test_grap_node_attributes( + target_nodes: str, + node_attribute: str, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = GraphNodeAttribute(target_nodes=target_nodes, node_attribute=node_attribute).weights(fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights) + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "scaled_attribute", "weight_frac_of_total", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", "cutout", 0.0, fake_graph(), torch.tensor([1.0, 0.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.5, fake_graph(), torch.tensor([1.0, 3.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.97, fake_graph(), torch.tensor([1.0, 97.0, 1.0, 1.0])), + ("data", "non_existent_attr", "cutout", 0.0, fake_graph(), fake_reweighted_sv_area_weights(0.0)), + ("data", "non_existent_attr", "cutout", 0.5, fake_graph(), fake_reweighted_sv_area_weights(0.5)), + ("data", "non_existent_attr", "cutout", 0.99, fake_graph(), fake_reweighted_sv_area_weights(0.99)), + ], +) +def test_graph_node_attributes( + target_nodes: str, + node_attribute: str, + scaled_attribute: str, + weight_frac_of_total: float, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = ReweightedGraphNodeAttribute( + target_nodes=target_nodes, + node_attribute=node_attribute, + scaled_attribute=scaled_attribute, + weight_frac_of_total=weight_frac_of_total, + ).weights(graph_data=fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights)