Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature: Option to set / change area weighting outside of graph-creation #136

Merged
merged 24 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c7f4867
Implementation of aw_rescaling
havardhhaugen Nov 11, 2024
17be84b
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 11, 2024
e99a5a7
Pre-commit
havardhhaugen Nov 11, 2024
c576efb
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 12, 2024
6a527a9
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 12, 2024
8dc5e11
Updated implementation based on feedback
havardhhaugen Nov 14, 2024
cc4f38b
Small fixes - training now worked for all cases
havardhhaugen Nov 14, 2024
5df91e1
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 15, 2024
d0d2b57
Docstrings GraphNodeAttributes, minor fixes
havardhhaugen Nov 15, 2024
bc91253
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 15, 2024
8502ebc
Update changelog
havardhhaugen Nov 15, 2024
bb4969b
Removed obsolete config options
havardhhaugen Nov 15, 2024
569316f
Docstrings
havardhhaugen Nov 15, 2024
5e850ce
Unit testing
havardhhaugen Nov 19, 2024
f29b83a
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 19, 2024
d2fec0b
Updated documentation
havardhhaugen Nov 21, 2024
9c0ac29
area_weights uses AreaWeights from anemoi-graphs
havardhhaugen Nov 22, 2024
cb2bf36
Merge remote-tracking branch 'origin/develop' into pr/aw_rescale
havardhhaugen Nov 25, 2024
f4bf9c0
pre-commit
havardhhaugen Nov 25, 2024
c159e27
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 25, 2024
7bb2919
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 26, 2024
87262ee
if test to check for scaled_attribute
havardhhaugen Nov 26, 2024
796ad56
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 26, 2024
2c7c548
Merge remote-tracking branch 'origin/develop' into pr/aw_rescale
havardhhaugen Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Keep it human-readable, your future self will thank you!
- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/)

- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133)
- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136)

- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147)

Expand Down
22 changes: 22 additions & 0 deletions docs/user-guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,28 @@ levels nearer to the surface). By default anemoi-training uses a ReLU
Pressure Level scaler with a minimum weighting of 0.2 (i.e. no pressure
level has a weighting less than 0.2).

The loss is also scaled by assigning a weight to each node on the output
grid. These weights are calculated during graph-creation and stored as
an attribute in the graph object. The configuration option
``config.training.node_loss_weights`` is used to specify the node
attribute used as weights in the loss function. By default
anemoi-training uses area weighting, where each node is weighted
according to the size of the geographical area it represents.

It is also possible to rescale the weight of a subset of nodes after
they are loaded from the graph. For instance, for a stretched grid setup
we can rescale the weight of nodes in the limited area such that their
sum equals 0.25 of the sum of all node weights with the following config
setup

.. code:: yaml

node_loss_weights:
_target_: anemoi.training.losses.nodeweights.ReweightedGraphNodeAttribute
target_nodes: data
scaled_attribute: cutout
weight_frac_of_total: 0.25

***************
Learning rate
***************
Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/training/config/model/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ attributes:
- edge_dirs
nodes: []

node_loss_weight: area_weight

# Bounding configuration
bounding: #These are applied in order

Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/training/config/model/graphtransformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ attributes:
- edge_dirs
nodes: []

node_loss_weight: area_weight

# Bounding configuration
bounding: #These are applied in order

Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/training/config/model/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ attributes:
- edge_dirs
nodes: []

node_loss_weight: area_weight

# Bounding configuration
bounding: #These are applied in order

Expand Down
5 changes: 5 additions & 0 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,8 @@ pressure_level_scaler:
_target_: anemoi.training.data.scaling.ReluPressureLevelScaler
minimum: 0.2
slope: 0.001

node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
target_nodes: ${graph.data}
node_attribute: area_weight
141 changes: 141 additions & 0 deletions src/anemoi/training/losses/nodeweights.py
HCookie marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

import torch
from anemoi.graphs.nodes.attributes import AreaWeights
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class GraphNodeAttribute:
"""Base class to load and optionally change the weight attribute of nodes in the graph.

Attributes
----------
target: str
name of target nodes, key in HeteroData graph object
node_attribute: str
name of node weight attribute, key in HeteroData graph object

Methods
-------
weights(self, graph_data)
Load node weight attribute. Compute area weights if they can not be found in graph
object.
"""

def __init__(self, target_nodes: str, node_attribute: str):
"""Initialize graph node attribute with target nodes and node attribute.

Parameters
----------
target_nodes: str
name of nodes, key in HeteroData graph object
node_attribute: str
name of node weight attribute, key in HeteroData graph object
"""
self.target = target_nodes
self.node_attribute = node_attribute

def area_weights(self, graph_data: HeteroData) -> torch.Tensor:
"""Nodes weighted by the size of the geographical area they represent.

Parameters
----------
graph_data: HeteroData
graph object

Returns
-------
torch.Tensor
area weights of the target nodes
"""
return AreaWeights(norm="unit-max", fill_value=0).compute(graph_data, self.target)

def weights(self, graph_data: HeteroData) -> torch.Tensor:
"""Returns weight of type self.node_attribute for nodes self.target.

Attempts to load from graph_data and calculates area weights for the target
nodes if they do not exist.

Parameters
----------
graph_data: HeteroData
graph object

Returns
-------
torch.Tensor
weight of target nodes
"""
if self.node_attribute in graph_data[self.target]:
attr_weight = graph_data[self.target][self.node_attribute].squeeze()

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
else:
attr_weight = self.area_weights(graph_data).squeeze()

LOGGER.info(
"Node attribute %s not found in graph. Default area weighting will be used",
self.node_attribute,
)

return attr_weight


