Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward Compat with Torchmetrics #2046

Merged
merged 10 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchmetrics.metric import jit_distributed_available

from composer.core.data_spec import DataSpec
from composer.core.event import Event
Expand Down Expand Up @@ -149,12 +150,26 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
# v0.4.1 removed the leading underscores for the keys in the state_dict
# It also renamed _is_model_ddp_wrapped to is_model_ddp
state = {}
for k, v in state_dict.items():
if k == '_is_model_ddp_wrapped':
k = 'is_model_ddp'
if k.startswith('_'):
k = k[1:]
state[k] = v
for attribute_name, serialized_value in state_dict.items():
if attribute_name == '_is_model_ddp_wrapped':
attribute_name = 'is_model_ddp'
if attribute_name.startswith('_'):
attribute_name = attribute_name[1:]
# Torchmetrics adds a new attribute as of 0.11 which must be added to deserialized metrics
if attribute_name == 'train_metrics':
for metric_name in serialized_value.keys():
metric = serialized_value[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
for evaluator_name, eval_metrics in serialized_value.items():
for metric_name in eval_metrics.keys():
metric = eval_metrics[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[evaluator_name][metric_name] = metric
state[attribute_name] = serialized_value
return state


Expand Down Expand Up @@ -1049,14 +1064,14 @@ def load_state_dict(
elif attribute_name == 'train_metrics':
state_field_value = getattr(self, attribute_name)
for metric_name, metric in serialized_value.items():
state_field_value[metric_name] = metric
metric._device = self.device._device
state_field_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
state_field_value = getattr(self, attribute_name)
for eval_key, eval_metrics in serialized_value.items():
for metric_name, metric in eval_metrics.items():
state_field_value[eval_key][metric_name] = metric
metric._device = self.device._device
state_field_value[eval_key][metric_name] = metric
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
state_field_value = getattr(self, attribute_name)
for target in ensure_tuple(state_field_value):
Expand Down
6 changes: 3 additions & 3 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from composer.metrics import InContextLearningMetric
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, get_file, import_object
from composer.utils import MissingConditionalImportError, get_file, import_object, safe_torch_load

if TYPE_CHECKING:
import transformers
Expand Down Expand Up @@ -226,7 +226,7 @@ def hf_from_composer_checkpoint(
get_file(checkpoint_path, str(local_checkpoint_save_location))

# load the state dict in
loaded_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu')
loaded_state_dict = safe_torch_load(local_checkpoint_save_location)

hf_state = loaded_state_dict['state']['integrations']['huggingface']
hf_model_state = hf_state['model']
Expand Down Expand Up @@ -512,7 +512,7 @@ def write_huggingface_pretrained_from_composer_checkpoint(
# download the checkpoint file
get_file(str(checkpoint_path), str(local_checkpoint_save_location))

composer_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu')
composer_state_dict = safe_torch_load(local_checkpoint_save_location)

config = get_hf_config_from_composer_state_dict(composer_state_dict)
config.save_pretrained(output_folder)
Expand Down
3 changes: 2 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from composer.utils.auto_log_hparams import (convert_flat_dict_to_nested_dict, convert_nested_dict_to_flat_dict,
extract_hparams)
from composer.utils.batch_helpers import batch_get, batch_set
from composer.utils.checkpoint import PartialFilePath, load_checkpoint, save_checkpoint
from composer.utils.checkpoint import PartialFilePath, load_checkpoint, safe_torch_load, save_checkpoint
from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report,
get_composer_env_dict, print_env)
from composer.utils.device import get_device, is_tpu_installed
Expand Down Expand Up @@ -48,6 +48,7 @@
'StringEnum',
'load_checkpoint',
'save_checkpoint',
'safe_torch_load',
'ensure_folder_is_empty',
'ensure_folder_has_no_conflicting_files',
'export_for_inference',
Expand Down
22 changes: 21 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tempfile
import textwrap
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -399,6 +400,25 @@ def filter_func(state_dict: Dict) -> None:
return filter_func


def safe_torch_load(composer_states_filepath: Union[Path, str], map_location: str = 'cpu'):
"""Load a torch checkpoint, catching errors due to backwards compatibility issues.

Args:
composer_states_filepath: The path to the checkpoint file.
map_location: The location to load the checkpoint to.
"""
try:
state_dict = torch.load(composer_states_filepath, map_location=map_location)
return state_dict
except TypeError as e:
if 'Accuracy.__new__() missing 1 required positional argument' in str(e):
raise Exception('As of v0.10.0, torchmetrics introduces a new required argument to Accuracy which '
'breaks backwards compatibility. Unfortunately, this means that older checkpoints '
'cannot be loaded with the metrics. In order to successfully load this model, please '
'pass `load_ignore_keys = ["state/train_metrics/*", "state/eval_metrics/*"]`.') from e
raise e


def _restore_checkpoint(
state: State,
logger: Logger,
Expand All @@ -413,7 +433,7 @@ def _restore_checkpoint(
) -> Optional[List[Dict[str, Any]]]:
"""Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False)."""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = torch.load(composer_states_filepath, map_location='cpu')
state_dict = safe_torch_load(composer_states_filepath)
if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn

from composer.utils import dist
from composer.utils.checkpoint import download_checkpoint
from composer.utils.checkpoint import download_checkpoint, safe_torch_load
from composer.utils.device import get_device
from composer.utils.iter_helpers import ensure_tuple
from composer.utils.misc import is_model_ddp, is_model_deepspeed, model_eval_mode
Expand Down Expand Up @@ -172,7 +172,7 @@ def export_for_inference(
node_checkpoint_folder=tempdir,
object_store=load_object_store,
progress_bar=True)
state_dict = torch.load(composer_states_filepath, map_location='cpu')
state_dict = safe_torch_load(composer_states_filepath)
missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict)
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
Expand Down