Skip to content

Commit

Permalink
add FileSystemReaderWithValidation
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 13, 2023
1 parent 7e7b13f commit 0e042c0
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch
from packaging import version
from torch.distributed.checkpoint.metadata import Metadata

from composer.utils import dist, reproducibility
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, format_name_with_dist,
Expand Down Expand Up @@ -67,6 +68,9 @@ def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Pat
Args:
checkpoint_filepath (Union[Path,str]): The path to the checkpoint file.
Raises:
ValueError if checkpoint file is invalid.
"""
fn_name = os.environ.get('CHECKPOINT_VALIDATION_FUNCTION', None)

Expand Down Expand Up @@ -381,7 +385,6 @@ def load_sharded_checkpoint(
_validate_load_planner(load_planner)

from torch.distributed import checkpoint as dist_cp
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner

Expand All @@ -395,8 +398,35 @@ def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = set(rng_inds)
return len(rng_inds)

class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
"""FileSystemReader that validates checkpoint files prior to reading."""

def __init__(self, path: str):
super().__init__(path)

def read_data(self, plan: LoadPlan, planner: LoadPlanner):
"""Reads data file.
Raises:
ValueError if the data file is invalid.
"""
for read_item in plan.items:
data_path = self.path / self.storage_data[read_item.storage_index].relative_path
_ensure_valid_checkpoint(data_path)
return super().read_data(plan, planner)

def read_metadata(self) -> Metadata:
"""Reads metadata file.
Raises:
ValueError if the metadata file is invalid.
"""
metadata_file_path = self.path / '.metadata'
_ensure_valid_checkpoint(metadata_file_path)
return super().read_metadata()

# A subclass of FileSystemReader that downloads files from the object store before reading them from the local filesystem.
class DistCPObjectStoreReader(dist_cp.FileSystemReader):
class DistCPObjectStoreReader(FileSystemReaderWithValidation):

def __init__(self, source_path: str, destination_path: str, object_store):
self.source_path = source_path
Expand Down Expand Up @@ -458,7 +488,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
Path(rank0_download_tempdir) / Path('checkpoints')),
object_store=object_store)
else:
storage_reader = dist_cp.FileSystemReader(source_path)
storage_reader = FileSystemReaderWithValidation(source_path)

# We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph.
with torch.no_grad():
Expand Down

0 comments on commit 0e042c0

Please sign in to comment.