Skip to content

Commit

Permalink
Backward Compat with Torchmetrics (#2046)
Browse files Browse the repository at this point in the history
* backwards compatbility

* fix test

* debug

* safe load

* add logs

* remove print

* add print

* flip check

* fix eval

* add lint
  • Loading branch information
mvpatel2000 authored Mar 9, 2023
1 parent 25380a8 commit 8e04fd5
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 15 deletions.
31 changes: 23 additions & 8 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchmetrics.metric import jit_distributed_available

from composer.core.data_spec import DataSpec
from composer.core.event import Event
Expand Down Expand Up @@ -149,12 +150,26 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
# v0.4.1 removed the leading underscores for the keys in the state_dict
# It also renamed _is_model_ddp_wrapped to is_model_ddp
state = {}
for k, v in state_dict.items():
if k == '_is_model_ddp_wrapped':
k = 'is_model_ddp'
if k.startswith('_'):
k = k[1:]
state[k] = v
for attribute_name, serialized_value in state_dict.items():
if attribute_name == '_is_model_ddp_wrapped':
attribute_name = 'is_model_ddp'
if attribute_name.startswith('_'):
attribute_name = attribute_name[1:]
# Torchmetrics adds a new attribute as of 0.11 which must be added to deserialized metrics
if attribute_name == 'train_metrics':
for metric_name in serialized_value.keys():
metric = serialized_value[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
for evaluator_name, eval_metrics in serialized_value.items():
for metric_name in eval_metrics.keys():
metric = eval_metrics[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[evaluator_name][metric_name] = metric
state[attribute_name] = serialized_value
return state


Expand Down Expand Up @@ -1049,14 +1064,14 @@ def load_state_dict(
elif attribute_name == 'train_metrics':
state_field_value = getattr(self, attribute_name)
for metric_name, metric in serialized_value.items():
state_field_value[metric_name] = metric
metric._device = self.device._device
state_field_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
state_field_value = getattr(self, attribute_name)
for eval_key, eval_metrics in serialized_value.items():
for metric_name, metric in eval_metrics.items():
state_field_value[eval_key][metric_name] = metric
metric._device = self.device._device
state_field_value[eval_key][metric_name] = metric
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
state_field_value = getattr(self, attribute_name)
for target in ensure_tuple(state_field_value):
Expand Down
6 changes: 3 additions & 3 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from composer.metrics import InContextLearningMetric
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, get_file, import_object
from composer.utils import MissingConditionalImportError, get_file, import_object, safe_torch_load

if TYPE_CHECKING:
import transformers
Expand Down Expand Up @@ -226,7 +226,7 @@ def hf_from_composer_checkpoint(
get_file(checkpoint_path, str(local_checkpoint_save_location))

# load the state dict in
loaded_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu')
loaded_state_dict = safe_torch_load(local_checkpoint_save_location)

hf_state = loaded_state_dict['state']['integrations']['huggingface']
hf_model_state = hf_state['model']
Expand Down Expand Up @@ -512,7 +512,7 @@ def write_huggingface_pretrained_from_composer_checkpoint(
# download the checkpoint file
get_file(str(checkpoint_path), str(local_checkpoint_save_location))

composer_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu')
composer_state_dict = safe_torch_load(local_checkpoint_save_location)

config = get_hf_config_from_composer_state_dict(composer_state_dict)
config.save_pretrained(output_folder)
Expand Down
3 changes: 2 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from composer.utils.auto_log_hparams import (convert_flat_dict_to_nested_dict, convert_nested_dict_to_flat_dict,
extract_hparams)
from composer.utils.batch_helpers import batch_get, batch_set
from composer.utils.checkpoint import PartialFilePath, load_checkpoint, save_checkpoint
from composer.utils.checkpoint import PartialFilePath, load_checkpoint, safe_torch_load, save_checkpoint
from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report,
get_composer_env_dict, print_env)
from composer.utils.device import get_device, is_tpu_installed
Expand Down Expand Up @@ -48,6 +48,7 @@
'StringEnum',
'load_checkpoint',
'save_checkpoint',
'safe_torch_load',
'ensure_folder_is_empty',
'ensure_folder_has_no_conflicting_files',
'export_for_inference',
Expand Down
22 changes: 21 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tempfile
import textwrap
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -399,6 +400,25 @@ def filter_func(state_dict: Dict) -> None:
return filter_func


def safe_torch_load(composer_states_filepath: Union[Path, str], map_location: str = 'cpu'):
"""Load a torch checkpoint, catching errors due to backwards compatibility issues.
Args:
composer_states_filepath: The path to the checkpoint file.
map_location: The location to load the checkpoint to.
"""
try:
state_dict = torch.load(composer_states_filepath, map_location=map_location)
return state_dict
except TypeError as e:
if 'Accuracy.__new__() missing 1 required positional argument' in str(e):
raise Exception('As of v0.10.0, torchmetrics introduces a new required argument to Accuracy which '
'breaks backwards compatibility. Unfortunately, this means that older checkpoints '
'cannot be loaded with the metrics. In order to successfully load this model, please '
'pass `load_ignore_keys = ["state/train_metrics/*", "state/eval_metrics/*"]`.') from e
raise e


def _restore_checkpoint(
state: State,
logger: Logger,
Expand All @@ -413,7 +433,7 @@ def _restore_checkpoint(
) -> Optional[List[Dict[str, Any]]]:
"""Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False)."""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = torch.load(composer_states_filepath, map_location='cpu')
state_dict = safe_torch_load(composer_states_filepath)
if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn

from composer.utils import dist
from composer.utils.checkpoint import download_checkpoint
from composer.utils.checkpoint import download_checkpoint, safe_torch_load
from composer.utils.device import get_device
from composer.utils.iter_helpers import ensure_tuple
from composer.utils.misc import is_model_ddp, is_model_deepspeed, model_eval_mode
Expand Down Expand Up @@ -172,7 +172,7 @@ def export_for_inference(
node_checkpoint_folder=tempdir,
object_store=load_object_store,
progress_bar=True)
state_dict = torch.load(composer_states_filepath, map_location='cpu')
state_dict = safe_torch_load(composer_states_filepath)
missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict)
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
Expand Down

0 comments on commit 8e04fd5

Please sign in to comment.