Skip to content

Commit

Permalink
Add VAE_NAME output to the Parameter Generator node and add `vae_…
Browse files Browse the repository at this point in the history
…name` input to the `Prompt Saver` node #39
  • Loading branch information
receyuki committed Jan 5, 2024
1 parent 3b3da02 commit 500a1fd
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def IS_CHANGED(s, image, parameter_index):


class SDPromptSaver:
hash_dict = {}
model_hash_dict = {}
vae_hash_dict = {}

def __init__(self):
self.output_dir = folder_paths.get_output_directory()
Expand All @@ -261,6 +262,7 @@ def INPUT_TYPES(s):
"path": ("STRING", {"default": "%date/", "multiline": False}),
"model_name": (folder_paths.get_filename_list("checkpoints"),),
# "model_name_str": ("STRING", {"default": ""}),
"vae_name": (folder_paths.get_filename_list("vae"),),
"seed": (
"INT",
{
Expand Down Expand Up @@ -298,7 +300,7 @@ def INPUT_TYPES(s):
"positive": ("STRING", {"default": "", "multiline": True}),
"negative": ("STRING", {"default": "", "multiline": True}),
"extension": (["png", "jpg", "webp"],),
"calculate_model_hash": ("BOOLEAN", {"default": False}),
"calculate_hash": ("BOOLEAN", {"default": False}),
"lossless_webp": ("BOOLEAN", {"default": True}),
"jpg_webp_quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"date_format": (
Expand Down Expand Up @@ -330,6 +332,7 @@ def save_images(
path: str = "%date/",
model_name: str = "",
model_name_str: str = "",
vae_name: str = "",
seed: int = 0,
steps: int = 0,
cfg: float = 0.0,
Expand All @@ -342,7 +345,7 @@ def save_images(
positive: str = "",
negative: str = "",
extension: str = "png",
calculate_model_hash: bool = False,
calculate_hash: bool = False,
lossless_webp: bool = True,
jpg_webp_quality: int = 100,
date_format: str = "%Y-%m-%d",
Expand Down Expand Up @@ -401,15 +404,24 @@ def save_images(
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None

model_hash_str = ""
vae_hash_str = ""
vae_str = ""

if vae_name:
vae_str = f"VAE: {vae_name}"

hashes = {}
if calculate_model_hash:
model_hash = self.calculate_model_hash(model_name_real)
if calculate_hash:
model_hash = self.calculate_hash(model_name_real, "model")
model_hash_str = f"Model hash: {model_hash}, "
hashes["model"] = model_hash
else:
model_hash_str = ""
if vae_name:
vae_hash = self.calculate_hash(vae_name, "vae")
vae_hash_str = f"VAE hash: {vae_hash}, "
hashes["vae"] = vae_hash

hashes_str = f"Hashes: {json.dumps(hashes)}, " if hashes else ""
hashes_str = f", Hashes: {json.dumps(hashes)}" if hashes else ""

comment = (
f"{positive}\n"
Expand All @@ -421,6 +433,8 @@ def save_images(
f"Size: {img.width if width==0 else width}x{img.height if height==0 else height}, "
f"{model_hash_str}"
f"Model: {Path(model_name_real).stem}, "
f"{vae_hash_str}"
f"{vae_str}"
f"Version: ComfyUI"
f"{hashes_str}"
f"{extra_info_real}"
Expand Down Expand Up @@ -477,20 +491,29 @@ def save_images(
return {"ui": {"images": results}, "result": (files, file_paths, comments)}

@staticmethod
def calculate_model_hash(model_name):
if hash_value := SDPromptSaver.hash_dict.get(model_name):
def calculate_hash(name, hash_type):
match hash_type:
case "model":
hash_dict = SDPromptSaver.model_hash_dict
file_name = folder_paths.get_full_path("checkpoints", name)
case "vae":
hash_dict = SDPromptSaver.vae_hash_dict
file_name = folder_paths.get_full_path("vae", name)
case _:
return ""

if hash_value := hash_dict.get(name):
return hash_value

hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
file_name = folder_paths.get_full_path("checkpoints", model_name)

with open(file_name, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)

hash_value = hash_sha256.hexdigest()[:10]
SDPromptSaver.hash_dict[model_name] = hash_value
hash_dict[name] = hash_value

return hash_value

Expand Down Expand Up @@ -639,6 +662,7 @@ def INPUT_TYPES(s):

RETURN_TYPES = (
folder_paths.get_filename_list("checkpoints"),
folder_paths.get_filename_list("vae"),
"MODEL",
"CLIP",
"VAE",
Expand All @@ -658,6 +682,7 @@ def INPUT_TYPES(s):

RETURN_NAMES = (
"MODEL_NAME",
"VAE_NAME",
"MODEL",
"CLIP",
"VAE",
Expand Down Expand Up @@ -721,10 +746,15 @@ def generate_parameter(
)[:3]

if vae_name != "baked VAE":
vae_name_real = vae_name
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
checkpoint = (*checkpoint[:2], vae)
vae_str = f"VAE: {vae_name}, \n"
else:
vae_str = ""
vae_name_real = ""

if aspect_ratio != "custom":
aspect_ratio_value = aspect_ratio.split(" - ")[0]
Expand All @@ -750,6 +780,7 @@ def generate_parameter(

parameters = (
f"Model: {ckpt_name},\n"
f"{vae_str}"
f"Seed: {str(seed)},\n"
f"Steps: {str(steps)},\n"
f"CFG scale: {str(cfg)},\n"
Expand All @@ -776,7 +807,10 @@ def generate_parameter(
)
},
"result": (
(ckpt_name,)
(
ckpt_name,
vae_name_real,
)
+ checkpoint
+ (
seed,
Expand Down

0 comments on commit 500a1fd

Please sign in to comment.