Skip to content

Commit

Permalink
Use a temp path to save local checkpoints for remote save path (#3673)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 22, 2024
1 parent 94a80f2 commit ef42f54
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,15 @@ def __init__(
num_concurrent_uploads: int = 1,
upload_timeout_in_seconds: int = 3600,
):

backend, _, local_folder = parse_uri(str(folder))
if local_folder == '':
local_folder = '.'

is_remote_folder = backend != ''
if is_remote_folder: # If uploading to a remote path, use a temporary directory to save local checkpoints.
local_folder = os.path.join(tempfile.mkdtemp(), local_folder)

filename = str(filename)
remote_file_name = str(remote_file_name) if remote_file_name is not None else None
latest_filename = str(latest_filename) if latest_filename is not None else None
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,8 @@ def _get_tmp_dir(self):

if delete_local:
# delete files locally, forcing trainer to look in object store
shutil.rmtree('first')
assert trainer_1._checkpoint_saver is not None
shutil.rmtree(trainer_1._checkpoint_saver.folder)

trainer_2 = self.get_trainer(
latest_filename=latest_filename,
Expand Down

0 comments on commit ef42f54

Please sign in to comment.