diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 661b3046ba..26ea64090a 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -294,10 +294,15 @@ def __init__( num_concurrent_uploads: int = 1, upload_timeout_in_seconds: int = 3600, ): + backend, _, local_folder = parse_uri(str(folder)) if local_folder == '': local_folder = '.' + is_remote_folder = backend != '' + if is_remote_folder: # If uploading to a remote path, use a temporary directory to save local checkpoints. + local_folder = os.path.join(tempfile.mkdtemp(), local_folder) + filename = str(filename) remote_file_name = str(remote_file_name) if remote_file_name is not None else None latest_filename = str(latest_filename) if latest_filename is not None else None diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 3e93ce56b3..8ae247fabf 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -882,7 +882,8 @@ def _get_tmp_dir(self): if delete_local: # delete files locally, forcing trainer to look in object store - shutil.rmtree('first') + assert trainer_1._checkpoint_saver is not None + shutil.rmtree(trainer_1._checkpoint_saver.folder) trainer_2 = self.get_trainer( latest_filename=latest_filename,