class ReweightedGraphNodeAttribute(GraphNodeAttribute):
"""Method to reweight a subset of the target nodes defined by scaled_attribute.

Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum
over all nodes.
"""

def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float):
"""Initialize reweighted graph node attribute.

Parameters
----------
target_nodes: str
name of nodes, key in HeteroData graph object
node_attribute: str
name of node weight attribute, key in HeteroData graph object
scaled_attribute: str
name of node attribute defining the subset of nodes to be scaled, key in HeteroData graph object
weight_frac_of_total: float
sum of weight of subset nodes as a fraction of sum of weight of all nodes after rescaling

"""
super().__init__(target_nodes=target_nodes, node_attribute=node_attribute)
self.scaled_attribute = scaled_attribute
self.fraction = weight_frac_of_total

def weights(self, graph_data: HeteroData) -> torch.Tensor:
attr_weight = super().weights(graph_data)

if self.scaled_attribute in graph_data[self.target]:
mask = graph_data[self.target][self.scaled_attribute].squeeze().bool()
else:
error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object"
raise KeyError(error_msg)

unmasked_sum = torch.sum(attr_weight[~mask])
weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask)
attr_weight[mask] = weight_per_masked_node

LOGGER.info(
"Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes",
self.scaled_attribute,
self.fraction,
)

return attr_weight
8 changes: 7 additions & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.save_hyperparameters()

self.latlons_data = graph_data[config.graph.data].x
self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze()
self.node_weights = self.get_node_weights(config, graph_data)

if config.model.get("output_mask", None) is not None:
self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask])
Expand Down Expand Up @@ -315,6 +315,12 @@ def get_variable_scaling(

return torch.from_numpy(variable_loss_scaling)

@staticmethod
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weighting = instantiate(config.training.node_loss_weights)

return node_weighting.weights(graph_data)

def set_model_comm_group(
self,
model_comm_group: ProcessGroup,
Expand Down
91 changes: 91 additions & 0 deletions tests/train/test_nodeweights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import pytest
import torch
from anemoi.graphs.nodes.attributes import AreaWeights
from torch_geometric.data import HeteroData

from anemoi.training.losses.nodeweights import GraphNodeAttribute
from anemoi.training.losses.nodeweights import ReweightedGraphNodeAttribute


def fake_graph() -> HeteroData:
hdata = HeteroData()
lons = torch.tensor([1.56, 3.12, 4.68, 6.24])
lats = torch.tensor([-3.12, -1.56, 1.56, 3.12])
cutout_mask = torch.tensor([False, True, False, False]).unsqueeze(1)
area_weights = torch.ones(cutout_mask.shape)
hdata["data"]["x"] = torch.stack((lats, lons), dim=1)
hdata["data"]["cutout"] = cutout_mask
hdata["data"]["area_weight"] = area_weights

return hdata


def fake_sv_area_weights() -> torch.Tensor:
return AreaWeights(norm="unit-max", fill_value=0).compute(fake_graph(), "data").squeeze()


def fake_reweighted_sv_area_weights(frac: float) -> torch.Tensor:
weights = fake_sv_area_weights().unsqueeze(1)
cutout_mask = fake_graph()["data"]["cutout"]
unmasked_sum = torch.sum(weights[~cutout_mask])
weight_per_masked_node = frac / (1.0 - frac) * unmasked_sum / sum(cutout_mask)
weights[cutout_mask] = weight_per_masked_node

return weights.squeeze()


@pytest.mark.parametrize(
("target_nodes", "node_attribute", "fake_graph", "expected_weights"),
[
("data", "area_weight", fake_graph(), fake_graph()["data"]["area_weight"]),
("data", "non_existent_attr", fake_graph(), fake_sv_area_weights()),
],
)
def test_grap_node_attributes(
target_nodes: str,
node_attribute: str,
fake_graph: HeteroData,
expected_weights: torch.Tensor,
) -> None:
weights = GraphNodeAttribute(target_nodes=target_nodes, node_attribute=node_attribute).weights(fake_graph)
assert isinstance(weights, torch.Tensor)
assert torch.allclose(weights, expected_weights)


@pytest.mark.parametrize(
("target_nodes", "node_attribute", "scaled_attribute", "weight_frac_of_total", "fake_graph", "expected_weights"),
[
("data", "area_weight", "cutout", 0.0, fake_graph(), torch.tensor([1.0, 0.0, 1.0, 1.0])),
("data", "area_weight", "cutout", 0.5, fake_graph(), torch.tensor([1.0, 3.0, 1.0, 1.0])),
("data", "area_weight", "cutout", 0.97, fake_graph(), torch.tensor([1.0, 97.0, 1.0, 1.0])),
("data", "non_existent_attr", "cutout", 0.0, fake_graph(), fake_reweighted_sv_area_weights(0.0)),
("data", "non_existent_attr", "cutout", 0.5, fake_graph(), fake_reweighted_sv_area_weights(0.5)),
("data", "non_existent_attr", "cutout", 0.99, fake_graph(), fake_reweighted_sv_area_weights(0.99)),
],
)
def test_graph_node_attributes(
target_nodes: str,
node_attribute: str,
scaled_attribute: str,
weight_frac_of_total: float,
fake_graph: HeteroData,
expected_weights: torch.Tensor,
) -> None:
weights = ReweightedGraphNodeAttribute(
target_nodes=target_nodes,
node_attribute=node_attribute,
scaled_attribute=scaled_attribute,
weight_frac_of_total=weight_frac_of_total,
).weights(graph_data=fake_graph)
assert isinstance(weights, torch.Tensor)
assert torch.allclose(weights, expected_weights)