Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Feature/transfer-learning (#166)
Browse files Browse the repository at this point in the history
* Introduced resume flag and checkpoint loading for transfer learning, removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting

* Added len of dataset computed dynamically

* debugging validation

* Small changes

* Removed prints

* Not working

* small changes

* Imputer changes

* Added sanification of checkpoint, effective batch size, git pre commit

* gpc

* gpc

* New implementation: do not store modified checkpoint, load it directly after changing it

* Added logging

* Transfer learning working: implemented checkpoint cleaning with large models

* Reverted some changes concerning imputer issues

* Reverted some changes concerning imputer issues

* Cleaned code for final review

* Changed changelog and assigned TODO correctly

* Changed changelog and assigned TODO correctly

* Addressed review: copy checkpoint before removing metadata file

* gpc passed

* Removed logger in debugging mode

* removed dataset lenght due to checkpointing issues

* Reintroduced correct config on graphtansformer

* gpc passed

* Removed patched for issue #57, code expects patched checkpoint already

* Removed new path name for patched checkpoint (ignoring fully issue #57) + removed fix for missing config

* Adapted changelog

* Switched logging to info from debug
  • Loading branch information
icedoom888 authored Dec 6, 2024
1 parent 5c4ac3f commit 891405e
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Keep it human-readable, your future self will thank you!
- Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180)

### Added
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
- <b> TRANSFER LEARNING</b>: enabled new functionality. You can now load checkpoints from different models and different training runs.
- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`.
Used for experiment reproducibility across different computing configurations.
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)

### Changed
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
run_id: null
fork_run_id: null
load_weights_only: null # only load model weights, do not restore optimiser states etc.
transfer_learning: null # activate to perform transfer learning

# run in deterministic mode ; slows down
deterministic: False
Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,25 @@ def _get_dataset(
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:

r = max(rollout, self.rollout)

# Compute effective batch size
effective_bs = (
self.config.dataloader.batch_size["training"]
* self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
)

return NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
shuffle=shuffle,
label=label,
effective_bs=effective_bs,
)

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
timeincrement: int = 1,
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
) -> None:
"""Initialize (part of) the dataset state.
Expand All @@ -55,9 +56,11 @@ def __init__(
Shuffle batches, by default True
label : str, optional
label for the dataset, by default "generic"
effective_bs : int, default 1
effective batch size useful to compute the lenght of the dataset
"""
self.label = label
self.effective_bs = effective_bs

self.data = data_reader

Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,10 @@ def on_train_epoch_end(self) -> None:
self.rollout = min(self.rollout, self.rollout_max)

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:

with torch.no_grad():
val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True)

self.log(
f"val_{getattr(self.loss, 'name', self.loss.__class__.__name__.lower())}",
val_loss,
Expand All @@ -615,6 +617,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
batch_size=batch.shape[0],
sync_dist=True,
)

for mname, mvalue in metrics.items():
self.log(
"val_" + mname,
Expand All @@ -626,6 +629,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
batch_size=batch.shape[0],
sync_dist=True,
)

return val_loss, y_preds

def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]:
Expand Down
32 changes: 27 additions & 5 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.checkpoint import transfer_learning_loading
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed

Expand Down Expand Up @@ -62,9 +63,8 @@ def __init__(self, config: DictConfig) -> None:
OmegaConf.resolve(config)
self.config = config

# Default to not warm-starting from a checkpoint
self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)
self.load_weights_only = config.training.load_weights_only
self.load_weights_only = self.config.training.load_weights_only
self.parent_uuid = None

self.config.training.run_id = self.run_id
Expand All @@ -83,6 +83,8 @@ def datamodule(self) -> AnemoiDatasetsDataModule:
"""DataModule instance and DataSets."""
datamodule = AnemoiDatasetsDataModule(self.config)
self.config.data.num_features = len(datamodule.ds_train.data.variables)
LOGGER.info("Number of data variables: %s", str(len(datamodule.ds_train.data.variables)))
LOGGER.debug("Variables: %s", str(datamodule.ds_train.data.variables))
return datamodule

@cached_property
Expand Down Expand Up @@ -145,10 +147,21 @@ def model(self) -> GraphForecaster:
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
}

model = GraphForecaster(**kwargs)

if self.load_weights_only:
# Sanify the checkpoint for transfer learning
if self.config.training.transfer_learning:
LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
return transfer_learning_loading(model, self.last_checkpoint)

LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster(**kwargs)

return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)

LOGGER.info("Model initialised from scratch.")
return model

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
Expand Down Expand Up @@ -200,6 +213,7 @@ def last_checkpoint(self) -> str | None:
fork_id or self.lineage_run,
self.config.hardware.files.warm_start or "last.ckpt",
)

# Check if the last checkpoint exists
if Path(checkpoint).exists():
LOGGER.info("Resuming training from last checkpoint: %s", checkpoint)
Expand Down Expand Up @@ -296,11 +310,15 @@ def _log_information(self) -> None:
* self.config.hardware.num_gpus_per_node
/ self.config.hardware.num_gpus_per_model
)

LOGGER.debug(
"Total GPU count / model group size: %d - NB: the learning rate will be scaled by this factor!",
total_number_of_model_instances,
)
LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate)
LOGGER.debug(
"Effective learning rate: %.3e",
int(total_number_of_model_instances) * self.config.training.lr.rate,
)
LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start)

if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1):
Expand Down Expand Up @@ -352,6 +370,8 @@ def strategy(self) -> DDPGroupStrategy:

def train(self) -> None:
"""Training entry point."""
LOGGER.debug("Setting up trainer..")

trainer = pl.Trainer(
accelerator=self.accelerator,
callbacks=self.callbacks,
Expand All @@ -378,6 +398,8 @@ def train(self) -> None:
enable_progress_bar=self.config.diagnostics.enable_progress_bar,
)

LOGGER.debug("Starting training..")

trainer.fit(
self.model,
datamodule=self.datamodule,
Expand Down
28 changes: 27 additions & 1 deletion src/anemoi/training/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from __future__ import annotations

import logging
from pathlib import Path

import torch
import torch.nn as nn
from anemoi.utils.checkpoints import save_metadata

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)


def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]:
"""Load the lightning checkpoint and extract the pytorch model and its metadata.
Expand Down Expand Up @@ -65,3 +68,26 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path:
torch.save(model, inference_filepath)
save_metadata(inference_filepath, metadata)
return inference_filepath


def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> nn.Module:

# Load the checkpoint
checkpoint = torch.load(ckpt_path, map_location=model.device)

# Filter out layers with size mismatch
state_dict = checkpoint["state_dict"]

model_state_dict = model.state_dict()

for key in state_dict.copy():
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
LOGGER.info("Skipping loading parameter: %s", key)
LOGGER.info("Checkpoint shape: %s", str(state_dict[key].shape))
LOGGER.info("Model shape: %s", str(model_state_dict[key].shape))

del state_dict[key] # Remove the mismatched key

# Load the filtered st-ate_dict into the model
model.load_state_dict(state_dict, strict=False)
return model

0 comments on commit 891405e

Please sign in to comment.