diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5604879f40ab..ebd356d981d6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -221,7 +221,7 @@ def _fetch_index_file( local_files_only=local_files_only, token=token, revision=revision, - subfolder=subfolder, + subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ce90fb09193b..7ecb7de89cd3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,10 +455,13 @@ def _get_checkpoint_shard_files( # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames + if subfolder is not None: + allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] + ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path) + model_files_info = model_info(pretrained_model_name_or_path, revision=revision) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: @@ -481,6 +484,8 @@ def _get_checkpoint_shard_files( ignore_patterns=ignore_patterns, user_agent=user_agent, ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 63e66dabf0c8..a84968e613b5 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1045,6 +1045,18 @@ def test_load_sharded_checkpoint_from_hub(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" + ) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common()