Skip to content

Commit

Permalink
add simple error check to model loading (comfyanonymous#4950)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmonkey4eva authored Sep 17, 2024
1 parent 0b7dfa9 commit 254838f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 23 deletions.
2 changes: 1 addition & 1 deletion comfy_extras/nodes_hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_photomaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def INPUT_TYPES(s):
CATEGORY = "_for_testing/photomaker"

def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
Expand Down
6 changes: 3 additions & 3 deletions comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/loaders"

def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)

Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_video_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders/video_models"

def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])

Expand Down
8 changes: 8 additions & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:

return None


def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path


def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
Expand Down
32 changes: 16 additions & 16 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def INPUT_TYPES(s):

def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))

class CheckpointLoaderSimple:
Expand All @@ -536,7 +536,7 @@ def INPUT_TYPES(s):
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."

def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]

Expand Down Expand Up @@ -578,7 +578,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out

Expand Down Expand Up @@ -625,7 +625,7 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)

lora_path = folder_paths.get_full_path("loras", lora_name)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
Expand Down Expand Up @@ -704,11 +704,11 @@ def load_taesd(name):
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))

enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]

dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]

Expand Down Expand Up @@ -739,7 +739,7 @@ def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
Expand All @@ -755,7 +755,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,)

Expand All @@ -771,7 +771,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,)

Expand Down Expand Up @@ -871,7 +871,7 @@ def load_unet(self, unet_name, weight_dtype):
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2

unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)

Expand All @@ -896,7 +896,7 @@ def load_clip(self, clip_name, type="stable_diffusion"):
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION

clip_path = folder_paths.get_full_path("clip", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)

Expand All @@ -913,8 +913,8 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/loaders"

def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
Expand All @@ -936,7 +936,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,)

Expand Down Expand Up @@ -966,7 +966,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,)

Expand Down Expand Up @@ -1031,7 +1031,7 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)

Expand Down

0 comments on commit 254838f

Please sign in to comment.