From 75682d2c00d3ec41c8c58b54c06621aab1c968c7 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 6 Jul 2024 11:32:04 -1000 Subject: [PATCH] fix loading sharded checkpoints from subfolder (#8798) * fix load sharded checkpoints from subfolder{ * style * os.path.join * add a small test --------- Co-authored-by: sayakpaul --- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/utils/hub_utils.py | 7 ++++++- tests/models/unets/test_models_unet_2d_condition.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5604879f40ab..ebd356d981d6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -221,7 +221,7 @@ def _fetch_index_file( local_files_only=local_files_only, token=token, revision=revision, - subfolder=subfolder, + subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ce90fb09193b..7ecb7de89cd3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,10 +455,13 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames + if subfolder is not None: + allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] + ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path) + model_files_info = model_info(pretrained_model_name_or_path, revision=revision) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: @@ -481,6 +484,8 @@ def _get_checkpoint_shard_files( ignore_patterns=ignore_patterns, user_agent=user_agent, ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 63e66dabf0c8..a84968e613b5 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1045,6 +1045,18 @@ def test_load_sharded_checkpoint_from_hub(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" + ) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common()