Skip to content

Commit

Permalink
Fix for use_safetensors parameters, allow use of parameter on loading…
Browse files Browse the repository at this point in the history
… submodels (#9576)
  • Loading branch information
elismasilva committed Oct 4, 2024
1 parent 99f6082 commit 751893d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_safetensors = kwargs.pop("use_safetensors", None)

allow_pickle = False
if use_safetensors is None:
if use_safetensors is None or use_safetensors:
use_safetensors = True
allow_pickle = True

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def load_sub_model(
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""

Expand Down Expand Up @@ -670,6 +671,7 @@ def load_sub_model(
loading_kwargs["offload_folder"] = offload_folder
loading_kwargs["offload_state_dict"] = offload_state_dict
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors

if from_flax:
loading_kwargs["from_flax"] = True
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def load_module(name, value):
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down

0 comments on commit 751893d

Please sign in to comment.