Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge branch 'develop' into 150-n320-callback-number-of-pixels-used-b…
Browse files Browse the repository at this point in the history
…y-datashader
  • Loading branch information
anaprietonem committed Nov 25, 2024
2 parents 0c1f138 + 0608f21 commit da1d79b
Show file tree
Hide file tree
Showing 16 changed files with 522 additions and 136 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ Keep it human-readable, your future self will thank you!
- Update `n_pixel` used by datashader to better adapt across resolutions #152


Fixed bug in power spectra plotting for the n320 resolution.

### Added
- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155)


- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76)

### Changed
## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14
Expand Down Expand Up @@ -49,6 +55,7 @@ Keep it human-readable, your future self will thank you!
- Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65)
- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/)
- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133)
- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147)

### Changed

Expand Down Expand Up @@ -114,6 +121,7 @@ Keep it human-readable, your future self will thank you!
- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48)
- Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86)
- Long Rollout Plots
- Mask NaN values in training loss function [#72](https://github.com/ecmwf/anemoi-training/pull/72) and [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)

### Fixed

Expand Down
4 changes: 4 additions & 0 deletions docs/user-guide/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ number of GPUs you wish to shard the model across. It is recommended to
only shard if the model does not fit in GPU memory, as data distribution
is a much more efficient way to parallelise the training.

When using model sharding, ``config.dataloader.read_group_size`` allows
for sharded data loading in subgroups. This should be set to the number
of GPUs per model for optimal performance.

*********
Example
*********
Expand Down
15 changes: 10 additions & 5 deletions docs/user-guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ level has a weighting less than 0.2).
***************

Anemoi training uses the ``CosineLRScheduler`` from PyTorch as it's
learning rate scheduler. The user can configure the maximum learning
rate by setting ``config.training.lr.rate``. Note that this learning
rate is scaled by the number of GPUs where for the `data parallelism
<distributed>`_.
learning rate scheduler. Docs for this scheduler can be found here
https://github.com/huggingface/pytorch-image-models/blob/main/timm/scheduler/cosine_lr.py
The user can configure the maximum learning rate by setting
``config.training.lr.rate``. Note that this learning rate is scaled by
the number of GPUs where for the `data parallelism <distributed>`_.

.. code:: yaml
Expand All @@ -201,7 +202,11 @@ The user can also control the rate at which the learning rate decreases
by setting the total number of iterations through
``config.training.lr.iterations`` and the minimum learning rate reached
through ``config.training.lr.min``. Note that the minimum learning rate
is not scaled by the number of GPUs.
is not scaled by the number of GPUs. The user can also control the
warmup period by setting ``config.training.lr.warmup_t``. If the warmup
period is set to 0, the learning rate will start at the maximum learning
rate. If no warmup period is defined, a default warmup period of 1000
iterations is used.

*********
Rollout
Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
prefetch_factor: 2
pin_memory: True

# ============
# read_group_size:
# Form subgroups of model comm groups that read data together.
# Each reader in the group only reads 1/read_group_size of the data
# which is then all-gathered between the group.
# This can reduce CPU memory usage as well as increase dataloader throughput.
# The number of GPUs per model must be divisible by read_group_size.
# To disable, set to 1.
# ============
read_group_size: ${hardware.num_gpus_per_model}

num_workers:
training: 8
validation: 8
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ training_loss:
# Scalars to include in loss calculation
# Available scalars include:
# - 'variable': See `variable_loss_scaling` for more information
scalars: ['variable']
# - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function
scalars: ['variable', 'loss_weights_mask']

ignore_nans: False

loss_gradient_scaling: False
Expand Down Expand Up @@ -81,6 +83,7 @@ lr:
rate: 0.625e-4 #local_lr
iterations: ${training.max_steps} # NOTE: When max_epochs < max_steps, scheduler will run for max_steps
min: 3e-7 #Not scaled by #GPU
warmup_t: 1000

# Changes in per-gpu batch_size should come with a rescaling of the local_lr
# in order to keep a constant global_lr
Expand Down
39 changes: 6 additions & 33 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os
from functools import cached_property
from typing import Callable

Expand Down Expand Up @@ -43,31 +42,6 @@ def __init__(self, config: DictConfig) -> None:

self.config = config

