Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Upgrade configs to anemoi-graphs 0.4.1 (#159)
Browse files Browse the repository at this point in the history
* chore: update configs to anemoi-graphs=0.4.1

* feat: bump anemoi-graphs version requirement to >= 0.4.1

* fix: target_mask_attr_name inside edge builder

* fix: remove from default

* Update CHANGELOG.md

* Update pyproject.toml

* Update CHANGELOG.md

* Update CHANGELOG.md

* fix: lam plotting

* fix: cast to DotDict

* fix: add import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: anemoi-graphs 0.4.1 format

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
JPXKQX and pre-commit-ci[bot] authored Nov 28, 2024
1 parent 112d78f commit ac38f92
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 33 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ Keep it human-readable, your future self will thank you!

### Added
- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155)


- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76)
- Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159)

### Changed
## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ classifiers = [
dynamic = [ "version" ]

dependencies = [
"anemoi-datasets>=0.4",
"anemoi-graphs>=0.4",
"anemoi-datasets>=0.5.2",
"anemoi-graphs>=0.4.1",
"anemoi-models>=0.3",
"anemoi-utils[provenance]>=0.4.4",
"datashader>=0.16.3",
Expand Down
10 changes: 5 additions & 5 deletions src/anemoi/training/config/graph/encoder_decoder_only.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ edges:
# Encoder configuration
- source_name: ${graph.data}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
cutoff_factor: 0.6 # only for cutoff method
attributes: ${graph.attributes.edges}
- source_name: ${graph.hidden}
# Decoder configuration
- source_name: ${graph.hidden}
target_name: ${graph.data}
edge_builder:
_target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
num_nearest_neighbours: 3 # only for knn method
attributes: ${graph.attributes.edges}

Expand Down
14 changes: 7 additions & 7 deletions src/anemoi/training/config/graph/limited_area.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@ edges:
# Encoder configuration
- source_name: ${graph.data}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
cutoff_factor: 0.6 # only for cutoff method
attributes: ${graph.attributes.edges}
# Processor configuration
- source_name: ${graph.hidden}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.MultiScaleEdges
edge_builders:
- _target_: anemoi.graphs.edges.MultiScaleEdges
x_hops: 1
attributes: ${graph.attributes.edges}
# Decoder configuration
- source_name: ${graph.hidden}
target_name: ${graph.data}
target_mask_attr_name: cutout
edge_builder:
_target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
target_mask_attr_name: cutout
num_nearest_neighbours: 3 # only for knn method
attributes: ${graph.attributes.edges}

Expand Down
16 changes: 8 additions & 8 deletions src/anemoi/training/config/graph/multi_scale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ edges:
# Encoder configuration
- source_name: ${graph.data}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
cutoff_factor: 0.6 # only for cutoff method
attributes: ${graph.attributes.edges}
- source_name: ${graph.hidden}
# Processor configuration
- source_name: ${graph.hidden}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.MultiScaleEdges
edge_builders:
- _target_: anemoi.graphs.edges.MultiScaleEdges
x_hops: 1
attributes: ${graph.attributes.edges}
- source_name: ${graph.hidden}
# Decoder configuration
- source_name: ${graph.hidden}
target_name: ${graph.data}
edge_builder:
_target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
num_nearest_neighbours: 3 # only for knn method
attributes: ${graph.attributes.edges}

Expand Down
12 changes: 6 additions & 6 deletions src/anemoi/training/config/graph/stretched_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ edges:
# Encoder
- source_name: ${graph.data}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.KNNEdges
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges
num_nearest_neighbours: 12
attributes: ${graph.attributes.edges}
# Processor
- source_name: ${graph.hidden}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.MultiScaleEdges
edge_builders:
- _target_: anemoi.graphs.edges.MultiScaleEdges
x_hops: 1
attributes: ${graph.attributes.edges}
# Decoder
- source_name: ${graph.hidden}
target_name: ${graph.data}
edge_builder:
_target_: anemoi.graphs.edges.KNNEdges
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges
num_nearest_neighbours: 3
attributes: ${graph.attributes.edges}

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def _plot(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
)
output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy()
output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy()
data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan)
data = data.numpy()

Expand Down Expand Up @@ -999,7 +999,7 @@ def process(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
)
output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy()
output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy()
data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan)
data = data.numpy()
return data, output_tensor
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.utils.config import DotDict
from anemoi.utils.provenance import gather_provenance_info
from omegaconf import DictConfig
from omegaconf import OmegaConf
Expand Down Expand Up @@ -128,7 +129,8 @@ def graph_data(self) -> HeteroData:

from anemoi.graphs.create import GraphCreator

return GraphCreator(config=self.config.graph).create(
graph_config = DotDict(OmegaConf.to_container(self.config.graph, resolve=True))
return GraphCreator(config=graph_config).create(
save_path=graph_filename,
overwrite=self.config.graph.overwrite,
)
Expand Down

0 comments on commit ac38f92

Please sign in to comment.