Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature: Option to set / change area weighting outside of graph-creation #136

Merged
merged 24 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c7f4867
Implementation of aw_rescaling
havardhhaugen Nov 11, 2024
17be84b
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 11, 2024
e99a5a7
Pre-commit
havardhhaugen Nov 11, 2024
c576efb
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 12, 2024
6a527a9
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 12, 2024
8dc5e11
Updated implementation based on feedback
havardhhaugen Nov 14, 2024
cc4f38b
Small fixes - training now worked for all cases
havardhhaugen Nov 14, 2024
5df91e1
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 15, 2024
d0d2b57
Docstrings GraphNodeAttributes, minor fixes
havardhhaugen Nov 15, 2024
bc91253
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 15, 2024
8502ebc
Update changelog
havardhhaugen Nov 15, 2024
bb4969b
Removed obsolete config options
havardhhaugen Nov 15, 2024
569316f
Docstrings
havardhhaugen Nov 15, 2024
5e850ce
Unit testing
havardhhaugen Nov 19, 2024
f29b83a
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 19, 2024
d2fec0b
Updated documentation
havardhhaugen Nov 21, 2024
9c0ac29
area_weights uses AreaWeights from anemoi-graphs
havardhhaugen Nov 22, 2024
cb2bf36
Merge remote-tracking branch 'origin/develop' into pr/aw_rescale
havardhhaugen Nov 25, 2024
f4bf9c0
pre-commit
havardhhaugen Nov 25, 2024
c159e27
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 25, 2024
7bb2919
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 26, 2024
87262ee
if test to check for scaled_attribute
havardhhaugen Nov 26, 2024
796ad56
Merge branch 'ecmwf:develop' into pr/aw_rescale
havardhhaugen Nov 26, 2024
2c7c548
Merge remote-tracking branch 'origin/develop' into pr/aw_rescale
havardhhaugen Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,8 @@ pressure_level_scaler:
_target_: anemoi.training.data.scaling.ReluPressureLevelScaler
minimum: 0.2
slope: 0.001

node_loss_weights:
_target_: anemoi.traininig.losses.nodeweights.GraphNodeAttribute
HCookie marked this conversation as resolved.
Show resolved Hide resolved
target_nodes: ${graph.data}
node_attribute: area_weight
89 changes: 89 additions & 0 deletions src/anemoi/training/losses/nodeweigths.py
HCookie marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

import numpy as np
import torch
from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian
from scipy.spatial import SphericalVoronoi
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class GraphNodeAttribute:
"""Method to load and optionally change the weighting of node attributes in the graph."""

def __init__(self, target_nodes: str, node_attribute: str):
self.target = target_nodes
self.node_attribute = node_attribute

def area_weights(self, graph_data: HeteroData) -> np.ndarray:
lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons)))
sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0])
area_weights = sv.calculate_areas()

return area_weights / np.max(area_weights)

def weights(self, graph_data: HeteroData) -> torch.Tensor:
try:
attr_weight = graph_data[self.target][self.node_attribute].squeeze()

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
except KeyError:
attr_weight = torch.from_numpy(self.global_area_weights(graph_data))

LOGGER.info(
"Node attribute %s not found in graph. Default area weighting will be used",
self.node_attribute,
)

return attr_weight


class ReweightedGraphNodeAttribute(GraphNodeAttribute):
"""Method to reweight a subset of the target nodes defined by scaled_attributes.

Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum
over all nodes.
"""

def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float):
super().__init__(target_nodes=target_nodes, node_attribute=node_attribute)
self.scaled_attribute = scaled_attribute
self.fraction = weight_frac_of_total

def weights(self, graph_data: HeteroData) -> torch.Tensor:
try:
attr_weight = graph_data[self.target][self.node_attribute].squeeze()

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
except KeyError:
attr_weight = torch.from_numpy(self.global_area_weights(graph_data))

LOGGER.info(
"Node attribute %s not found in graph. Default area weighting will be used",
self.node_attribute,
)
HCookie marked this conversation as resolved.
Show resolved Hide resolved

mask = graph_data[self.target][self.scaled_attribute].squeeze().bool()

unmasked_sum = torch.sum(attr_weight[~mask])
weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask)
attr_weight[mask] = weight_per_masked_node
LOGGER.info(
"Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes",
self.node_attribute,
self.fraction,
)

return attr_weight
8 changes: 7 additions & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.save_hyperparameters()

self.latlons_data = graph_data[config.graph.data].x
self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze()
self.node_weights = self.get_node_weights(config, graph_data)

if config.model.get("output_mask", None) is not None:
self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask])
Expand Down Expand Up @@ -290,6 +290,12 @@ def get_feature_weights(

return torch.from_numpy(loss_scaling)

@staticmethod
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weighting = instantiate(config.training.node_loss_weights)

return node_weighting.weights(graph_data)

def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None:
LOGGER.debug("set_model_comm_group: %s", model_comm_group)
self.model_comm_group = model_comm_group
Expand Down
Loading