self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank
self.model_comm_group_id = (
self.global_rank // self.config.hardware.num_gpus_per_model
) # id of the model communication group the rank is participating in
self.model_comm_group_rank = (
self.global_rank % self.config.hardware.num_gpus_per_model
) # rank within one model communication group
total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes
assert (
total_gpus
) % self.config.hardware.num_gpus_per_model == 0, (
f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}"
)
self.model_comm_num_groups = (
self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
) # number of model communication groups
LOGGER.debug(
"Rank %d model communication group number %d, with local model communication group rank %d",
self.global_rank,
self.model_comm_group_id,
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
Expand Down Expand Up @@ -142,10 +116,12 @@ def ds_train(self) -> NativeGridDataset:
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
f"validation start date {self.config.dataloader.validation.start}"
)
if not self.config.dataloader.training.end < self.config.dataloader.validation.start:
LOGGER.warning(
"Training end date %s is not before validation start date %s.",
self.config.dataloader.training.end,
self.config.dataloader.validation.start,
)
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
shuffle=False,
Expand Down Expand Up @@ -182,9 +158,6 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
shuffle=shuffle,
label=label,
)
Expand Down
82 changes: 67 additions & 15 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
) -> None:
Expand All @@ -54,12 +51,6 @@ def __init__(
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
device group ID, default 0
model_comm_num_groups : int, optional
total number of device groups, by default 1
shuffle : bool, optional
Shuffle batches, by default True
label : str, optional
Expand All @@ -77,11 +68,14 @@ def __init__(
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# DDP-relevant info
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_id = model_comm_group_id
self.global_rank = int(os.environ.get("SLURM_PROCID", "0"))
# lazy init model and reader group info, will be set by the DDPGroupStrategy:
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.model_comm_group_id = 0
self.global_rank = 0

self.reader_group_rank = 0
self.reader_group_size = 1

# additional state vars (lazy init)
self.n_samples_per_worker = 0
Expand All @@ -93,6 +87,8 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_dim: int = -1
self.grid_size = self.data.shape[self.grid_dim]

@cached_property
def statistics(self) -> dict:
Expand Down Expand Up @@ -128,6 +124,58 @@ def valid_date_indices(self) -> np.ndarray:
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)

def set_comm_group_info(
self,
global_rank: int,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
reader_group_rank: int,
reader_group_size: int,
) -> None:
"""Set model and reader communication group information (called by DDPGroupStrategy).
Parameters
----------
global_rank : int
Global rank
model_comm_group_id : int
Model communication group ID
model_comm_group_rank : int
Model communication group rank
model_comm_num_groups : int
Number of model communication groups
reader_group_rank : int
Reader group rank
reader_group_size : int
Reader group size
"""
self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

if self.reader_group_size > 1:
# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
self.grid_end = self.grid_size
else:
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d",
global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.
Expand Down Expand Up @@ -233,7 +281,11 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
else: # read the full grid
x = self.data[start : end : self.timeincrement, :, :, :]

x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only

if TYPE_CHECKING:
import pytorch_lightning as pl
Expand Down Expand Up @@ -103,7 +102,6 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict,
rank_zero_only=True,
)

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -114,6 +112,8 @@ def on_validation_batch_end(
) -> None:
del outputs # outputs are not used
if batch_idx % self.every_n_batches == 0:
batch = pl_module.allgather_batch(batch)

precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand Down
16 changes: 13 additions & 3 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None):
super().__init__(config)
self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -251,7 +250,16 @@ def on_validation_batch_end(
batch_idx: int,
**kwargs,
) -> None:
if (
self.config.diagnostics.plot.asynchronous
and self.config.dataloader.read_group_size > 1
and pl_module.local_rank == 0
):
LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.")

if batch_idx % self.every_n_batches == 0:
batch = pl_module.allgather_batch(batch)

self.plot(
trainer,
pl_module,
Expand Down Expand Up @@ -383,7 +391,6 @@ def __init__(
every_n_epochs,
)

@rank_zero_only
def _plot(
self,
trainer: pl.Trainer,
Expand Down Expand Up @@ -480,6 +487,7 @@ def _plot(

LOGGER.info("Time taken to plot/animate samples for longer rollout: %d seconds", int(time.time() - start_time))

@rank_zero_only
def _plot_rollout_step(
self,
pl_module: pl.LightningModule,
Expand Down Expand Up @@ -539,6 +547,7 @@ def _store_video_frame_data(
vmax[:] = np.maximum(vmax, np.nanmax(data_over_time[-1], axis=1))
return data_over_time, vmin, vmax

@rank_zero_only
def _generate_video_rollout(
self,
data_0: np.ndarray,
Expand Down Expand Up @@ -595,7 +604,6 @@ def _generate_video_rollout(
tag=f"gnn_pred_val_animation_{variable_name}_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0",
)

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -605,6 +613,8 @@ def on_validation_batch_end(
batch_idx: int,
) -> None:
if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0:
batch = pl_module.allgather_batch(batch)

precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand Down
Loading

0 comments on commit da1d79b

Please sign in to comment.