From 57224023f3c2c530ae4f90f32b1577adf3dc9286 Mon Sep 17 00:00:00 2001 From: btian Date: Mon, 9 Dec 2024 15:08:07 +0800 Subject: [PATCH] feat: merged from archive/diffus-20241012 --- .gitignore | 2 + api_server/routes/internal/internal_routes.py | 40 +- api_server/services/file_service.py | 12 +- app/app_settings.py | 10 +- app/user_manager.py | 154 +- comfy/utils.py | 8 +- comfy_execution/caching.py | 58 +- comfy_execution/graph.py | 23 +- comfy_extras/nodes_audio.py | 34 +- comfy_extras/nodes_clip_sdxl.py | 4 +- comfy_extras/nodes_cond.py | 1 + comfy_extras/nodes_custom_sampler.py | 14 +- comfy_extras/nodes_flux.py | 2 +- comfy_extras/nodes_hooks.py | 48 +- comfy_extras/nodes_hunyuan.py | 2 +- comfy_extras/nodes_hypernetwork.py | 12 +- comfy_extras/nodes_images.py | 17 +- comfy_extras/nodes_lora_extract.py | 11 +- comfy_extras/nodes_model_merging.py | 61 +- comfy_extras/nodes_photomaker.py | 13 +- comfy_extras/nodes_sd3.py | 18 +- comfy_extras/nodes_upscale_model.py | 12 +- comfy_extras/nodes_video_model.py | 26 +- comfy_extras/nodes_webcam.py | 6 +- diffus/__init__.py | 0 diffus/database.py | 41 + diffus/decoded_params.py | 1831 +++++++++++++++++ diffus/image_gallery.py | 174 ++ diffus/message.py | 23 + diffus/models.py | 31 + diffus/redis_client.py | 21 + diffus/repository.py | 119 ++ diffus/service_registrar.py | 63 + diffus/system_monitor.py | 324 +++ diffus/task_queue.py | 244 +++ execution.py | 69 +- execution_context.py | 152 ++ folder_paths.py | 150 +- latent_preview.py | 10 +- main.py | 65 +- node_helpers.py | 32 +- nodes.py | 332 +-- requirements.txt | 5 + server.py | 244 ++- 44 files changed, 3963 insertions(+), 555 deletions(-) create mode 100644 diffus/__init__.py create mode 100644 diffus/database.py create mode 100644 diffus/decoded_params.py create mode 100644 diffus/image_gallery.py create mode 100644 diffus/message.py create mode 100644 diffus/models.py create mode 100644 diffus/redis_client.py create mode 100644 diffus/repository.py create mode 100644 diffus/service_registrar.py create mode 100644 diffus/system_monitor.py create mode 100644 diffus/task_queue.py create mode 100644 execution_context.py diff --git a/.gitignore b/.gitignore index 61881b8a4f3..431ffbb13d5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ venv/ *.log web_custom_versions/ .DS_Store + +custom_nodes.bak* diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index aaefa9335f5..db22566ec4a 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -1,6 +1,8 @@ from aiohttp import web from typing import Optional -from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths + +import execution_context +from folder_paths import get_models_dir, get_user_directory, get_output_directory, folder_names_and_paths from api_server.services.file_service import FileService from api_server.services.terminal_service import TerminalService import app.logger @@ -16,9 +18,9 @@ def __init__(self, prompt_server): self.routes: web.RouteTableDef = web.RouteTableDef() self._app: Optional[web.Application] = None self.file_service = FileService({ - "models": models_dir, - "user": user_directory, - "output": output_directory + # "models": get_models_dir, + "user": get_user_directory, + "output": get_output_directory }) self.prompt_server = prompt_server self.terminal_service = TerminalService(prompt_server) @@ -27,39 +29,43 @@ def setup_routes(self): @self.routes.get('/files') async def list_files(request): directory_key = request.query.get('directory', '') + context = execution_context.ExecutionContext(request) try: - file_list = self.file_service.list_files(directory_key) + file_list = self.file_service.list_files(context, directory_key) return web.json_response({"files": file_list}) except ValueError as e: return web.json_response({"error": str(e)}, status=400) except Exception as e: return web.json_response({"error": str(e)}, status=500) - @self.routes.get('/logs') async def get_logs(request): - return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()])) + # return web.json_response(app.logger.get_logs()) + return web.json_response([]) @self.routes.get('/logs/raw') async def get_logs(request): self.terminal_service.update_size() + # return web.json_response({ + # "entries": list(app.logger.get_logs()), + # "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows} + # }) return web.json_response({ - "entries": list(app.logger.get_logs()), - "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows} + "entries": [], + "size": {"cols": 0, "rows": 0} }) @self.routes.patch('/logs/subscribe') async def subscribe_logs(request): - json_data = await request.json() - client_id = json_data["clientId"] - enabled = json_data["enabled"] - if enabled: - self.terminal_service.subscribe(client_id) - else: - self.terminal_service.unsubscribe(client_id) + # json_data = await request.json() + # client_id = json_data["clientId"] + # enabled = json_data["enabled"] + # if enabled: + # self.terminal_service.subscribe(client_id) + # else: + # self.terminal_service.unsubscribe(client_id) return web.Response(status=200) - @self.routes.get('/folder_paths') async def get_folder_paths(request): response = {} diff --git a/api_server/services/file_service.py b/api_server/services/file_service.py index 394571084e9..6e31fc6fbac 100644 --- a/api_server/services/file_service.py +++ b/api_server/services/file_service.py @@ -1,13 +1,15 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Callable + +import execution_context from api_server.utils.file_operations import FileSystemOperations, FileSystemItem class FileService: - def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None): - self.allowed_directories: Dict[str, str] = allowed_directories + def __init__(self, allowed_directories: Dict[str, Callable], file_system_ops: Optional[FileSystemOperations] = None): + self.allowed_directories: Dict[str, Callable] = allowed_directories self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations() - def list_files(self, directory_key: str) -> List[FileSystemItem]: + def list_files(self, context: execution_context.ExecutionContext, directory_key: str) -> List[FileSystemItem]: if directory_key not in self.allowed_directories: raise ValueError("Invalid directory key") - directory_path: str = self.allowed_directories[directory_key] + directory_path: str = self.allowed_directories[directory_key](context.user_hash) return self.file_system_ops.walk_directory(directory_path) \ No newline at end of file diff --git a/app/app_settings.py b/app/app_settings.py index 8c6edc56c1d..6339a4edc3b 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -8,8 +8,12 @@ def __init__(self, user_manager): self.user_manager = user_manager def get_settings(self, request): - file = self.user_manager.get_request_user_filepath( - request, "comfy.settings.json") + # use user private settings + file = self.user_manager.get_request_user_filepath(request, "comfy.settings.json") + if not os.path.isfile(file): + # use default user settings + file = self.user_manager.get_default_user_filepath("comfy.settings.json") + if os.path.isfile(file): with open(file) as f: return json.load(f) @@ -51,4 +55,4 @@ async def post_setting(request): settings = self.get_settings(request) settings[setting_id] = await request.json() self.save_settings(request, settings) - return web.Response(status=200) \ No newline at end of file + return web.Response(status=200) diff --git a/app/user_manager.py b/app/user_manager.py index e863b93dd29..fa3ec4307bc 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -8,6 +8,8 @@ import logging from aiohttp import web from urllib import parse + +import execution_context from comfy.cli_args import args import folder_paths from .app_settings import AppSettings @@ -32,26 +34,27 @@ def get_file_info(path: str, relative_to: str) -> FileInfo: class UserManager(): def __init__(self): - user_directory = folder_paths.get_user_directory() + # user_directory = folder_paths.get_user_directory('') self.settings = AppSettings(self) - if not os.path.exists(user_directory): - os.makedirs(user_directory, exist_ok=True) - if not args.multi_user: - print("****** User settings have been changed to be stored on the server instead of browser storage. ******") - print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") - - if args.multi_user: - if os.path.isfile(self.get_users_file()): - with open(self.get_users_file()) as f: - self.users = json.load(f) - else: - self.users = {} - else: - self.users = {"default": "default"} - - def get_users_file(self): - return os.path.join(folder_paths.get_user_directory(), "users.json") + # if not os.path.exists(user_directory): + # os.mkdir(user_directory) + # if not args.multi_user: + # print("****** User settings have been changed to be stored on the server instead of browser storage. ******") + # print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + + # if args.multi_user: + # if os.path.isfile(self.get_users_file()): + # with open(self.get_users_file()) as f: + # self.users = json.load(f) + # else: + # self.users = {} + # else: + # self.users = {"default": "default"} + self.users = {"default": "default"} + + def get_users_file(self, context: execution_context.ExecutionContext): + return os.path.join(folder_paths.get_user_directory(context.user_hash), "users.json") def get_request_user_id(self, request): user = "default" @@ -63,16 +66,25 @@ def get_request_user_id(self, request): return user + def get_default_user_filepath(self, file, type="userdata"): + return self._get_request_user_filepath("default", file, type, False) + def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() + context = execution_context.ExecutionContext(request) + if not context.user_hash: + raise Exception("user hash is not provided") + return self._get_request_user_filepath(context.user_hash, file, type, create_dir) + + def _get_request_user_filepath(self, user_hash, file, type="userdata", create_dir=True): + user_directory = folder_paths.get_user_directory(user_hash) if type == "userdata": root_dir = user_directory else: raise KeyError("Unknown filepath type:" + type) - user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + # user = self.get_request_user_id(request) + path = user_root = os.path.abspath(root_dir) # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -95,7 +107,7 @@ def get_request_user_filepath(self, request, file, type="userdata", create_dir=T return path - def add_user(self, name): + def add_user(self, context: execution_context.ExecutionContext, name): name = name.strip() if not name: raise ValueError("username not provided") @@ -104,7 +116,7 @@ def add_user(self, name): self.users[user_id] = name - with open(self.get_users_file(), "w") as f: + with open(self.get_users_file(context), "w") as f: json.dump(self.users, f) return user_id @@ -125,13 +137,14 @@ async def get_users(request): @routes.post("/users") async def post_users(request): - body = await request.json() - username = body["username"] - if username in self.users.values(): - return web.json_response({"error": "Duplicate username."}, status=400) - - user_id = self.add_user(username) - return web.json_response(user_id) + # body = await request.json() + # username = body["username"] + # if username in self.users.values(): + # return web.json_response({"error": "Duplicate username."}, status=400) + # + # user_id = self.add_user(username) + # return web.json_response(user_id) + return web.Response(status=403, text="Forbidden") @routes.get("/userdata") async def listuserdata(request): @@ -270,61 +283,32 @@ async def post_userdata(request): @routes.delete("/userdata/{file}") async def delete_userdata(request): - path = get_user_data_path(request, check_exists=True) - if not isinstance(path, str): - return path - - os.remove(path) - - return web.Response(status=204) + # path = get_user_data_path(request, check_exists=True) + # if not isinstance(path, str): + # return path + # + # os.remove(path) + # + # return web.Response(status=204) + return web.Response(status=403, text="Forbidden") @routes.post("/userdata/{file}/move/{dest}") async def move_userdata(request): - """ - Move or rename a user data file. - - This endpoint handles moving or renaming files within a user's data directory, with options for - controlling overwrite behavior and response format. - - Path Parameters: - - file: The source file path (URL encoded if necessary) - - dest: The destination file path (URL encoded if necessary) - - Query Parameters: - - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true". - - full_info (optional): If "true", returns detailed file information (path, size, modified time). - If "false", returns only the relative file path. - - Returns: - - 400: If either 'file' or 'dest' parameter is missing - - 403: If either requested path is not allowed - - 404: If the source file does not exist - - 409: If overwrite=false and the destination file already exists - - 200: JSON response with either: - - Full file information (if full_info=true) - - Relative file path (if full_info=false) - """ - source = get_user_data_path(request, check_exists=True) - if not isinstance(source, str): - return source - - dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): - return dest - - overwrite = request.query.get("overwrite", 'true') != "false" - full_info = request.query.get('full_info', 'false').lower() == "true" - - if not overwrite and os.path.exists(dest): - return web.Response(status=409, text="File already exists") - - logging.info(f"moving '{source}' -> '{dest}'") - shutil.move(source, dest) - - user_path = self.get_request_user_filepath(request, None) - if full_info: - resp = get_file_info(dest, user_path) - else: - resp = os.path.relpath(dest, user_path) - - return web.json_response(resp) + # source = get_user_data_path(request, check_exists=True) + # if not isinstance(source, str): + # return source + # + # dest = get_user_data_path(request, check_exists=False, param="dest") + # if not isinstance(source, str): + # return dest + # + # overwrite = request.query["overwrite"] != "false" + # if not overwrite and os.path.exists(dest): + # return web.Response(status=409) + # + # print(f"moving '{source}' -> '{dest}'") + # shutil.move(source, dest) + # + # resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) + # return web.json_response(resp) + return web.Response(status=403, text="Forbidden") diff --git a/comfy/utils.py b/comfy/utils.py index 985cd9a1bcf..12c13e1f10f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -30,7 +30,13 @@ def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") - if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): + if hasattr(ckpt, 'is_safetensors') and hasattr(ckpt, 'filename'): + is_safetensors = ckpt.is_safetensors + ckpt = ckpt.filename + else: + is_safetensors = ckpt.lower().endswith(".safetensors") + + if is_safetensors: sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 630f280fc5e..668ff18c022 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,5 +1,8 @@ import itertools from typing import Sequence, Mapping, Dict + +import execution_context +import node_helpers from comfy_execution.graph import DynamicPrompt import nodes @@ -9,11 +12,12 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} -def include_unique_id_in_input(class_type: str) -> bool: +def include_unique_id_in_input(context: execution_context.ExecutionContext, class_type: str) -> bool: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() + class_inputs = node_helpers.get_node_input_types(context, class_def) + NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_inputs.get("hidden", {}).values() return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class CacheKeySet: @@ -21,7 +25,7 @@ def __init__(self, dynprompt, node_ids, is_changed_cache): self.keys = {} self.subcache_keys = {} - def add_keys(self, node_ids): + def add_keys(self, context: execution_context.ExecutionContext, node_ids): raise NotImplementedError() def all_node_ids(self): @@ -57,12 +61,12 @@ def to_hashable(obj): return Unhashable() class CacheKeySetID(CacheKeySet): - def __init__(self, dynprompt, node_ids, is_changed_cache): + def __init__(self, context: execution_context.ExecutionContext, dynprompt, node_ids, is_changed_cache): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt - self.add_keys(node_ids) + self.add_keys(context, node_ids) - def add_keys(self, node_ids): + def add_keys(self, context: execution_context.ExecutionContext, node_ids): for node_id in node_ids: if node_id in self.keys: continue @@ -73,42 +77,42 @@ def add_keys(self, node_ids): self.subcache_keys[node_id] = (node_id, node["class_type"]) class CacheKeySetInputSignature(CacheKeySet): - def __init__(self, dynprompt, node_ids, is_changed_cache): + def __init__(self, context: execution_context.ExecutionContext, dynprompt, node_ids, is_changed_cache): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache - self.add_keys(node_ids) + self.add_keys(context, node_ids) def include_node_id_in_input(self) -> bool: return False - def add_keys(self, node_ids): + def add_keys(self, context: execution_context.ExecutionContext, node_ids): for node_id in node_ids: if node_id in self.keys: continue if not self.dynprompt.has_node(node_id): continue node = self.dynprompt.get_node(node_id) - self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.keys[node_id] = self.get_node_signature(context, self.dynprompt, node_id) self.subcache_keys[node_id] = (node_id, node["class_type"]) - def get_node_signature(self, dynprompt, node_id): + def get_node_signature(self, context: execution_context.ExecutionContext, dynprompt, node_id): signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) - signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + signature.append(self.get_immediate_node_signature(context, dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: - signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + signature.append(self.get_immediate_node_signature(context, dynprompt, ancestor_id, order_mapping)) return to_hashable(signature) - def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + def get_immediate_node_signature(self, context: execution_context.ExecutionContext, dynprompt, node_id, ancestor_order_mapping): if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, self.is_changed_cache.get(node_id)] - if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): + signature = [class_type, self.is_changed_cache.get(context, node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(context, class_type): signature.append(node_id) inputs = node["inputs"] for key in sorted(inputs.keys()): @@ -150,9 +154,9 @@ def __init__(self, key_class): self.cache = {} self.subcaches = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): + def set_prompt(self, context: execution_context.ExecutionContext, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt - self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.cache_key_set = self.key_class(context, dynprompt, node_ids, is_changed_cache) self.is_changed_cache = is_changed_cache self.initialized = True @@ -201,13 +205,13 @@ def _get_immediate(self, node_id): else: return None - def _ensure_subcache(self, node_id, children_ids): + def _ensure_subcache(self, context, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) self.subcaches[subcache_key] = subcache - subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + subcache.set_prompt(context, self.dynprompt, children_ids, self.is_changed_cache) return subcache def _get_subcache(self, node_id): @@ -259,10 +263,10 @@ def set(self, node_id, value): assert cache is not None cache._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + def ensure_subcache_for(self, context, node_id, children_ids): cache = self._get_cache_for(node_id) assert cache is not None - return cache._ensure_subcache(node_id, children_ids) + return cache._ensure_subcache(context, node_id, children_ids) class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): @@ -273,8 +277,8 @@ def __init__(self, key_class, max_size=100): self.used_generation = {} self.children = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): - super().set_prompt(dynprompt, node_ids, is_changed_cache) + def set_prompt(self, context, dynprompt, node_ids, is_changed_cache): + super().set_prompt(context, dynprompt, node_ids, is_changed_cache) self.generation += 1 for node_id in node_ids: self._mark_used(node_id) @@ -303,11 +307,11 @@ def set(self, node_id, value): self._mark_used(node_id) return self._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + def ensure_subcache_for(self, context: execution_context.ExecutionContext, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes - super()._ensure_subcache(node_id, children_ids) + super()._ensure_subcache(context, node_id, children_ids) - self.cache_key_set.add_keys(children_ids) + self.cache_key_set.add_keys(context, children_ids) self._mark_used(node_id) cache_key = self.cache_key_set.get_data_key(node_id) self.children[cache_key] = [] diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0b5bf189906..2b5738c66dd 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,3 +1,4 @@ +import node_helpers import nodes from comfy_execution.graph_utils import is_link @@ -54,8 +55,8 @@ def all_node_ids(self): def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name): - valid_inputs = class_def.INPUT_TYPES() +def get_input_info(context, class_def, input_name): + valid_inputs = node_helpers.get_node_input_types(context, class_def) input_info = None input_category = None if "required" in valid_inputs and input_name in valid_inputs["required"]: @@ -83,12 +84,12 @@ def __init__(self, dynprompt): self.blockCount = {} # Number of nodes this node is directly blocked by self.blocking = {} # Which nodes are blocked by this node - def get_input_info(self, unique_id, input_name): + def get_input_info(self, context, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return get_input_info(class_def, input_name) + return get_input_info(context, class_def, input_name) - def make_input_strong_link(self, to_node_id, to_input): + def make_input_strong_link(self, context, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] if to_input not in inputs: raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") @@ -96,17 +97,17 @@ def make_input_strong_link(self, to_node_id, to_input): if not is_link(value): raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") from_node_id, from_socket = value - self.add_strong_link(from_node_id, from_socket, to_node_id) + self.add_strong_link(context, from_node_id, from_socket, to_node_id) - def add_strong_link(self, from_node_id, from_socket, to_node_id): + def add_strong_link(self, context, from_node_id, from_socket, to_node_id): if not self.is_cached(from_node_id): - self.add_node(from_node_id) + self.add_node(context, from_node_id) if to_node_id not in self.blocking[from_node_id]: self.blocking[from_node_id][to_node_id] = {} self.blockCount[to_node_id] += 1 self.blocking[from_node_id][to_node_id][from_socket] = True - def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): + def add_node(self, context, node_unique_id, include_lazy=False, subgraph_nodes=None): node_ids = [node_unique_id] links = [] @@ -126,14 +127,14 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + input_type, input_category, input_info = self.get_input_info(context, unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: - self.add_strong_link(*link) + self.add_strong_link(context, *link) def is_cached(self, node_id): return False diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index e5cc4dffeb0..916bdd9e8a2 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -1,6 +1,7 @@ import torchaudio import torch import comfy.model_management +import execution_context import folder_paths import os import io @@ -121,7 +122,6 @@ def insert_or_replace_vorbis_comment(flac_io, comment_dict): class SaveAudio: def __init__(self): - self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" @@ -129,7 +129,7 @@ def __init__(self): def INPUT_TYPES(s): return {"required": { "audio": ("AUDIO", ), "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}, } RETURN_TYPES = () @@ -139,9 +139,13 @@ def INPUT_TYPES(s): CATEGORY = "audio" - def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext=None): filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + if self.type == "output": + output_dir = folder_paths.get_output_directory(context.user_hash) + else: + output_dir = folder_paths.get_temp_directory(context.user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) results = list() metadata = {} @@ -175,7 +179,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginf class PreviewAudio(SaveAudio): def __init__(self): - self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) @@ -183,38 +186,39 @@ def __init__(self): def INPUT_TYPES(s): return {"required": {"audio": ("AUDIO", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}, } class LoadAudio: @classmethod - def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + input_dir = folder_paths.get_input_directory(context.user_hash) files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + return {"required": {"audio": (sorted(files), {"audio_upload": True})}, + "hidden": {"context": "EXECUTION_CONTEXT"}} CATEGORY = "audio" RETURN_TYPES = ("AUDIO", ) FUNCTION = "load" - def load(self, audio): - audio_path = folder_paths.get_annotated_filepath(audio) + def load(self, audio, context: execution_context.ExecutionContext): + audio_path = folder_paths.get_annotated_filepath(audio, context.user_hash) waveform, sample_rate = torchaudio.load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, ) @classmethod - def IS_CHANGED(s, audio): - image_path = folder_paths.get_annotated_filepath(audio) + def IS_CHANGED(s, audio, context: execution_context.ExecutionContext): + image_path = folder_paths.get_annotated_filepath(audio, context.user_hash) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): - if not folder_paths.exists_annotated_filepath(audio): + def VALIDATE_INPUTS(s, audio, context: execution_context.ExecutionContext): + if not folder_paths.exists_annotated_filepath(audio, context.user_hash): return "Invalid audio file: {}".format(audio) return True diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index b8e241578e7..779adfe6808 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -17,7 +17,7 @@ def INPUT_TYPES(s): def encode(self, clip, ascore, width, height, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height, "_origin_text_": text}), ) class CLIPTextEncodeSDXL: @classmethod @@ -46,7 +46,7 @@ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_heigh tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height, "_origin_text_": text_g + " " + text_l}), ) NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 4c3a1d5bf63..4f619135953 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -17,6 +17,7 @@ def encode(self, clip, conditioning, text): n = [t[0], t[1].copy()] n[1]['cross_attn_controlnet'] = cond n[1]['pooled_output_controlnet'] = pooled + n[1]['_origin_text_'] = text c.append(n) return (c, ) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c7ff9a4d8f9..34607cd7c6f 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -443,7 +443,8 @@ def INPUT_TYPES(s): "sampler": ("SAMPLER", ), "sigmas": ("SIGMAS", ), "latent_image": ("LATENT", ), - } + }, + "hidden": {"context": "EXECUTION_CONTEXT"} } RETURN_TYPES = ("LATENT","LATENT") @@ -453,7 +454,7 @@ def INPUT_TYPES(s): CATEGORY = "sampling/custom_sampling" - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image, context): latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -470,7 +471,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, noise_mask = latent["noise_mask"] x0_output = {} - callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + callback = latent_preview.prepare_callback(context, model, sigmas.shape[-1] - 1, x0_output) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) @@ -605,7 +606,8 @@ def INPUT_TYPES(s): "sampler": ("SAMPLER", ), "sigmas": ("SIGMAS", ), "latent_image": ("LATENT", ), - } + }, + "hidden": {"context": "EXECUTION_CONTEXT"} } RETURN_TYPES = ("LATENT","LATENT") @@ -615,7 +617,7 @@ def INPUT_TYPES(s): CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): + def sample(self, noise, guider, sampler, sigmas, latent_image, context): latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -627,7 +629,7 @@ def sample(self, noise, guider, sampler, sigmas, latent_image): noise_mask = latent["noise_mask"] x0_output = {} - callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + callback = latent_preview.prepare_callback(context, guider.model_patcher, sigmas.shape[-1] - 1, x0_output) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 2ae23f73550..5ac77b09b33 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -18,7 +18,7 @@ def encode(self, clip, clip_l, t5xxl, guidance): tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance, "_origin_text_": clip_l + " " + t5xxl}), ) class FluxGuidance: @classmethod diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index f73a0e9b0f7..961f8e5ce85 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -3,6 +3,8 @@ import torch from collections.abc import Iterable +import execution_context + if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP @@ -292,15 +294,18 @@ def __init__(self): self.loaded_lora = None @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { "required": { - "lora_name": (folder_paths.get_filename_list("loras"), ), + "lora_name": (folder_paths.get_filename_list(context, "loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }, "optional": { "prev_hooks": ("HOOKS",) + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } @@ -309,7 +314,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/hooks/create" FUNCTION = "create_hook" - def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None): + def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None, context: execution_context.ExecutionContext=None): if prev_hooks is None: prev_hooks = comfy.hooks.HookGroup() prev_hooks.clone() @@ -317,7 +322,7 @@ def create_hook(self, lora_name: str, strength_model: float, strength_clip: floa if strength_model == 0 and strength_clip == 0: return (prev_hooks,) - lora_path = folder_paths.get_full_path("loras", lora_name) + lora_path = folder_paths.get_full_path(context, "loras", lora_name) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: @@ -338,14 +343,17 @@ class CreateHookLoraModelOnly(CreateHookLora): NodeId = 'CreateHookLoraModelOnly' NodeName = 'Create Hook LoRA (MO)' @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { "required": { - "lora_name": (folder_paths.get_filename_list("loras"), ), + "lora_name": (folder_paths.get_filename_list(context, "loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }, "optional": { "prev_hooks": ("HOOKS",) + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } @@ -354,8 +362,8 @@ def INPUT_TYPES(s): CATEGORY = "advanced/hooks/create" FUNCTION = "create_hook_model_only" - def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None): - return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks) + def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None, context: execution_context.ExecutionContext=None): + return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks, context=context) class CreateHookModelAsLora: NodeId = 'CreateHookModelAsLora' @@ -367,15 +375,18 @@ def __init__(self): self.loaded_weights = None @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }, "optional": { "prev_hooks": ("HOOKS",) + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } @@ -385,12 +396,13 @@ def INPUT_TYPES(s): FUNCTION = "create_hook" def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float, - prev_hooks: comfy.hooks.HookGroup=None): + prev_hooks: comfy.hooks.HookGroup=None, + context: execution_context.ExecutionContext=None): if prev_hooks is None: prev_hooks = comfy.hooks.HookGroup() prev_hooks.clone() - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + ckpt_path = folder_paths.get_full_path(context, "checkpoints", ckpt_name) weights_model = None weights_clip = None if self.loaded_weights is not None: @@ -416,14 +428,17 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): NodeId = 'CreateHookModelAsLoraModelOnly' NodeName = 'Create Hook Model as LoRA (MO)' @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }, "optional": { "prev_hooks": ("HOOKS",) + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } @@ -433,8 +448,9 @@ def INPUT_TYPES(s): FUNCTION = "create_hook_model_only" def create_hook_model_only(self, ckpt_name: str, strength_model: float, - prev_hooks: comfy.hooks.HookGroup=None): - return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks) + prev_hooks: comfy.hooks.HookGroup=None, + context: execution_context.ExecutionContext=None): + return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks, context=context) #------------------------------------------ ########################################### diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 2bd295e2459..f491db27a9d 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -15,7 +15,7 @@ def encode(self, clip, bert, mt5xl): tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"_origin_text_": bert + " " + mt5xl}), ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 66563229278..8439d978872 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -1,4 +1,5 @@ import comfy.utils +import execution_context import folder_paths import torch import logging @@ -96,18 +97,19 @@ def to(self, device): class HypernetworkLoader: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return {"required": { "model": ("MODEL",), - "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "hypernetwork_name": (folder_paths.get_filename_list(context, "hypernetworks"), ), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL",) FUNCTION = "load_hypernetwork" CATEGORY = "loaders" - def load_hypernetwork(self, model, hypernetwork_name, strength): - hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) + def load_hypernetwork(self, model, hypernetwork_name, strength, context: execution_context.ExecutionContext): + hypernetwork_path = folder_paths.get_full_path_or_raise(context, "hypernetworks", hypernetwork_name) model_hypernetwork = model.clone() patch = load_hypernetwork_patch(hypernetwork_path, strength) if patch is not None: diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index af37666b29f..185adfff0de 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,3 +1,4 @@ +import execution_context import nodes import folder_paths from comfy.cli_args import args @@ -69,7 +70,6 @@ def frombatch(self, image, batch_index, length): class SaveAnimatedWEBP: def __init__(self): - self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" @@ -85,7 +85,7 @@ def INPUT_TYPES(s): "method": (list(s.methods.keys()),), # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}, } RETURN_TYPES = () @@ -95,10 +95,11 @@ def INPUT_TYPES(s): CATEGORY = "image/animation" - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): + def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None, context:execution_context.ExecutionContext=None): method = self.methods.get(method) filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + output_dir = folder_paths.get_output_directory(context.user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]) results = list() pil_images = [] for image in images: @@ -135,7 +136,6 @@ def save_images(self, images, fps, filename_prefix, lossless, quality, method, n class SaveAnimatedPNG: def __init__(self): - self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" @@ -147,7 +147,7 @@ def INPUT_TYPES(s): "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "user_hash": "USER_HASH"}, } RETURN_TYPES = () @@ -157,9 +157,10 @@ def INPUT_TYPES(s): CATEGORY = "image/animation" - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, user_hash=''): filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + output_dir = folder_paths.get_output_directory(user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]) results = list() pil_images = [] for image in images: diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index dfd4fe9f4a5..81a0a13ba08 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import comfy.utils +import execution_context import folder_paths import os import logging @@ -73,7 +74,7 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora class LoraSave: def __init__(self): - self.output_dir = folder_paths.get_output_directory() + pass @classmethod def INPUT_TYPES(s): @@ -84,6 +85,7 @@ def INPUT_TYPES(s): }, "optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}), "text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})}, + "hidden": {"context": "EXECUTION_CONTEXT"} } RETURN_TYPES = () FUNCTION = "save" @@ -91,12 +93,13 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" - def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None): + def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None, context: execution_context.ExecutionContext=None): if model_diff is None and text_encoder_diff is None: return {} lora_type = LORA_TYPES.get(lora_type) - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + output_dir = folder_paths.get_output_directory(context.user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) output_sd = {} if model_diff is not None: @@ -111,7 +114,7 @@ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, tex return {} NODE_CLASS_MAPPINGS = { - "LoraSave": LoraSave + # "LoraSave": LoraSave } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ccf601158d5..42dd5b6be7c 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -5,6 +5,8 @@ import comfy.model_sampling import torch + +import execution_context import folder_paths import json import os @@ -223,7 +225,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi class CheckpointSave: def __init__(self): - self.output_dir = folder_paths.get_output_directory() + pass @classmethod def INPUT_TYPES(s): @@ -231,33 +233,33 @@ def INPUT_TYPES(s): "clip": ("CLIP",), "vae": ("VAE",), "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" - def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext=None): + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(context.user_hash), prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: def __init__(self): - self.output_dir = folder_paths.get_output_directory() + pass @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",), "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" - def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): + def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext=None): prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) @@ -289,7 +291,7 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): replace_prefix[prefix] = "" replace_prefix["transformer."] = "" - full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, folder_paths.get_output_directory(context.user_hash)) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) @@ -301,21 +303,21 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): class VAESave: def __init__(self): - self.output_dir = folder_paths.get_output_directory() - + pass + @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"}} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" - def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(context.user_hash)) prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) @@ -335,37 +337,38 @@ def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): class ModelSave: def __init__(self): - self.output_dir = folder_paths.get_output_directory() + pass @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" - def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): - save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext=None): + output_dir = folder_paths.get_output_directory(context.user_hash) + save_checkpoint(model, filename_prefix=filename_prefix, output_dir=output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} NODE_CLASS_MAPPINGS = { - "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks, - "ModelMergeSubtract": ModelSubtract, - "ModelMergeAdd": ModelAdd, - "CheckpointSave": CheckpointSave, - "CLIPMergeSimple": CLIPMergeSimple, - "CLIPMergeSubtract": CLIPSubtract, - "CLIPMergeAdd": CLIPAdd, - "CLIPSave": CLIPSave, - "VAESave": VAESave, - "ModelSave": ModelSave, + # "ModelMergeSimple": ModelMergeSimple, + # "ModelMergeBlocks": ModelMergeBlocks, + # "ModelMergeSubtract": ModelSubtract, + # "ModelMergeAdd": ModelAdd, + # "CheckpointSave": CheckpointSave, + # "CLIPMergeSimple": CLIPMergeSimple, + # "CLIPMergeSubtract": CLIPSubtract, + # "CLIPMergeAdd": CLIPAdd, + # "CLIPSave": CLIPSave, + # "VAESave": VAESave, + # "ModelSave": ModelSave, } NODE_DISPLAY_NAME_MAPPINGS = { - "CheckpointSave": "Save Checkpoint", + # "CheckpointSave": "Save Checkpoint", } diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index 95d24dd221e..12d5f38082b 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -1,5 +1,7 @@ import torch import torch.nn as nn + +import execution_context import folder_paths import comfy.clip_model import comfy.clip_vision @@ -117,16 +119,17 @@ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): class PhotoMakerLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list(context, "photomaker"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("PHOTOMAKER",) FUNCTION = "load_photomaker_model" CATEGORY = "_for_testing/photomaker" - def load_photomaker_model(self, photomaker_model_name): - photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) + def load_photomaker_model(self, photomaker_model_name, context: execution_context.ExecutionContext): + photomaker_model_path = folder_paths.get_full_path_or_raise(context, "photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: @@ -177,7 +180,7 @@ def apply_photomaker(self, photomaker, image, clip, text): else: out = cond - return ([[out, {"pooled_output": pooled}]], ) + return ([[out, {"pooled_output": pooled, "_origin_text_": text}]], ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d75b29e606f..cce62da439f 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -1,3 +1,4 @@ +import execution_context import folder_paths import comfy.sd import comfy.model_management @@ -8,9 +9,10 @@ class TripleCLIPLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "clip_name1": (folder_paths.get_filename_list(context, "text_encoders"), ), "clip_name2": (folder_paths.get_filename_list(context, "text_encoders"), ), "clip_name3": (folder_paths.get_filename_list(context, "text_encoders"), ) + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -18,10 +20,10 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5" - def load_clip(self, clip_name1, clip_name2, clip_name3): - clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) - clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + def load_clip(self, clip_name1, clip_name2, clip_name3, context: execution_context.ExecutionContext): + clip_path1 = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name3) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) @@ -82,7 +84,7 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"_origin_text_": " ".join([clip_l, clip_g, t5xxl])}), ) class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 6ba3e404f2e..2c168c00d41 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,5 +1,6 @@ import os import logging +import execution_context from spandrel import ModelLoader, ImageModelDescriptor from comfy import model_management import torch @@ -16,16 +17,17 @@ class UpscaleModelLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "model_name": (folder_paths.get_filename_list(context, "upscale_models"), ), + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" CATEGORY = "loaders" - def load_model(self, model_name): - model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) + def load_model(self, model_name, context: execution_context.ExecutionContext): + model_path = folder_paths.get_full_path_or_raise(context, "upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index e7a7ec181fc..960a34ccf9d 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -1,3 +1,4 @@ +import execution_context import nodes import torch import comfy.utils @@ -8,16 +9,23 @@ class ImageOnlyCheckpointLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), ), + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders/video_models" - def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + @classmethod + def VALIDATE_INPUTS(cls, ckpt_name, output_vae=True, output_clip=True, context: execution_context.ExecutionContext=None): + context.validate_model("checkpoints", ckpt_name) + return True + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, context: execution_context.ExecutionContext=None): + ckpt_path = folder_paths.get_full_path_or_raise(context, "checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return (out[0], out[3], out[2]) @@ -115,10 +123,10 @@ def INPUT_TYPES(s): "clip_vision": ("CLIP_VISION",), "vae": ("VAE",), "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "user_hash": "USER_HASH"}} - def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): - comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None, user_hash=''): + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(user_hash), prompt=prompt, extra_pnginfo=extra_pnginfo) return {} NODE_CLASS_MAPPINGS = { @@ -126,7 +134,7 @@ def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pngi "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, - "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, + # "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 32a0ba2f67b..ea2fe0a20e2 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -13,6 +13,9 @@ def INPUT_TYPES(s): "width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "capture_on_queue": ("BOOLEAN", {"default": True}), + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } RETURN_TYPES = ("IMAGE",) @@ -21,7 +24,8 @@ def INPUT_TYPES(s): CATEGORY = "image" def load_capture(s, image, **kwargs): - return super().load_image(folder_paths.get_annotated_filepath(image)) + context = kwargs["context"] + return super().load_image(folder_paths.get_annotated_filepath(image, context.user_hash)) NODE_CLASS_MAPPINGS = { diff --git a/diffus/__init__.py b/diffus/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/diffus/database.py b/diffus/database.py new file mode 100644 index 00000000000..0013adbc73f --- /dev/null +++ b/diffus/database.py @@ -0,0 +1,41 @@ +import os +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, Session + +SQLALCHEMY_DATABASE_URL = os.getenv('SQL_DATABASE_URL') + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + connect_args={ + }, + pool_pre_ping=True, + pool_recycle=3600 +) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + + +class Database: + + def __init__(self): + self._db: Session | None = None + + def __enter__(self) -> Session: + self._db = SessionLocal() + return self._db + + def __exit__(self, exc_type, exc_val, exc_tb): + self._db.commit() + self._db.close() + self._db = None + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.commit() + db.close() diff --git a/diffus/decoded_params.py b/diffus/decoded_params.py new file mode 100644 index 00000000000..dcc10c1db09 --- /dev/null +++ b/diffus/decoded_params.py @@ -0,0 +1,1831 @@ +import math + +import execution_context + + +def _sample_consumption_ratio( + context: execution_context.ExecutionContext, + model, +): + import comfy.model_base + import comfy.model_patcher + if model: + if isinstance(model, comfy.model_patcher.ModelPatcher): + model = model.model + if isinstance(model, comfy.model_base.SD3) or isinstance(model, comfy.model_base.Flux): + return 3 + + for model_info in context.loaded_checkpoints: + if model_info.base and model_info.base.lower() in ("sd3", "flux"): + return 3 + return 1 + + +def __update_context_checkpoints_model_base(context: execution_context.ExecutionContext, model): + import comfy.model_base + import comfy.model_patcher + if isinstance(model, comfy.model_patcher.ModelPatcher): + model = model.model + if isinstance(model, comfy.model_base.SD3): + context.checkpoints_model_base = "SD3" + elif isinstance(model, comfy.model_base.Flux): + context.checkpoints_model_base = "SD3" + elif isinstance(model, comfy.model_base.GenmoMochi): + context.checkpoints_model_base = "MOCHI" + + +def __get_image_size(model, latent_image): + import comfy.model_base + import comfy.model_patcher + latent = latent_image["samples"] + latent_size = latent.size() + batch_size = latent_size[0] + if isinstance(model, comfy.model_patcher.ModelPatcher) and isinstance(model.model, comfy.model_base.GenmoMochi): + image_height = latent_size[3] * 8 + image_width = latent_size[4] * 8 + n_iter = max(1, (latent_size[2] - 1) * 6 + 1) + else: + image_height = latent_size[2] * 8 + image_width = latent_size[3] * 8 + n_iter = 1 + return image_height, image_width, n_iter, batch_size + + +def __sample_opt_from_latent(context: execution_context.ExecutionContext, model, latent_image, steps): + __update_context_checkpoints_model_base(context, model) + image_height, image_width, n_iter, batch_size = __get_image_size(model, latent_image) + return { + 'opt_type': 'ksampler', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': n_iter, + 'batch_size': batch_size, + "ratio": _sample_consumption_ratio(context, model) + } + + +def _k_sampler_consumption(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=1.0, context=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _reactor_restore_face_consumption(image, model, visibility, codeformer_weight, facedetection, + context: execution_context.ExecutionContext): + if model != 'none': + opts = [{ + 'opt_type': 'detect_face', + 'width': image.shape[2], + 'height': image.shape[1], + 'steps': 30, + 'n_iter': 1, + 'batch_size': image.shape[0] + }, { + 'opt_type': 'swap_face', + 'width': image.shape[2], + 'height': image.shape[1], + 'steps': 30, + 'n_iter': 1, + 'batch_size': image.shape[0] + }, { + 'opt_type': 'restore_face', + 'width': image.shape[2], + 'height': image.shape[1], + 'steps': 30, + 'n_iter': 1, + 'batch_size': image.shape[0] + }] + else: + opts = [] + return {'opts': opts} + + +def _reactor_face_swap_consumption(enabled, + input_image, + swap_model, + detect_gender_source, + detect_gender_input, + source_faces_index, + input_faces_index, + console_log_level, + face_restore_model, + face_restore_visibility, + codeformer_weight, + facedetection, + source_image=None, + face_model=None, + faces_order=None, + context: execution_context.ExecutionContext = None): + return { + 'opts': [{ + 'opt_type': 'detect_face', + 'width': input_image.shape[2], + 'height': input_image.shape[1], + 'batch_size': input_image.shape[0] + }, { + 'opt_type': 'restore_face', + 'width': input_image.shape[2], + 'height': input_image.shape[1], + 'batch_size': input_image.shape[0] + }] + } + + +def _reactor_face_swap_opt_consumption(enabled, input_image, swap_model, facedetection, face_restore_model, + face_restore_visibility, codeformer_weight, source_image=None, face_model=None, + options=None, context: execution_context.ExecutionContext = None): + return { + 'opts': [{ + 'opt_type': 'detect_face', + 'width': input_image.shape[2], + 'height': input_image.shape[1], + 'batch_size': input_image.shape[0] + }, { + 'opt_type': 'restore_face', + 'width': input_image.shape[2], + 'height': input_image.shape[1], + 'batch_size': input_image.shape[0] + }] + } + + +def _k_sampler_advanced_consumption(model, + add_noise, + noise_seed, + steps, + cfg, + sampler_name, + scheduler, + positive, + negative, + latent_image, + start_at_step, + end_at_step, + return_with_leftover_noise, + denoise=1.0, + context=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=noise_seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _tsc_ksampler_advanced_consumption(model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, + negative, + latent_image, start_at_step, end_at_step, return_with_leftover_noise, + preview_method, vae_decode, + prompt=None, extra_pnginfo=None, my_unique_id=None, + context: execution_context.ExecutionContext = None, + optional_vae=(None,), script=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=noise_seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _tsc_ksampler_sdxl_consumption(sdxl_tuple, noise_seed, steps, cfg, sampler_name, scheduler, latent_image, + start_at_step, refine_at_step, preview_method, vae_decode, prompt=None, + extra_pnginfo=None, + my_unique_id=None, context: execution_context.ExecutionContext = None, + optional_vae=(None,), refiner_extras=None, + script=None): + context.set_geninfo( + positive_prompt=prompt, + negative_prompt=None, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=noise_seed, + ) + model = sdxl_tuple[0] + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _tsc_k_sampler_consumption(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + preview_method, vae_decode, denoise=1.0, prompt=None, extra_pnginfo=None, + my_unique_id=None, + context: execution_context.ExecutionContext = None, + optional_vae=(None,), script=None, add_noise=None, start_at_step=None, end_at_step=None, + return_with_leftover_noise=None, sampler_type="regular"): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _xlabs_sampler_consumption(model, conditioning, neg_conditioning, + noise_seed, steps, timestep_to_start_cfg, true_gs, + image_to_image_strength, denoise_strength, + latent_image=None, controlnet_condition=None, + context=None): + context.set_geninfo( + positive_prompt=conditioning, + negative_prompt=neg_conditioning, + steps=steps, + sampler="", + cfg_scale=timestep_to_start_cfg, + seed=noise_seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _impact_k_sampler_basic_pipe_consumption(basic_pipe, seed, steps, cfg, sampler_name, scheduler, latent_image, + denoise=1.0, context=None): + model, clip, vae, positive, negative = basic_pipe + + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _tiled_k_sampler_consumption(model, seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, + scheduler, positive, negative, latent_image, denoise, context=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _easy_full_k_sampler_consumption(pipe, steps, cfg, sampler_name, scheduler, denoise, image_output, link_id, + save_prefix, seed=None, model=None, positive=None, negative=None, latent=None, + vae=None, clip=None, xyPlot=None, tile_size=None, prompt=None, extra_pnginfo=None, + my_unique_id=None, context: execution_context.ExecutionContext = None, + force_full_denoise=False, disable_noise=False, downscale_options=None, image=None): + samp_samples = latent if latent is not None else pipe["samples"] + samp_vae = vae if vae is not None else pipe["vae"] + if image is not None and latent is None: + samp_samples = {"samples": samp_vae.encode(image[:, :, :, :3])} + samp_model = model if model is not None else pipe["model"] + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + return {'opts': [__sample_opt_from_latent(context, samp_model, samp_samples, steps, )]} + + +def _mochi_sampler_consumption(model, positive, negative, steps, cfg, seed, height, width, num_frames, + cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None, + context: execution_context.ExecutionContext = None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=samples, + cfg_scale=cfg, + seed=seed, + ) + return { + 'opts': [ + { + 'opt_type': 'mochi_sampler', + 'width': width, + 'height': height, + 'steps': steps, + 'n_iter': num_frames, + 'batch_size': 1, + "ratio": _sample_consumption_ratio(context, model) + } + ] + } + + +def _cog_video_sampler_consumption(model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, + denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, + tora_trajectory=None, fastercache=None, + context: execution_context.ExecutionContext = None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=samples, + cfg_scale=cfg, + seed=seed, + ) + + H = 768 + W = 768 + B = 1 + if samples is not None: + if len(samples["samples"].shape) == 5: + B, T, C, H, W = samples["samples"].shape + if len(samples["samples"].shape) == 4: + B, C, H, W = samples["samples"].shape + if image_cond_latents is not None: + B, T, C, H, W = image_cond_latents["samples"].shape + height = H * 8 + width = W * 8 + return { + 'opts': [ + { + 'opt_type': 'cog_video_sampler', + 'width': width, + 'height': height, + 'steps': steps, + 'n_iter': num_frames, + 'batch_size': B, + "ratio": _sample_consumption_ratio(context, model) + } + ] + } + + +def _tiled_k_sampler_advanced_consumption(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, + cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, + end_at_step, return_with_leftover_noise, preview, denoise=1.0, + context: execution_context.ExecutionContext = None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=noise_seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _sampler_custom_consumption(model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image, + context): + steps = len(sigmas) + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler, + cfg_scale=cfg, + seed=noise_seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _sampler_custom_advanced_consumption(noise, guider, sampler, sigmas, latent_image, context): + steps = len(sigmas) + return {'opts': [__sample_opt_from_latent(context, guider.model_patcher, latent_image, steps, )]} + + +def _k_sampler_inspire_consumption(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise, noise_mode, batch_seed_mode="comfy", variation_seed=None, + variation_strength=None, variation_method="linear", + context=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + return {'opts': [__sample_opt_from_latent(context, model, latent_image, steps, )]} + + +def _was_k_sampler_cycle_consumption(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + tiled_vae, latent_upscale, upscale_factor, + upscale_cycles, starting_denoise, cycle_denoise, scale_denoise, scale_sampling, + vae, secondary_model=None, secondary_start_cycle=None, + pos_additive=None, pos_add_mode=None, pos_add_strength=None, + pos_add_strength_scaling=None, pos_add_strength_cutoff=None, + neg_additive=None, neg_add_mode=None, neg_add_strength=None, + neg_add_strength_scaling=None, neg_add_strength_cutoff=None, + upscale_model=None, processor_model=None, sharpen_strength=0, sharpen_radius=2, + steps_scaling=None, steps_control=None, + steps_scaling_value=None, steps_cutoff=None, denoise_cutoff=0.25, + context=None): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + + result = [] + upscale_steps = upscale_cycles + division_factor = upscale_steps if steps >= upscale_steps else steps + current_upscale_factor = upscale_factor ** (1 / (division_factor - 1)) + n_iter = latent_image.get("batch_index", 1) + + latent = latent_image["samples"] + latent_size = latent.size() + batch_size = latent_size[0] + latent_image_height = latent_size[2] * 8 + latent_image_width = latent_size[3] * 8 + + for i in range(division_factor): + if steps_scaling and i > 0: + steps = ( + steps + steps_scaling_value + if steps_control == 'increment' + else steps - steps_scaling_value + ) + steps = ( + (steps + if steps <= steps_cutoff + else steps_cutoff) + if steps_control == 'increment' + else (steps + if steps >= steps_cutoff + else steps_cutoff) + ) + result.append({ + 'opt_type': 'sample', + 'width': latent_image_width, + 'height': latent_image_height, + 'steps': steps, + 'n_iter': n_iter, + 'batch_size': batch_size, + "ratio": _sample_consumption_ratio(context, model) + }) + if i < division_factor - 1 and latent_upscale == 'disable': + if processor_model: + scale_factor = _get_upscale_model_size(context, processor_model) + result.append({ + 'opt_type': 'upscale', + 'width': latent_image_width * scale_factor, + 'height': latent_image_height * scale_factor, + }) + if upscale_model: + scale_factor = _get_upscale_model_size(context, upscale_model) + result.append({ + 'opt_type': 'upscale', + 'width': latent_image_width * scale_factor, + 'height': latent_image_height * scale_factor, + }) + latent_image_width = int(round(round(latent_image_width * current_upscale_factor) / 32) * 32) + latent_image_height = int(round(round(latent_image_height * current_upscale_factor) / 32) * 32) + else: + latent_image_height *= current_upscale_factor + latent_image_width *= current_upscale_factor + + return {'opts': result} + + +def _searge_sdxl_image2image_sampler2_consumption(base_model, base_positive, base_negative, refiner_model, + refiner_positive, refiner_negative, + image, vae, noise_seed, steps, cfg, sampler_name, scheduler, + base_ratio, denoise, softness, + upscale_model=None, scaled_width=None, scaled_height=None, + noise_offset=None, refiner_strength=None, + context: execution_context.ExecutionContext = None): + context.set_geninfo( + positive_prompt=base_positive, + negative_prompt=base_negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + ) + + result = [] + + if steps < 1: + return result + + if upscale_model is not None and softness < 0.9999: + use_upscale_model = True + model_scale = _get_upscale_model_size(context, upscale_model) + else: + use_upscale_model = False + model_scale = 1 + + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + if use_upscale_model: + result.append({ + 'opt_type': 'upscale', + 'width': image_width * model_scale, + 'height': image_height * model_scale, + 'batch_size': batch_size, + }) + + if denoise < 0.01: + return result + + n_iter = 1 + if scaled_width is not None and scaled_height is not None: + sample_height = scaled_height + sample_width = scaled_width + elif use_upscale_model: + sample_height = image_height * model_scale + sample_width = image_width * model_scale + else: + sample_height = image_height + sample_width = image_width + + result.append({ + 'opt': 'sample', + 'width': sample_width, + 'height': sample_height, + 'steps': steps, + 'n_iter': n_iter, + 'batch_size': batch_size, + 'ratio': _sample_consumption_ratio(context, base_model) + }) + return {'opts': result} + + +def _ultimate_sd_upscale_consumption(image, model, positive, negative, vae, upscale_by, seed, + steps, cfg, sampler_name, scheduler, denoise, upscale_model, + mode_type, tile_width, tile_height, mask_blur, tile_padding, + seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur, + seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode, + context: execution_context.ExecutionContext): + context.set_geninfo( + positive_prompt=positive, + negative_prompt=negative, + steps=steps, + sampler=sampler_name, + cfg_scale=cfg, + seed=seed, + ) + + batch_size = image.shape[0] + image_width = image.shape[2] + image_height = image.shape[1] + + if upscale_model is not None: + enable_hr = True + hr_width = image_width * upscale_model.scale + hr_height = image_height * upscale_model.scale + else: + enable_hr = False + hr_width = 0 + hr_height = 0 + + redraw_width = math.ceil((image_width * upscale_by) / 64) * 64 + redraw_height = math.ceil((image_height * upscale_by) / 64) * 64 + result = [{ + 'opt_type': 'sample', + 'width': redraw_width, + 'height': redraw_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + 'ratio': _sample_consumption_ratio(context, model) + }] + if enable_hr: + result.append({ + 'opt_type': 'hires_fix', + 'width': hr_width, + 'height': hr_height, + }) + return {'opts': result} + + +def _image_upscale_with_model_consumption(upscale_model, image): + return { + 'opts': [{ + 'opt_type': 'upscale', + 'width': image.shape[2] * upscale_model.scale, + 'height': image.shape[1] * upscale_model.scale, + 'batch_size': image.shape[0], + }] + } + + +model_upscale_cache = { + '16xPSNR.pth': 16, + '1x_NMKD-BrightenRedux_200k.pth': 1, + '1x_NMKD-YandereInpaint_375000_G.pth': 1, + '1x_NMKDDetoon_97500_G.pth': 1, + '1x_NoiseToner-Poisson-Detailed_108000_G.pth': 1, + '1x_NoiseToner-Uniform-Detailed_100000_G.pth': 1, + '4x-AnimeSharp.pth': 4, + '4x-UltraSharp.pth': 4, + '4xPSNR.pth': 4, + '4x_CountryRoads_377000_G.pth': 4, + '4x_Fatality_Comix_260000_G.pth': 4, + '4x_NMKD-Siax_200k.pth': 4, + '4x_NMKD-Superscale-Artisoftject_210000_G.pth': 4, + '4x_NMKD-Superscale-SP_178000_G.pth': 4, + '4x_NMKD-UltraYandere-Lite_280k.pth': 4, + '4x_NMKD-UltraYandere_300k.pth': 4, + '4x_NMKD-YandereNeoXL_200k.pth': 4, + '4x_NMKDSuperscale_Artisoft_120000_G.pth': 4, + '4x_NickelbackFS_72000_G.pth': 4, + '4x_Nickelback_70000G.pth': 4, + '4x_RealisticRescaler_100000_G.pth': 4, + '4x_UniversalUpscalerV2-Neutral_115000_swaG.pth': 4, + '4x_UniversalUpscalerV2-Sharp_101000_G.pth': 4, + '4x_UniversalUpscalerV2-Sharper_103000_G.pth': 4, + '4x_Valar_v1.pth': 4, + '4x_fatal_Anime_500000_G.pth': 4, + '4x_foolhardy_Remacri.pth': 4, + '4x_foolhardy_Remacri_ExtraSmoother.pth': 4, + '8xPSNR.pth': 8, + '8x_NMKD-Superscale_150000_G.pth': 8, + '8x_NMKD-Typescale_175k.pth': 8, + "A_ESRGAN_Single.pth": 4, + "BSRGAN.pth": 4, + 'BSRGANx2.pth': 2, + "BSRNet.pth": 4, + 'ESRGAN_4x.pth': 4, + "LADDIER1_282500_G.pth": 4, + 'RealESRGAN_x4plus.pth': 4, + 'RealESRGAN_x4plus_anime_6B.pth': 4, + 'SwinIR_4x.pth': 4, + "WaifuGAN_v3_30000.pth": 4, + "lollypop.pth": 4, +} + + +def _get_upscale_model_size(context, model_name): + if model_name not in model_upscale_cache: + try: + import folder_paths + import comfy + from comfy_extras.chainner_models import model_loading + model_path = folder_paths.get_full_path(context, "upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + upscale_model = model_loading.load_state_dict(sd).eval() + model_upscale_cache[model_name] = upscale_model.scale + del upscale_model + except Exception as e: + model_upscale_cache[model_name] = 4 + return model_upscale_cache[model_name] + + +def _easy_hires_fix_consumption( + model_name, rescale_after_model, rescale_method, rescale, percent, width, height, + longer_side, crop, image_output, link_id, save_prefix, pipe=None, image=None, vae=None, prompt=None, + extra_pnginfo=None, my_unique_id=None, context: execution_context.ExecutionContext = None): + model_scale = _get_upscale_model_size(context, model_name) + + if pipe is not None: + image = image if image is not None else pipe["images"] + if image is not None: + return { + 'opts': { + 'opt_type': 'hires_fix', + 'width': image.shape[2] * model_scale, + 'height': image.shape[1] * model_scale, + 'batch_size': image.shape[0], + } + } + else: + return { + } + + +def _cr_upscale_image_consumption(image, upscale_model, rounding_modulus=8, loops=1, mode="rescale", supersample='true', + resampling_method="lanczos", rescale_factor=2, resize_width=1024, + context: execution_context.ExecutionContext = None): + model_scale = _get_upscale_model_size(context, upscale_model) + if image is not None: + return { + 'opts': [{ + 'opt_type': 'upscale', + 'width': image.shape[2] * model_scale, + 'height': image.shape[1] * model_scale, + 'batch_size': image.shape[0], + }] + } + else: + return { + } + + +def _vhs_video_combine_consumption( + images, + frame_rate: int, + loop_count: int, + filename_prefix="AnimateDiff", + format="image/gif", + pingpong=False, + save_output=True, + prompt=None, + extra_pnginfo=None, + audio=None, + unique_id=None, + manual_format_widgets=None, + meta_batch=None, + context: execution_context.ExecutionContext = None, +): + return { + 'opts': [{ + 'opt_type': 'generate', + 'width': images.shape[2], + 'height': images.shape[1], + 'batch_size': images.shape[0], + }] + } + + +def _face_detailer_pipe_consumption(image, detailer_pipe, guide_size, guide_size_for, max_size, seed, steps, cfg, + sampler_name, scheduler, + denoise, feather, noise_mask, force_inpaint, bbox_threshold, bbox_dilation, + bbox_crop_factor, + sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, + sam_mask_hint_threshold, sam_mask_hint_use_negative, drop_size, refiner_ratio=None, + cycle=1, inpaint_model=False, noise_mask_feather=0, + context: execution_context.ExecutionContext = None, ): + model, clip, vae, positive, negative, wildcard, bbox_detector, segm_detector_opt, sam_model_opt, detailer_hook, \ + refiner_model, refiner_clip, refiner_positive, refiner_negative = detailer_pipe + + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'face_detector', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }] + if sam_model_opt is not None: + opts.append({ + 'opt_type': 'sam', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + elif segm_detector_opt is not None: + opts.append({ + 'opt_type': 'sam', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + opts.append({ + 'opt_type': 'face_enhance', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + + return {'opts': opts} + + +def _face_detailer_consumption(image, model, clip, vae, guide_size, guide_size_for, max_size, seed, steps, cfg, + sampler_name, scheduler, + positive, negative, denoise, feather, noise_mask, force_inpaint, + bbox_threshold, bbox_dilation, bbox_crop_factor, + sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, + sam_mask_hint_threshold, + sam_mask_hint_use_negative, drop_size, bbox_detector, wildcard, cycle=1, + sam_model_opt=None, segm_detector_opt=None, detailer_hook=None, inpaint_model=False, + noise_mask_feather=0, + context: execution_context.ExecutionContext = None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'face_detector', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }] + if sam_model_opt is not None: + opts.append({ + 'opt_type': 'sam', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + elif segm_detector_opt is not None: + opts.append({ + 'opt_type': 'sam', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + opts.append({ + 'opt_type': 'face_enhance', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': 1, + 'batch_size': batch_size, + }) + + return {'opts': opts} + + +def _detailer_for_each_consumption(image, segs, model, clip, vae, guide_size, guide_size_for, max_size, seed, steps, + cfg, sampler_name, + scheduler, positive, negative, denoise, feather, noise_mask, force_inpaint, wildcard, + cycle=1, + detailer_hook=None, inpaint_model=False, noise_mask_feather=0, + context: execution_context.ExecutionContext = None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'enhance_detail', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': cycle, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _detailer_for_each_pipe_consumption(image, segs, guide_size, guide_size_for, max_size, seed, steps, cfg, + sampler_name, scheduler, + denoise, feather, noise_mask, force_inpaint, basic_pipe, wildcard, + refiner_ratio=None, detailer_hook=None, refiner_basic_pipe_opt=None, + cycle=1, inpaint_model=False, noise_mask_feather=0, + context: execution_context.ExecutionContext = None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'enhance_detail', + 'width': image_width, + 'height': image_height, + 'steps': steps, + 'n_iter': cycle, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _impact_simple_detector_segs_for_ad_consumption(bbox_detector, image_frames, bbox_threshold, bbox_dilation, + crop_factor, drop_size, + sub_threshold, sub_dilation, sub_bbox_expansion, + sam_mask_hint_threshold, + masking_mode="Pivot SEGS", segs_pivot="Combined mask", + sam_model_opt=None, segm_detector_opt=None): + image_width = image_frames.shape[2] + image_height = image_frames.shape[1] + batch_size = image_frames.shape[0] + opts = [{ + 'opt_type': 'detect_box', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + if sam_model_opt is not None: + opts.append({ + 'opt_type': 'sam', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }) + elif segm_detector_opt is not None: + opts.append({ + 'opt_type': 'detect_seg', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }) + return {'opts': opts} + + +def _re_actor_build_face_model_consumption(image, det_size=(640, 640)): + return { + 'opts': [{ + 'opt_type': 'build_face_model', + 'width': det_size[0], + 'height': det_size[1], + 'steps': 1, + 'n_iter': 1, + 'batch_size': 1 + }] + } + + +def _supir_decode_consumption(SUPIR_VAE, latents, use_tiled_vae, decoder_tile_size): + opt = __sample_opt_from_latent(latents, 30) + opt['opt_type'] = 'supir_decode' + del opt['steps'] + return {'opts': [opt, ]} + + +def _supir_encode_consumption(SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'supir_encode', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _supir_sample_consumption(SUPIR_model, latents, steps, seed, cfg_scale_end, EDM_s_churn, s_noise, positive, + negative, + cfg_scale_start, control_scale_start, control_scale_end, restore_cfg, keep_model_loaded, + DPMPP_eta, + sampler, sampler_tile_size=1024, sampler_tile_stride=512): + return {'opts': [__sample_opt_from_latent(latents, steps, )]} + + +def _supir_first_stage_consumption(SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size, + decoder_tile_size): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'supir_first_stage', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _bbox_detector_segs_consumption(bbox_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, + detailer_hook=None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'bbox_detector', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _segm_detector_for_each_consumption(segm_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, + detailer_hook=None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'segm_detector', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _impact_simple_detector_segs_consumption(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, + drop_size, + sub_threshold, sub_dilation, sub_bbox_expansion, + sam_mask_hint_threshold, post_dilation=0, sam_model_opt=None, + segm_detector_opt=None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'detector_segs', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _clip_seg_masking_consumption(image, text=None, clipseg_model=None): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'clip_seg_masking', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _layermask_person_mask_ultra_consumption(images, face, hair, body, clothes, + accessories, background, confidence, + detail_range, black_point, white_point, process_detail): + image_width = images.shape[2] + image_height = images.shape[1] + batch_size = images.shape[0] + opts = [{ + 'opt_type': 'layermask_person_mask_ultra', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _slice_dict(d, i): + d_new = dict() + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + +def _map_node_consumption_over_list(obj, input_data_all, func): + # check if node wants the lists + input_is_list = getattr(obj, "INPUT_IS_LIST", False) + + if len(input_data_all) == 0: + max_len_input = 0 + else: + max_len_input = max([len(x) for x in input_data_all.values()]) + + if input_is_list: + return func(**input_data_all) + elif max_len_input == 0: + return func() + else: + results = [] + for i in range(max_len_input): + results.append(func(**_slice_dict(input_data_all, i))) + if len(results) == 1: + return results[0] + else: + return results + + +def _ultimate_sd_upscale_no_upscale_consumption(upscaled_image, model, positive, negative, vae, seed, + steps, cfg, sampler_name, scheduler, denoise, + mode_type, tile_width, tile_height, mask_blur, tile_padding, + seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur, + seam_fix_width, seam_fix_padding, force_uniform_tiles, tiled_decode, + context: execution_context.ExecutionContext): + return { + 'opts': [{ + 'opt_type': 'upscale', + 'width': upscaled_image.shape[2], + 'height': upscaled_image.shape[1], + 'steps': steps, + 'n_iter': 1, + 'batch_size': upscaled_image.shape[0], + 'ratio': _sample_consumption_ratio(context, model) + }] + } + + +def image_rembg_consumption( + images, + transparency=True, + model="u2net", + alpha_matting=False, + alpha_matting_foreground_threshold=240, + alpha_matting_background_threshold=10, + alpha_matting_erode_size=10, + post_processing=False, + only_mask=False, + background_color="none", +): + image_width = images.shape[2] + image_height = images.shape[1] + batch_size = images.shape[0] + opts = [{ + 'opt_type': 'rembg', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _sam_detector_combined_consumption( + sam_model, segs, image, detection_hint, dilation, + threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): + image_width = image.shape[2] + image_height = image.shape[1] + batch_size = image.shape[0] + opts = [{ + 'opt_type': 'segm_detector_combined', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def was_remove_background_consumption(images, mode='background', threshold=127, threshold_tolerance=2): + image_width = images.shape[2] + image_height = images.shape[1] + batch_size = images.shape[0] + opts = [{ + 'opt_type': 'was_remove_background', + 'width': image_width, + 'height': image_height, + 'batch_size': batch_size, + }] + return {'opts': opts} + + +def _default_consumption_maker(*args, **kwargs): + return {} + + +def _none_consumption_maker(*args, **kwargs): + return None + + +_NODE_CONSUMPTION_MAPPING = { + 'KSampler': _k_sampler_consumption, + 'KSamplerAdvanced': _k_sampler_advanced_consumption, + 'KSampler (Efficient)': _tsc_k_sampler_consumption, + 'KSampler Adv. (Efficient)': _tsc_ksampler_advanced_consumption, + 'KSampler SDXL (Eff.)': _tsc_ksampler_sdxl_consumption, + 'XlabsSampler': _xlabs_sampler_consumption, + 'ImpactKSamplerBasicPipe': _impact_k_sampler_basic_pipe_consumption, + 'ReActorRestoreFace': _reactor_restore_face_consumption, + 'ReActorFaceSwap': _reactor_face_swap_consumption, + 'ReActorFaceSwapOpt': _reactor_face_swap_opt_consumption, + 'ReActorBuildFaceModel': _re_actor_build_face_model_consumption, + 'ImageUpscaleWithModel': _image_upscale_with_model_consumption, + 'UltimateSDUpscale': _ultimate_sd_upscale_consumption, + 'easy hiresFix': _easy_hires_fix_consumption, + 'VHS_VideoCombine': _vhs_video_combine_consumption, + 'FaceDetailer': _face_detailer_consumption, + 'FaceDetailerPipe': _face_detailer_pipe_consumption, + 'SamplerCustom': _sampler_custom_consumption, + 'SamplerCustomAdvanced': _sampler_custom_advanced_consumption, + 'SeargeSDXLImage2ImageSampler2': _searge_sdxl_image2image_sampler2_consumption, + 'BNK_TiledKSamplerAdvanced': _tiled_k_sampler_advanced_consumption, + 'BNK_TiledKSampler': _tiled_k_sampler_consumption, + 'easy fullkSampler': _easy_full_k_sampler_consumption, + "CogVideoSampler": _cog_video_sampler_consumption, + "MochiSampler": _mochi_sampler_consumption, + + 'UltimateSDUpscaleNoUpscale': _ultimate_sd_upscale_no_upscale_consumption, + 'CR Upscale Image': _cr_upscale_image_consumption, + 'KSampler //Inspire': _k_sampler_inspire_consumption, + 'KSampler Cycle': _was_k_sampler_cycle_consumption, + 'ImpactSimpleDetectorSEGS_for_AD': _impact_simple_detector_segs_for_ad_consumption, + 'DetailerForEach': _detailer_for_each_consumption, + 'DetailerForEachPipe': _detailer_for_each_pipe_consumption, + 'SUPIR_decode': _supir_decode_consumption, + 'SUPIR_encode': _supir_encode_consumption, + 'SUPIR_sample': _supir_sample_consumption, + 'SUPIR_first_stage': _supir_first_stage_consumption, + 'BboxDetectorSEGS': _bbox_detector_segs_consumption, + 'ONNXDetectorSEGS': _bbox_detector_segs_consumption, + 'SegmDetectorSEGS': _segm_detector_for_each_consumption, + 'ImpactSimpleDetectorSEGS': _impact_simple_detector_segs_consumption, + 'CLIPSeg Masking': _clip_seg_masking_consumption, + 'LayerMask: PersonMaskUltra': _layermask_person_mask_ultra_consumption, + 'Image Rembg (Remove Background)': image_rembg_consumption, + 'SAMDetectorCombined': _sam_detector_combined_consumption, + 'Image Remove Background (Alpha)': was_remove_background_consumption, + + 'ADE_UseEvolvedSampling': _none_consumption_maker, + 'ModelSamplingSD3': _none_consumption_maker, + 'HighRes-Fix Script': _none_consumption_maker, + 'ImageBatch': _none_consumption_maker, + 'ControlNetApplyAdvanced': _none_consumption_maker, + 'ReActorLoadFaceModel': _none_consumption_maker, + 'SVD_img2vid_Conditioning': _none_consumption_maker, + 'VideoLinearCFGGuidance': _none_consumption_maker, + 'ConstrainImage|pysssss': _none_consumption_maker, + 'MiDaS-DepthMapPreprocessor': _none_consumption_maker, + 'VHS_BatchManager': _none_consumption_maker, + 'easy ultralyticsDetectorPipe': _none_consumption_maker, + 'ColorPreprocessor': _none_consumption_maker, + 'DWPreprocessor': _none_consumption_maker, + 'FreeU_V2': _none_consumption_maker, + 'ImageInvert': _none_consumption_maker, + 'XY Plot': _none_consumption_maker, + 'ApplyInstantID': _none_consumption_maker, + 'CR Apply Multi-ControlNet': _none_consumption_maker, + 'CR Multi-ControlNet Stack': _none_consumption_maker, + 'AnimeLineArtPreprocessor': _none_consumption_maker, + 'InstantIDFaceAnalysis': _none_consumption_maker, + 'IPAdapterAdvanced': _none_consumption_maker, + 'IPAdapter': _none_consumption_maker, + 'RepeatLatentBatch': _none_consumption_maker, + 'OpenposePreprocessor': _none_consumption_maker, + 'ADE_AnimateDiffSamplingSettings': _none_consumption_maker, + 'ADE_StandardStaticContextOptions': _none_consumption_maker, + 'SaveImage': _none_consumption_maker, + 'VAEDecode': _none_consumption_maker, + 'CLIPTextEncode': _none_consumption_maker, + 'LoraLoader': _none_consumption_maker, + 'CheckpointLoaderSimple': _none_consumption_maker, + 'VAEEncode': _none_consumption_maker, + 'Image Resize': _none_consumption_maker, + 'EmptyLatentImage': _none_consumption_maker, + 'ImageScale': _none_consumption_maker, + 'CLIPSetLastLayer': _none_consumption_maker, + 'LoadImage': _none_consumption_maker, + 'easy promptReplace': _none_consumption_maker, + 'Text Multiline': _none_consumption_maker, + 'VAELoader': _none_consumption_maker, + 'ConditioningSetTimestepRange': _none_consumption_maker, + 'UpscaleModelLoader': _none_consumption_maker, + 'LoraLoader|pysssss': _none_consumption_maker, + 'RebatchLatents': _none_consumption_maker, + 'LatentBatchSeedBehavior': _none_consumption_maker, + 'Efficient Loader': _none_consumption_maker, + 'SDXLPromptStyler': _none_consumption_maker, + 'ConditioningCombine': _none_consumption_maker, + 'ConditioningZeroOut': _none_consumption_maker, + 'TripleCLIPLoader': _none_consumption_maker, + 'LatentUpscale': _none_consumption_maker, + 'easy stylesSelector': _none_consumption_maker, + 'ComfyUIStyler': _none_consumption_maker, + 'CLIPTextEncodeSDXLRefiner': _none_consumption_maker, + 'easy ipadapterApply': _none_consumption_maker, + 'ArtistStyler': _none_consumption_maker, + 'FantasyStyler': _none_consumption_maker, + 'ADE_AnimateDiffLoaderGen1': _none_consumption_maker, + 'AestheticStyler': _none_consumption_maker, + 'ControlNetLoader': _none_consumption_maker, + 'SaveAnimatedWEBP': _none_consumption_maker, + 'CLIPTextEncodeSDXL': _none_consumption_maker, + 'ImageScaleBy': _none_consumption_maker, + 'ImageOnlyCheckpointLoader': _none_consumption_maker, + 'IPAdapterUnifiedLoader': _none_consumption_maker, + 'EnvironmentStyler': _none_consumption_maker, + 'MilehighStyler': _none_consumption_maker, + 'AnimeStyler': _none_consumption_maker, + 'ADE_AnimateDiffLoaderWithContext': _none_consumption_maker, + 'VHS_LoadVideo': _none_consumption_maker, + 'MoodStyler': _none_consumption_maker, + 'ReActorSaveFaceModel': _none_consumption_maker, + 'Camera_AnglesStyler': _none_consumption_maker, + 'TimeofdayStyler': _none_consumption_maker, + 'FaceStyler': _none_consumption_maker, + 'Breast_StateStyler': _none_consumption_maker, + 'easy seed': _none_consumption_maker, + 'EmptySD3LatentImage': _none_consumption_maker, + 'UltralyticsDetectorProvider': _none_consumption_maker, + 'CR Apply LoRA Stack': _none_consumption_maker, + 'Upscale Model Loader': _none_consumption_maker, + 'PortraitMaster': _none_consumption_maker, + 'PlaySound|pysssss': _none_consumption_maker, + 'WD14Tagger|pysssss': _none_consumption_maker, + 'SAMLoader': _none_consumption_maker, + 'ADE_AnimateDiffUniformContextOptions': _none_consumption_maker, + 'ToBasicPipe': _none_consumption_maker, + 'easy pipeOut': _none_consumption_maker, + 'easy pipeIn': _none_consumption_maker, + 'CR LoRA Stack': _none_consumption_maker, + 'InstantIDModelLoader': _none_consumption_maker, + 'LatentUpscaleBy': _none_consumption_maker, + 'ToDetailerPipe': _none_consumption_maker, + 'easy ipadapterStyleComposition': _none_consumption_maker, + 'Canny': _none_consumption_maker, + 'BaseModel_Loader_local': _none_consumption_maker, + 'CR Load LoRA': _none_consumption_maker, + 'SAM Model Loader': _none_consumption_maker, + 'CLIPLoader': _none_consumption_maker, + 'VHS_LoadImages': _none_consumption_maker, + 'easy fullLoader': _none_consumption_maker, + 'XY Input: Checkpoint': _none_consumption_maker, + 'MaskToImage': _none_consumption_maker, + 'CR Text Concatenate': _none_consumption_maker, + 'CR Text': _none_consumption_maker, + 'easy loadImageBase64': _none_consumption_maker, + 'easy clearCacheAll': _none_consumption_maker, + 'ShowText|pysssss': _none_consumption_maker, + 'ADE_AnimateDiffLoRALoader': _none_consumption_maker, + 'easy showTensorShape': _none_consumption_maker, + 'ConditioningConcat': _none_consumption_maker, + 'ConditioningAverage': _none_consumption_maker, + 'KSamplerSelect': _none_consumption_maker, + 'AlignYourStepsScheduler': _none_consumption_maker, + 'BasicScheduler': _none_consumption_maker, + 'ModelMergeSimple': _none_consumption_maker, + 'CLIPMergeSimple': _none_consumption_maker, + 'ControlNetApply': _none_consumption_maker, + 'ADE_LoadAnimateDiffModel': _none_consumption_maker, + 'LatentSwitch': _none_consumption_maker, + 'ADE_ApplyAnimateDiffModelSimple': _none_consumption_maker, + 'IPAdapterModelLoader': _none_consumption_maker, + 'ImageCrop': _none_consumption_maker, + 'CLIPSegDetectorProvider': _none_consumption_maker, + 'PreviewImage': _none_consumption_maker, + 'Eff. Loader SDXL': _none_consumption_maker, + 'Unpack SDXL Tuple': _none_consumption_maker, + 'Automatic CFG': _none_consumption_maker, + 'easy textSwitch': _none_consumption_maker, + 'ApplyInstantIDAdvanced': _none_consumption_maker, + 'easy a1111Loader': _none_consumption_maker, + 'Text to Conditioning': _none_consumption_maker, + 'SeargeSamplerInputs': _none_consumption_maker, + 'LineArtPreprocessor': _none_consumption_maker, + 'IPAdapterFaceID': _none_consumption_maker, + 'unCLIPCheckpointLoader': _none_consumption_maker, + 'UNETLoader': _none_consumption_maker, + 'ControlNetLoaderAdvanced': _none_consumption_maker, + 'EmptyImage': _none_consumption_maker, + 'ACN_AdvancedControlNetApply': _none_consumption_maker, + 'LoraLoaderModelOnly': _none_consumption_maker, + 'LoadAnimateDiffModelNode': _none_consumption_maker, + 'ADE_AnimateDiffKeyframe': _none_consumption_maker, + 'DiffControlNetLoader': _none_consumption_maker, + 'DiffControlNetLoaderAdvanced': _none_consumption_maker, + 'SDTurboScheduler': _none_consumption_maker, + 'Image Crop Face': _none_consumption_maker, + 'IPAdapterTiled': _none_consumption_maker, + 'SeargeInput1': _none_consumption_maker, + 'SeargeInput2': _none_consumption_maker, + 'SeargeInput3': _none_consumption_maker, + 'SeargeInput4': _none_consumption_maker, + 'SeargeInput5': _none_consumption_maker, + 'SeargeInput6': _none_consumption_maker, + 'SeargeInput7': _none_consumption_maker, + 'SeargeOutput1': _none_consumption_maker, + 'SeargeOutput2': _none_consumption_maker, + 'SeargeOutput3': _none_consumption_maker, + 'SeargeOutput4': _none_consumption_maker, + 'SeargeOutput5': _none_consumption_maker, + 'SeargeOutput6': _none_consumption_maker, + 'SeargeOutput7': _none_consumption_maker, + 'SeargeGenerated1': _none_consumption_maker, + 'SeargeVAELoader': _none_consumption_maker, + 'TilePreprocessor': _none_consumption_maker, + 'TTPlanet_TileGF_Preprocessor': _none_consumption_maker, + 'TTPlanet_TileSimple_Preprocessor': _none_consumption_maker, + 'IPAdapterNoise': _none_consumption_maker, + 'SetLatentNoiseMask': _none_consumption_maker, + 'easy loraStack': _none_consumption_maker, + 'ScaledSoftControlNetWeights': _none_consumption_maker, + 'ScaledSoftMaskedUniversalWeights': _none_consumption_maker, + 'SoftControlNetWeights': _none_consumption_maker, + 'CustomControlNetWeights': _none_consumption_maker, + 'SoftT2IAdapterWeights': _none_consumption_maker, + 'CustomT2IAdapterWeights': _none_consumption_maker, + 'ACN_DefaultUniversalWeights': _none_consumption_maker, + 'ACN_ReferencePreprocessor': _none_consumption_maker, + 'ACN_ReferenceControlNet': _none_consumption_maker, + 'ACN_ReferenceControlNetFnetune': _none_consumption_maker, + 'easy controlnetStack': _none_consumption_maker, + 'Control Net Stacker': _none_consumption_maker, + 'easy globalSeed': _none_consumption_maker, + "easy positive": _none_consumption_maker, + "easy negative": _none_consumption_maker, + "easy wildcards": _none_consumption_maker, + "easy prompt": _none_consumption_maker, + "easy promptList": _none_consumption_maker, + "easy promptLine": _none_consumption_maker, + "easy promptConcat": _none_consumption_maker, + "easy portraitMaster": _none_consumption_maker, + 'AIO_Preprocessor': _none_consumption_maker, + 'ImageResizeKJ': _none_consumption_maker, + 'ImagePadForOutpaint': _none_consumption_maker, + 'CannyEdgePreprocessor': _none_consumption_maker, + 'DepthAnythingPreprocessor': _none_consumption_maker, + 'Zoe_DepthAnythingPreprocessor': _none_consumption_maker, + 'IPAdapterInsightFaceLoader': _none_consumption_maker, + 'ReActorMaskHelper': _none_consumption_maker, + 'easy imageColorMatch': _none_consumption_maker, + 'Checkpoint Selector': _none_consumption_maker, + 'Save Image w/Metadata': _none_consumption_maker, + 'Sampler Selector': _none_consumption_maker, + 'Scheduler Selector': _none_consumption_maker, + 'Seed Generator': _none_consumption_maker, + 'String Literal': _none_consumption_maker, + 'Width/Height Literal': _none_consumption_maker, + 'Cfg Literal': _none_consumption_maker, + 'Int Literal': _none_consumption_maker, + 'ImpactImageBatchToImageList': _none_consumption_maker, + 'ImpactMakeImageList': _none_consumption_maker, + 'ImpactMakeImageBatch': _none_consumption_maker, + 'PhotoMakerLoader': _none_consumption_maker, + 'PhotoMakerEncode': _none_consumption_maker, + 'ImpactSEGSToMaskList': _none_consumption_maker, + 'GroundingDinoModelLoader (segment anything)': _none_consumption_maker, + 'VAEEncodeForInpaint': _none_consumption_maker, + 'ImpactWildcardProcessor': _none_consumption_maker, + 'ImpactWildcardEncode': _none_consumption_maker, + 'StringFunction|pysssss': _none_consumption_maker, + 'SAMModelLoader (segment anything)': _none_consumption_maker, + "INTConstant": _none_consumption_maker, + "FloatConstant": _none_consumption_maker, + "StringConstant": _none_consumption_maker, + "StringConstantMultiline": _none_consumption_maker, + "CR Image Input Switch": _none_consumption_maker, + "CR Image Input Switch (4 way)": _none_consumption_maker, + "CR Latent Input Switch": _none_consumption_maker, + "CR Conditioning Input Switch": _none_consumption_maker, + "CR Clip Input Switch": _none_consumption_maker, + "CR Model Input Switch": _none_consumption_maker, + "CR ControlNet Input Switch": _none_consumption_maker, + "CR VAE Input Switch": _none_consumption_maker, + "CR Text Input Switch": _none_consumption_maker, + "CR Text Input Switch (4 way)": _none_consumption_maker, + "CR Switch Model and CLIP": _none_consumption_maker, + 'SeargeFloatConstant': _none_consumption_maker, + "SeargeFloatPair": _none_consumption_maker, + "SeargeFloatMath": _none_consumption_maker, + 'easy showAnything': _none_consumption_maker, + 'CLIPVisionEncode': _none_consumption_maker, + 'unCLIPConditioning': _none_consumption_maker, + 'SeargeDebugPrinter': _none_consumption_maker, + 'LayerUtility: SaveImagePlus': _none_consumption_maker, + 'DetailerForEachDebug': _none_consumption_maker, + 'GlobalSeed //Inspire': _none_consumption_maker, + 'VAEDecodeTiled': _none_consumption_maker, + 'VAEEncodeTiled': _none_consumption_maker, + 'DualCLIPLoader': _none_consumption_maker, + 'RandomNoise': _none_consumption_maker, + 'BasicGuider': _none_consumption_maker, + 'SDXLPromptStylerAdvanced': _none_consumption_maker, + 'PerturbedAttentionGuidance': _none_consumption_maker, + 'Anything Everywhere3': _none_consumption_maker, + "CR SD1.5 Aspect Ratio": _none_consumption_maker, + "CR SDXL Aspect Ratio": _none_consumption_maker, + "CR Aspect Ratio": _none_consumption_maker, + "CR Aspect Ratio Banners": _none_consumption_maker, + "CR Aspect Ratio Social Media": _none_consumption_maker, + "CR_Aspect Ratio For Print": _none_consumption_maker, + 'Text Concatenate': _none_consumption_maker, + 'ImageBatchMulti': _none_consumption_maker, + 'MeshGraphormer-DepthMapPreprocessor': _none_consumption_maker, + 'MeshGraphormer+ImpactDetector-DepthMapPreprocessor': _none_consumption_maker, + 'LeReS-DepthMapPreprocessor': _none_consumption_maker, + 'IPAdapterUnifiedLoaderFaceID': _none_consumption_maker, + "ImageBlend": _none_consumption_maker, + "ImageBlur": _none_consumption_maker, + "ImageQuantize": _none_consumption_maker, + "ImageSharpen": _none_consumption_maker, + "ImageScaleToTotalPixels": _none_consumption_maker, + 'SD_4XUpscale_Conditioning': _none_consumption_maker, + 'Latent Input Switch': _none_consumption_maker, + 'FluxGuidance': _none_consumption_maker, + 'VAE Input Switch': _none_consumption_maker, + "Logic Comparison OR": _none_consumption_maker, + "Logic Comparison AND": _none_consumption_maker, + "Logic Comparison XOR": _none_consumption_maker, + 'ImpactControlNetApplySEGS': _none_consumption_maker, + 'DWPreprocessor_Provider_for_SEGS //Inspire': _none_consumption_maker, + 'CLIPVisionLoader': _none_consumption_maker, + 'SUPIR_conditioner': _none_consumption_maker, + 'ImageResize+': _none_consumption_maker, + 'SUPIR_model_loader_v2': _none_consumption_maker, + "SUPIR_Upscale": _none_consumption_maker, + "SUPIR_model_loader": _none_consumption_maker, + "SUPIR_tiles": _none_consumption_maker, + "SUPIR_model_loader_v2_clip": _none_consumption_maker, + 'ColorMatch': _none_consumption_maker, + 'EmptySegs': _none_consumption_maker, + 'VHS_VideoInfo': _none_consumption_maker, + 'ModelSamplingFlux': _none_consumption_maker, + 'SplitSigmas': _none_consumption_maker, + 'SegsToCombinedMask': _none_consumption_maker, + 'VHS_DuplicateImages': _none_consumption_maker, + 'CLIPTextEncodeFlux': _none_consumption_maker, + 'GetImageSize+': _none_consumption_maker, + 'MathExpression|pysssss': _none_consumption_maker, + 'SeargeIntegerConstant': _none_consumption_maker, + "SeargeIntegerPair": _none_consumption_maker, + "SeargeIntegerMath": _none_consumption_maker, + "SeargeIntegerScaler": _none_consumption_maker, + 'Seed': _none_consumption_maker, + 'FeatherMask': _none_consumption_maker, + 'CheckpointLoader|pysssss': _none_consumption_maker, + 'ImageColorMatch+': _none_consumption_maker, + 'PreviewDetailerHookProvider': _none_consumption_maker, + 'Image Analyze': _none_consumption_maker, + 'easy comfyLoader': _none_consumption_maker, + 'easy controlnetLoaderADV': _none_consumption_maker, + 'LoRALoader': _none_consumption_maker, + 'LoadCLIPSegModels+': _none_consumption_maker, + 'ApplyCLIPSeg+': _none_consumption_maker, + 'LayerMask: MaskPreview': _none_consumption_maker, + 'CLIPSeg Model Loader': _none_consumption_maker, + 'Text to Console': _none_consumption_maker, + 'FromBasicPipe': _none_consumption_maker, + 'Lora Loader': _none_consumption_maker, + 'SDXLEmptyLatentSizePicker+': _none_consumption_maker, + 'LayerColor: AutoAdjust': _none_consumption_maker, + 'LayerUtility: PurgeVRAM': _none_consumption_maker, + 'LoRA Stacker': _none_consumption_maker, + 'Text Parse A1111 Embeddings': _none_consumption_maker, + 'easy showSpentTime': _none_consumption_maker, + 'Latent Upscale by Factor (WAS)': _none_consumption_maker, + 'ImpactControlBridge': _none_consumption_maker, + 'LoRA Stack to String converter': _none_consumption_maker, + 'LatentFromBatch': _none_consumption_maker, + 'AnimateDiffModuleLoader': _none_consumption_maker, + 'Image Size to Number': _none_consumption_maker, + 'Constant Number': _none_consumption_maker, + 'Number Operation': _none_consumption_maker, + 'Image Bounds': _none_consumption_maker, + 'Inset Image Bounds': _none_consumption_maker, + 'Bounded Image Crop': _none_consumption_maker, + 'PreviewBridge': _none_consumption_maker, + 'ImageToMask': _none_consumption_maker, + 'ToBinaryMask': _none_consumption_maker, + 'Mask Smooth Region': _none_consumption_maker, + 'Mask Erode Region': _none_consumption_maker, + 'Mask Dilate Region': _none_consumption_maker, + 'MaskToSEGS': _none_consumption_maker, + 'ImpactSEGSOrderedFilter': _none_consumption_maker, + 'Image Blank': _none_consumption_maker, + 'Image Blend by Mask': _none_consumption_maker, + 'SubtractMask': _none_consumption_maker, + 'InvertMask': _none_consumption_maker, + 'ImpactDilateMask': _none_consumption_maker, + 'BitwiseAndMask': _none_consumption_maker, + 'Mask Crop Region': _none_consumption_maker, + 'Image Crop Location': _none_consumption_maker, + 'Image Select Channel': _none_consumption_maker, + 'Image Levels Adjustment': _none_consumption_maker, + 'Image Mix RGB Channels': _none_consumption_maker, + 'Image Filter Adjustments': _none_consumption_maker, + 'Image to Noise': _none_consumption_maker, + 'AnimateDiffSlidingWindowOptions': _none_consumption_maker, + 'CR Seed': _none_consumption_maker, + 'LayerFilter: SoftLight': _none_consumption_maker, + 'LayerColor: Color of Shadow & Highlight': _none_consumption_maker, + 'easy imageSize': _none_consumption_maker, + 'easy imageScaleDownBy': _none_consumption_maker, + 'CFGGuider': _none_consumption_maker, + 'SolidMask': _none_consumption_maker, + 'MaskComposite': _none_consumption_maker, + 'ConditioningSetMask': _none_consumption_maker, + 'Image Blending Mode': _none_consumption_maker, + 'MaskBlur+': _none_consumption_maker, + 'SamplerDPMPP_3M_SDE': _none_consumption_maker, + 'KarrasScheduler': _none_consumption_maker, + 'SamplerEulerCFGpp': _none_consumption_maker, + 'ADE_AnimateDiffCombine': _none_consumption_maker, + 'comfy.Seed (rgthree)': _none_consumption_maker, + + "CogVideoDecode": _none_consumption_maker, + "CogVideoTextEncode": _none_consumption_maker, + "CogVideoImageEncode": _none_consumption_maker, + "CogVideoTextEncodeCombine": _none_consumption_maker, + "CogVideoTransformerEdit": _none_consumption_maker, + "CogVideoContextOptions": _none_consumption_maker, + "CogVideoControlNet": _none_consumption_maker, + "ToraEncodeTrajectory": _none_consumption_maker, + "ToraEncodeOpticalFlow": _none_consumption_maker, + "CogVideoXFasterCache": _none_consumption_maker, + "CogVideoXFunResizeToClosestBucket": _none_consumption_maker, + "CogVideoLatentPreview": _none_consumption_maker, + "CogVideoXTorchCompileSettings": _none_consumption_maker, + "CogVideoImageEncodeFunInP": _none_consumption_maker, + + "DownloadAndLoadMochiModel": _none_consumption_maker, + "MochiDecode": _none_consumption_maker, + "MochiTextEncode": _none_consumption_maker, + "MochiModelLoader": _none_consumption_maker, + "MochiVAELoader": _none_consumption_maker, + "MochiVAEEncoderLoader": _none_consumption_maker, + "MochiDecodeSpatialTiling": _none_consumption_maker, + "MochiTorchCompileSettings": _none_consumption_maker, + "MochiImageEncode": _none_consumption_maker, + "MochiLatentPreview": _none_consumption_maker, + "MochiSigmaSchedule": _none_consumption_maker, + "MochiFasterCache": _none_consumption_maker, + + "Lora Loader (JPS)": _none_consumption_maker, + "SDXL Resolutions (JPS)": _none_consumption_maker, + "SDXL Basic Settings (JPS)": _none_consumption_maker, + "SDXL Settings (JPS)": _none_consumption_maker, + "Generation TXT IMG Settings (JPS)": _none_consumption_maker, + "Crop Image Settings (JPS)": _none_consumption_maker, + "ImageToImage Settings (JPS)": _none_consumption_maker, + "CtrlNet CannyEdge Settings (JPS)": _none_consumption_maker, + "CtrlNet ZoeDepth Settings (JPS)": _none_consumption_maker, + "CtrlNet MiDaS Settings (JPS)": _none_consumption_maker, + "CtrlNet OpenPose Settings (JPS)": _none_consumption_maker, + "Revision Settings (JPS)": _none_consumption_maker, + "IP Adapter Settings (JPS)": _none_consumption_maker, + "IP Adapter Tiled Settings (JPS)": _none_consumption_maker, + "InstantID Settings (JPS)": _none_consumption_maker, + "Image Prepare Settings (JPS)": _none_consumption_maker, + "InstantID Source Prepare Settings (JPS)": _none_consumption_maker, + "InstantID Pose Prepare Settings (JPS)": _none_consumption_maker, + "InstantID Mask Prepare Settings (JPS)": _none_consumption_maker, + "Sampler Scheduler Settings (JPS)": _none_consumption_maker, + "Integer Switch (JPS)": _none_consumption_maker, + "Image Switch (JPS)": _none_consumption_maker, + "Latent Switch (JPS)": _none_consumption_maker, + "Conditioning Switch (JPS)": _none_consumption_maker, + "Model Switch (JPS)": _none_consumption_maker, + "IPA Switch (JPS)": _none_consumption_maker, + "VAE Switch (JPS)": _none_consumption_maker, + "Mask Switch (JPS)": _none_consumption_maker, + "ControlNet Switch (JPS)": _none_consumption_maker, + "Disable Enable Switch (JPS)": _none_consumption_maker, + "Enable Disable Switch (JPS)": _none_consumption_maker, + "SDXL Basic Settings Pipe (JPS)": _none_consumption_maker, + "SDXL Settings Pipe (JPS)": _none_consumption_maker, + "Crop Image Pipe (JPS)": _none_consumption_maker, + "ImageToImage Pipe (JPS)": _none_consumption_maker, + "CtrlNet CannyEdge Pipe (JPS)": _none_consumption_maker, + "CtrlNet ZoeDepth Pipe (JPS)": _none_consumption_maker, + "CtrlNet MiDaS Pipe (JPS)": _none_consumption_maker, + "CtrlNet OpenPose Pipe (JPS)": _none_consumption_maker, + "IP Adapter Settings Pipe (JPS)": _none_consumption_maker, + "IP Adapter Tiled Settings Pipe (JPS)": _none_consumption_maker, + "InstantID Pipe (JPS)": _none_consumption_maker, + "Image Prepare Pipe (JPS)": _none_consumption_maker, + "InstantID Source Prepare Pipe (JPS)": _none_consumption_maker, + "InstantID Pose Prepare Pipe (JPS)": _none_consumption_maker, + "InstantID Mask Prepare Pipe (JPS)": _none_consumption_maker, + "Revision Settings Pipe (JPS)": _none_consumption_maker, + "SDXL Fundamentals MultiPipe (JPS)": _none_consumption_maker, + "Images Masks MultiPipe (JPS)": _none_consumption_maker, + "SDXL Recommended Resolution Calc (JPS)": _none_consumption_maker, + "Resolution Multiply (JPS)": _none_consumption_maker, + "Largest Int (JPS)": _none_consumption_maker, + "Multiply Int Int (JPS)": _none_consumption_maker, + "Multiply Int Float (JPS)": _none_consumption_maker, + "Multiply Float Float (JPS)": _none_consumption_maker, + "Substract Int Int (JPS)": _none_consumption_maker, + "Text Concatenate (JPS)": _none_consumption_maker, + "Get Date Time String (JPS)": _none_consumption_maker, + "Get Image Size (JPS)": _none_consumption_maker, + "Crop Image Square (JPS)": _none_consumption_maker, + "Crop Image TargetSize (JPS)": _none_consumption_maker, + "Prepare Image (JPS)": _none_consumption_maker, + "Prepare Image Plus (JPS)": _none_consumption_maker, + "Prepare Image Tiled IPA (JPS)": _none_consumption_maker, + "SDXL Prompt Styler (JPS)": _none_consumption_maker, + "SDXL Prompt Handling (JPS)": _none_consumption_maker, + "SDXL Prompt Handling Plus (JPS)": _none_consumption_maker, + "Text Prompt (JPS)": _none_consumption_maker, + "Text Prompt Combo (JPS)": _none_consumption_maker, + "Save Images Plus (JPS)": _none_consumption_maker, + "CLIPTextEncode SDXL Plus (JPS)": _none_consumption_maker, + "Time Seed (JPS)": _none_consumption_maker, + + "Sine Curve [Dream]": _none_consumption_maker, + "Linear Curve [Dream]": _none_consumption_maker, + "CSV Curve [Dream]": _none_consumption_maker, + "Beat Curve [Dream]": _none_consumption_maker, + "Common Frame Dimensions [Dream]": _none_consumption_maker, + "Image Motion [Dream]": _none_consumption_maker, + "Noise from Palette [Dream]": _none_consumption_maker, + "Analyze Palette [Dream]": _none_consumption_maker, + "Palette Color Shift [Dream]": _none_consumption_maker, + "File Count [Dream]": _none_consumption_maker, + "Frame Counter Offset [Dream]": _none_consumption_maker, + "Frame Counter (Directory) [Dream]": _none_consumption_maker, + "Frame Counter (Simple) [Dream]": _none_consumption_maker, + "Image Sequence Loader [Dream]": _none_consumption_maker, + "Image Sequence Saver [Dream]": _none_consumption_maker, + "CSV Generator [Dream]": _none_consumption_maker, + "Sample Image Area as Palette [Dream]": _none_consumption_maker, + "FFMPEG Video Encoder [Dream]": _none_consumption_maker, + "Image Sequence Tweening [Dream]": _none_consumption_maker, + "Image Sequence Blend [Dream]": _none_consumption_maker, + "Palette Color Align [Dream]": _none_consumption_maker, + "Sample Image as Palette [Dream]": _none_consumption_maker, + "Noise from Area Palettes [Dream]": _none_consumption_maker, + "String Input [Dream]": _none_consumption_maker, + "Float Input [Dream]": _none_consumption_maker, + "Int Input [Dream]": _none_consumption_maker, + "Text Input [Dream]": _none_consumption_maker, + "Big Latent Switch [Dream]": _none_consumption_maker, + "Frame Count Calculator [Dream]": _none_consumption_maker, + "Big Image Switch [Dream]": _none_consumption_maker, + "Big Text Switch [Dream]": _none_consumption_maker, + "Big Float Switch [Dream]": _none_consumption_maker, + "Big Int Switch [Dream]": _none_consumption_maker, + "Big Palette Switch [Dream]": _none_consumption_maker, + "Build Prompt [Dream]": _none_consumption_maker, + "Finalize Prompt [Dream]": _none_consumption_maker, + "Frame Counter Info [Dream]": _none_consumption_maker, + "Boolean To Float [Dream]": _none_consumption_maker, + "Boolean To Int [Dream]": _none_consumption_maker, + "Saw Curve [Dream]": _none_consumption_maker, + "Triangle Curve [Dream]": _none_consumption_maker, + "Triangle Event Curve [Dream]": _none_consumption_maker, + "Smooth Event Curve [Dream]": _none_consumption_maker, + "Calculation [Dream]": _none_consumption_maker, + "Image Color Shift [Dream]": _none_consumption_maker, + "Compare Palettes [Dream]": _none_consumption_maker, + "Image Contrast Adjustment [Dream]": _none_consumption_maker, + "Image Brightness Adjustment [Dream]": _none_consumption_maker, + "Laboratory [Dream]": _none_consumption_maker, + "String to Log Entry [Dream]": _none_consumption_maker, + "Int to Log Entry [Dream]": _none_consumption_maker, + "Float to Log Entry [Dream]": _none_consumption_maker, + "Log Entry Joiner [Dream]": _none_consumption_maker, + "String Tokenizer [Dream]": _none_consumption_maker, + "WAV Curve [Dream]": _none_consumption_maker, + "Frame Counter Time Offset [Dream]": _none_consumption_maker, + "Lora Loader Stack (rgthree)": _none_consumption_maker, + "PatchModelAddDownscale": _none_consumption_maker, + "Power Prompt (rgthree)": _none_consumption_maker, + "EmptyMochiLatentVideo": _none_consumption_maker, + "Seed (rgthree)": _none_consumption_maker, + "Image Comparer (rgthree)": _none_consumption_maker, + "CR Text Replace": _none_consumption_maker, + "VHS_MergeImages": _none_consumption_maker, + "Ref_Image_Preprocessing": _none_consumption_maker, + "ADE_AttachLoraHookToConditioning": _none_consumption_maker, + "ADE_CombineLoraHooksFour": _none_consumption_maker, + "ADE_RegisterModelAsLoraHook": _none_consumption_maker, + "LoadImageMask": _none_consumption_maker, + "JoinImageWithAlpha": _none_consumption_maker, + "SplitImageWithAlpha": _none_consumption_maker, + "DownloadAndLoadCogVideoModel": _none_consumption_maker, + "easy loraStackApply": _none_consumption_maker, + "DifferentialDiffusion": _none_consumption_maker, + "InpaintModelConditioning": _none_consumption_maker, + "ImpactGaussianBlurMask": _none_consumption_maker, + "MaskPreview+": _none_consumption_maker, + "ImageSender": _none_consumption_maker, + "ReActorFaceBoost": _none_consumption_maker, + "PrepImageForClipVision": _none_consumption_maker, + "ModelMergeSDXL": _none_consumption_maker, + "easy int": _none_consumption_maker, + "Display Int (rgthree)": _none_consumption_maker, + "Text Random Line": _none_consumption_maker, + "ImpactFloat": _none_consumption_maker, + "SeargeSDXLRefinerPromptEncoder": _none_consumption_maker, + "Text Find and Replace": _none_consumption_maker, + "Mask Invert": _none_consumption_maker, + "MaskFlip+": _none_consumption_maker, + "EmptyLatentImagePresets": _none_consumption_maker, + "ModelSamplingDiscrete": _none_consumption_maker, + "TripleCLIPLoaderGGUF": _none_consumption_maker, + "SeargePromptText": _none_consumption_maker, + "Prompts Everywhere": _none_consumption_maker, + "Anything Everywhere": _none_consumption_maker, + "IPAdapterStyleComposition": _none_consumption_maker, + "RescaleCFG": _none_consumption_maker, + "BAE-NormalMapPreprocessor": _none_consumption_maker, + "InpaintPreprocessor": _none_consumption_maker, + "ScribblePreprocessor": _none_consumption_maker, + "CLIPTextEncodeSD3": _none_consumption_maker, + "Image Dragan Photography Filter": _none_consumption_maker, + "CreateFadeMaskAdvanced": _none_consumption_maker, + "Power Lora Loader (rgthree)": _none_consumption_maker, + "CLIPLoaderGGUF": _none_consumption_maker, + "ConditioningUpscale //Inspire": _none_consumption_maker, + "Paste By Mask": _none_consumption_maker, + "Text Find and Replace Input": _none_consumption_maker, + "Text String": _none_consumption_maker, + "ReActorOptions": _none_consumption_maker, + "ModelMergeAdd": _none_consumption_maker, + "easy controlnetNames": _none_consumption_maker, + "LayerColor: Color of Shadow & Highligh": _none_consumption_maker, + "CreateHookLora": _none_consumption_maker, + "CreateHookLoraModelOnly": _none_consumption_maker, + "CreateHookModelAsLora": _none_consumption_maker, + "CreateHookModelAsLoraModelOnly": _none_consumption_maker, + "SetHookKeyframes": _none_consumption_maker, + "CreateHookKeyframe": _none_consumption_maker, + "CreateHookKeyframesInterpolated": _none_consumption_maker, + "CreateHookKeyframesFromFloats": _none_consumption_maker, + "CombineHooks2": _none_consumption_maker, + "CombineHooks4": _none_consumption_maker, + "CombineHooks8": _none_consumption_maker, + "ConditioningSetProperties": _none_consumption_maker, + "ConditioningSetPropertiesAndCombine": _none_consumption_maker, + "PairConditioningCombine": _none_consumption_maker, + "PairConditioningSetDefaultCombine": _none_consumption_maker, + "ConditioningSetDefaultCombine": _none_consumption_maker, + "SetClipHooks": _none_consumption_maker, + "ConditioningTimestepsRange": _none_consumption_maker, +} + + +def get_monitor_params(obj, obj_type, input_data_all): + func = _NODE_CONSUMPTION_MAPPING.get(obj_type, _default_consumption_maker) + consumption = _map_node_consumption_over_list(obj, input_data_all, func) + if consumption and 'opts' in consumption: + for opt in consumption['opts']: + opt["ratio"] = opt.get("ratio", 1) * 1.8 + + return consumption diff --git a/diffus/image_gallery.py b/diffus/image_gallery.py new file mode 100644 index 00000000000..d9883a8045a --- /dev/null +++ b/diffus/image_gallery.py @@ -0,0 +1,174 @@ +import json +import logging +import os +import uuid +from typing import Iterable + +import requests + +import execution_context +from diffus.service_registrar import get_service_node +import folder_paths + +logger = logging.getLogger(__name__) + + +def _do_post_image_to_gallery( + post_url, + task_id, + user_id, + user_hash, + image_type, + image_subfolder, + image_filename, + positive_prompt: str, + pnginfo, + model_base: str, + model_ids: list[int], +): + if image_type != "output": + return + post_json = { + "task_id": task_id, + "path": os.path.join( + folder_paths.get_relative_output_directory(user_hash), + image_subfolder, + image_filename, + ), + "feature": "COMFYUI", + "pnginfo": json.dumps(pnginfo), + "created_by": user_id, + "base": model_base, + "prompt": positive_prompt, # positive prompt + "model_ids": model_ids, # checkpoint, loras + "is_public": False, + } + resp = requests.post( + url=post_url, + json=post_json + ) + + if 199 < resp.status_code < 300: + logger.debug( + f"succeeded to post image to gallery, {resp.status_code} {resp.text}, url={post_url}, post_json={post_json}" + ) + else: + logger.error( + f"failed to post image to gallery, {resp.status_code} {resp.text}, url={post_url}, post_json={post_json}" + ) + + +def post_output_to_image_gallery(redis_client, node_obj, header_dict, input_data, output_data): + if not output_data: + return + + user_hash = _find_user_hash_from_input_data(input_data) + if not user_hash: + return + + user_id = header_dict.get('user-id', None) or header_dict.get('user-id', None) + if not user_id: + return + + result_data, ui_data, _ = output_data + + if not isinstance(ui_data, dict): + return + + proceeded_files = set() + + task_id = header_dict.get('x-task-id', str(uuid.uuid4())) + gallery_service_node = get_service_node(redis_client, "gallery") + if not gallery_service_node: + logger.warning("no gallery service node is found") + return + image_server_endpoint = f"{gallery_service_node.host_url}/gallery-api/v1/images" + + exec_context = _find_execution_context_from_input_data(input_data) + for images_key in ("images", "gifs", "video"): + if images_key not in ui_data: + continue + + if not isinstance(ui_data[images_key], Iterable): + continue + for image in ui_data[images_key]: + if not isinstance(image, dict): + logger.error("image is not a dict, do nothing") + continue + image_type = image["type"] + image_subfolder = image["subfolder"] + image_filename = image["filename"] + pnginfo = _find_extra_pnginfo_from_input_data(exec_context, input_data=input_data) + + if image_filename in proceeded_files: + continue + _do_post_image_to_gallery( + image_server_endpoint, + task_id, + user_id, + user_hash, + image_type, + image_subfolder, + image_filename, + exec_context.positive_prompt, + pnginfo, + exec_context.checkpoints_model_base, + exec_context.loaded_model_ids + ) + proceeded_files.add(image_filename) + + if hasattr(node_obj, "RETURN_TYPES") and "VHS_FILENAMES" in node_obj.RETURN_TYPES: + for node_result in result_data[node_obj.RETURN_TYPES.index("VHS_FILENAMES")]: + if node_result[0]: + for filepath in node_result[1]: + output_directory_len = len(folder_paths.get_output_directory(user_hash)) + filename = filepath[output_directory_len + 1:] + if filename in proceeded_files: + continue + _do_post_image_to_gallery( + image_server_endpoint, + task_id, + user_id, + user_hash, + "output", + "", + filename, + exec_context.positive_prompt, + {}, + exec_context.checkpoints_model_base, + exec_context.loaded_model_ids + ) + proceeded_files.add(filename) + + +def _find_user_hash_from_input_data(input_data): + if not isinstance(input_data, dict): + return "" + for key, value in input_data.items(): + if key == "user_hash": + return value[0] + elif key == "context": + return value[0].user_hash + return "" + + +def _find_execution_context_from_input_data(input_data): + import execution_context + if not isinstance(input_data, dict): + return None + for key, value in input_data.items(): + if key == "context": + return value[0] + return execution_context.ExecutionContext({}) + + +def _find_extra_pnginfo_from_input_data(context: execution_context.ExecutionContext, input_data): + if not isinstance(input_data, dict): + return "" + for key, value in input_data.items(): + if key == "extra_pnginfo" and value: + pnginfo = value[0] + pnginfo["parameters"] = context.geninfo if context else {} + return pnginfo + return {} + diff --git a/diffus/message.py b/diffus/message.py new file mode 100644 index 00000000000..85b406eac8b --- /dev/null +++ b/diffus/message.py @@ -0,0 +1,23 @@ +import logging +import diffus.redis_client + +logger = logging.getLogger(__name__) + + +class MessageQueue: + def __init__(self): + self._redis_client = diffus.redis_client.get_redis_client() + + def send_message(self, sid: str, message: bytes | str, retry=1): + if not self._redis_client: + self._redis_client = diffus.redis_client.get_redis_client() + try: + if sid: + self._redis_client.rpush(f'COMFYUI_MESSAGE_{sid}', message) + else: + self._redis_client.publish(f'COMFYUI_MESSAGE_anonymous', message) + except Exception as e: + logger.exception(e) + if retry > 0: + self._redis_client = None + self.send_message(sid, message, retry=retry - 1) diff --git a/diffus/models.py b/diffus/models.py new file mode 100644 index 00000000000..5203e444c99 --- /dev/null +++ b/diffus/models.py @@ -0,0 +1,31 @@ +from sqlalchemy import Column, Integer, String + +from diffus import database + +FAVORITE_MODEL_TYPES = { + 'checkpoints': 'CHECKPOINT', + 'loras': 'LORA', + 'lycoris': 'LYCORIS', +} + + +class Model(database.Base): + __tablename__ = "models" + + id = Column(Integer, primary_key=True, index=True) + model_type = Column(String) + base = Column(String) + + stem = Column(String) + extension = Column(String, index=True) + + sha256 = Column(String, index=True) + config_sha256 = Column(String) + + +class FavoriteModel(database.Base): + __tablename__ = "favorite_models" + + id = Column(Integer, primary_key=True) + favorited_by = Column(String, index=True) + model_id = Column(Integer, index=True) diff --git a/diffus/redis_client.py b/diffus/redis_client.py new file mode 100644 index 00000000000..7ca4659c955 --- /dev/null +++ b/diffus/redis_client.py @@ -0,0 +1,21 @@ +import os +import time +import redis +import logging + +logger = logging.getLogger(__name__) + + +def get_redis_client() -> redis.Redis: + redis_address = os.getenv('REDIS_ADDRESS') + while True: + try: + if redis_address: + redis_client = redis.Redis.from_url(url=redis_address) + redis_client.ping() + else: + redis_client = None + return redis_client + except Exception as e: + logger.exception(f'failed to create redis client from {redis_address}: {e.__str__()}') + time.sleep(3) diff --git a/diffus/repository.py b/diffus/repository.py new file mode 100644 index 00000000000..eaede737b0c --- /dev/null +++ b/diffus/repository.py @@ -0,0 +1,119 @@ +import os +import pathlib +from typing import Literal + +from pydantic import BaseModel +from sqlalchemy.orm import Session, Query + +from diffus import models, database + +MODEL_BINARY_CONTAINER = os.getenv('MODEL_BINARY_CONTAINER') +MODEL_CONFIG_CONTAINER = os.getenv('MODEL_CONFIG_CONTAINER') + + +def get_binary_path(sha256: str): + sha256 = sha256.lower() + return pathlib.Path(MODEL_BINARY_CONTAINER, sha256[0:2], sha256[2:4], sha256[4:6], sha256) + + +def get_config_path(sha256: str) -> pathlib.Path: + sha256 = sha256.lower() + return pathlib.Path(MODEL_CONFIG_CONTAINER, sha256) + + +class ModelInfo(BaseModel): + id: int + model_type: Literal["checkpoint", "embedding", "hypernetwork", "lora", "lycoris"] + base: str | None + stem: str + extension: str + sha256: str + config_sha256: str | None + + @property + def name(self): + return f'{self.stem}.{self.extension}' + + @property + def is_safetensors(self) -> bool: + return self.extension in ("safetensors", "sft") + + @property + def filename(self) -> str: + return str(get_binary_path(self.sha256)) + + def __str__(self): + return self.filename + + +def create_model_info(record: models.Model) -> ModelInfo: + return ModelInfo( + id=record.id, + model_type=str(record.model_type).lower(), + base=record.base, + stem=record.stem, + extension=record.extension, + sha256=record.sha256, + config_sha256=record.config_sha256, + ) + + +def list_favorite_model_by_model_type(user_id: str, folder_name: str): + if folder_name not in models.FAVORITE_MODEL_TYPES: + return [] + model_type = models.FAVORITE_MODEL_TYPES[folder_name] + with database.Database() as session: + query = _make_favorite_model_query(session) + query = _filter_favorite_model_by_model_type(query, user_id, model_type) + return [create_model_info(ckpt).name for ckpt in session.scalars(query)] + + +def get_favorite_model_full_path(user_id: str, folder_name: str, filename: str) -> ModelInfo | None: + if folder_name not in models.FAVORITE_MODEL_TYPES: + return None + model_type = models.FAVORITE_MODEL_TYPES[folder_name] + with database.Database() as session: + query = _make_favorite_model_query(session) + query = _filter_favorite_model_by_model_type(query, user_id, model_type) + query = _filter_model_by_name(query, filename) + record = session.scalar(query) + if not record: + raise Exception(f"model is not found, [{folder_name}]/{filename}") + return create_model_info(record) + + +def _make_favorite_model_query(session: Session) -> Query: + return session.query( + models.Model + ).join( + models.FavoriteModel, + models.Model.id == models.FavoriteModel.model_id, + isouter=True + ) + + +def _filter_model_by_sha256(query: Query, sha256: str) -> Query: + return query.filter(models.Model.sha256 == sha256) + + +def _filter_model_by_name(query: Query, name: str) -> Query: + filename = pathlib.Path(name) + suffix = filename.suffix + if suffix: + suffix = suffix[1:] + return query.filter( + models.Model.stem == filename.stem, + models.Model.extension == suffix, + ) + + +def _filter_favorite_model_by_model_type(query: Query, user_id: str, model_type: str) -> Query: + query = query.filter( + models.FavoriteModel.favorited_by == user_id + ) + if model_type in {"LORA", "LYCORIS"}: + query = query.filter(models.Model.model_type.in_(["LORA", "LYCORIS"])) + else: + query = query.filter(models.Model.model_type == model_type) + + return query diff --git a/diffus/service_registrar.py b/diffus/service_registrar.py new file mode 100644 index 00000000000..aaed9c4d0e8 --- /dev/null +++ b/diffus/service_registrar.py @@ -0,0 +1,63 @@ +import json +import random + +from pydantic import BaseModel + +import logging + +_logger = logging.getLogger(__name__) + + +class ServiceNode(BaseModel): + service: str + status: str = "UP" + ip: str + port: str + schema: str = "http" + health_check: str = "/health" + + def __key(self): + return self.service, self.ip, self.port, self.schema + + def __hash__(self): + return hash(self.__key()) + + def __eq__(self, other): + if isinstance(other, ServiceNode): + return self.__key() == other.__key() + return NotImplemented + + @property + def alive(self): + return self.status.lower() == "up" + + @property + def host_url(self): + return f"{self.schema}://{self.ip}:{self.port}" + + @staticmethod + def from_json_str(json_str: str, service_name: str = ''): + node_dict = json.loads(json_str) + if service_name: + node_dict["service"] = service_name + node_dict["port"] = str(node_dict["port"]) + return ServiceNode(**node_dict) + + +def get_service_node(redis_client, service_name: str) -> ServiceNode | None: + service_pattern = f"service:{service_name}_*" + alive_nodes = [] + for instance_id in redis_client.scan_iter(service_pattern): + if not instance_id: + continue + status_str = redis_client.get(instance_id) + try: + node = ServiceNode.from_json_str(status_str, service_name) + if node.alive: + alive_nodes.append(node) + except Exception as e: + _logger.warning(f"make node status for '{instance_id}' from {status_str} failed: {e.__str__()}") + if alive_nodes: + return random.choice(alive_nodes) + else: + return None diff --git a/diffus/system_monitor.py b/diffus/system_monitor.py new file mode 100644 index 00000000000..64558b79e01 --- /dev/null +++ b/diffus/system_monitor.py @@ -0,0 +1,324 @@ +import json +import logging +import os +import time +import uuid +from contextlib import contextmanager +from typing import Optional, Tuple + +import requests + +import diffus.decoded_params +import diffus.task_queue +from diffus.image_gallery import post_output_to_image_gallery +from diffus.redis_client import get_redis_client + +logger = logging.getLogger(__name__) + + +class MonitorException(Exception): + def __init__(self, status_code: int, code: str, message: str): + self.status_code = status_code + self.code = code + self.message = message + + def __repr__(self) -> str: + return f"{self.status_code} {self.code} {self.message}" + + +class MonitorTierMismatchedException(Exception): + def __init__(self, msg, current_tier, allowed_tiers): + self._msg = msg + self.current_tier = current_tier + self.allowed_tiers = allowed_tiers + + def __repr__(self) -> str: + return self._msg + + +def _get_system_monitor_config(headers: dict) -> Tuple[str, str]: + # take per-task config as priority instead of global config + monitor_addr = headers.get( + 'x-diffus-system-monitor-url', "" + ) or headers.get( + 'X-Diffus-System-Monitor-Url', "" + ) + system_monitor_api_secret = headers.get( + 'x-diffus-system-monitor-api-secret', "" + ) or headers.get( + 'X-Diffus-System-Monitor-Api-Secret', "" + ) + return monitor_addr, system_monitor_api_secret + + +def _make_headers(extra_data: dict): + headers = extra_data.get('diffus-request-headers', {}) + result = {} + for key, value in headers.items(): + key = key.lower() + if isinstance(value, list): + if len(value) > 0: + result[key] = value[0] + else: + result[key] = '' + else: + result[key] = value + return result + + +def _before_task_started( + header_dict: dict, + api_name: str, + function_name: str, + job_id: Optional[str] = None, + decoded_params: Optional[dict] = None, + is_intermediate: bool = False, + refund_if_task_failed: bool = True, + only_available_for: Optional[list[str]] = None) -> Optional[str]: + if decoded_params is None and is_intermediate: + return '' + + if job_id is None: + job_id = str(uuid.uuid4()) + monitor_addr, system_monitor_api_secret = _get_system_monitor_config(header_dict) + if not monitor_addr or not system_monitor_api_secret: + logger.error(f'{job_id}: system_monitor_addr or system_monitor_api_secret is not present') + return None + + session_hash = header_dict.get('x-session-hash', "") + if not session_hash: + logger.error(f'{job_id}: x-session-hash does not presented in headers') + return None + task_id = header_dict.get('x-task-id', "") + if not task_id: + logger.error(f'{job_id}: x-task-id does not presented in headers') + return None + if not is_intermediate and task_id != job_id: + logger.error(f'x-task-id ({task_id}) and job_id ({job_id}) are not equal') + deduct_flag = header_dict.get('x-deduct-credits', "") + deduct_flag = not (deduct_flag == 'false') + if only_available_for: + user_tier = header_dict.get('user-tire', '') or header_dict.get('user-tier', '') + if not user_tier or user_tier.lower() not in [item.lower() for item in only_available_for]: + raise MonitorTierMismatchedException( + f'This feature is available for {only_available_for} only. The current user tier is {user_tier}.', + user_tier, + only_available_for) + + user_id = header_dict.get('user-id', None) or header_dict.get('user-id', None) + request_data = { + 'api': api_name, + 'initiator': function_name, + 'user': user_id, + 'started_at': time.time(), + 'session_hash': session_hash, + 'skip_charge': not deduct_flag, + 'refund_if_task_failed': refund_if_task_failed, + 'node': os.getenv('HOST_IP', default=''), + } + if is_intermediate: + request_data['step_id'] = job_id + request_data['task_id'] = task_id + else: + request_data['task_id'] = job_id + request_data['decoded_params'] = decoded_params if decoded_params is not None else {} + resp = requests.post( + monitor_addr, + headers={ + 'Api-Secret': system_monitor_api_secret, + }, + json=request_data + ) + logger.info(json.dumps(request_data, ensure_ascii=False, sort_keys=True)) + + # check response, raise exception if status code is not 2xx + if 199 < resp.status_code < 300: + return job_id + + content = resp.json() + # log the response if request failed + logger.error(f'create monitor log failed, status: {resp.status_code}, content: {content}') + raise MonitorException(resp.status_code, content["code"], content["message"]) + + +def _after_task_finished( + header_dict: dict, + job_id: Optional[str], + status: str, + message: Optional[str] = None, + is_intermediate: bool = False, + refund_if_failed: bool = False, + decoded_params=None, ): + if decoded_params is None and is_intermediate: + return {} + if job_id is None: + logger.error( + 'task_id is not present in after_task_finished, there might be error occured in before_task_started.') + return {} + monitor_addr, system_monitor_api_secret = _get_system_monitor_config(header_dict) + if not monitor_addr or not system_monitor_api_secret: + logger.error(f'{job_id}: system_monitor_addr or system_monitor_api_secret is not present') + return {} + + session_hash = header_dict.get('x-session-hash', "") + if not session_hash: + logger.error(f'{job_id}: x-session-hash does not presented in headers') + return {} + task_id = header_dict.get('x-task-id', "") + if not task_id: + logger.error(f'{job_id}: x-task-id does not presented in headers') + return {} + + request_url = f'{monitor_addr}/{job_id}' + request_body = { + 'status': status, + 'result': message if message else "{}", + 'finished_at': time.time(), + 'session_hash': session_hash, + 'refund_if_failed': refund_if_failed, + } + if is_intermediate: + request_body['step_id'] = job_id + request_body['task_id'] = task_id + else: + request_body['task_id'] = job_id + resp = requests.post( + request_url, + headers={ + 'Api-Secret': system_monitor_api_secret, + }, + json=request_body + ) + + # log the response if request failed + if resp.status_code < 200 or resp.status_code > 299: + logger.error((f'update monitor log failed, status: monitor_log_id: {job_id}, {resp.status_code}, ' + f'message: {resp.text[:1000]}')) + return resp.json() + + +@contextmanager +def monitor_call_context( + queue_dispatcher: diffus.task_queue.TaskDispatcher | None, + extra_data: dict, + api_name: str, + function_name: str, + task_id: Optional[str] = None, + decoded_params: Optional[dict] = None, + is_intermediate: bool = True, + refund_if_task_failed: bool = True, + refund_if_failed: bool = False, + only_available_for: Optional[list[str]] = None): + status = 'unknown' + message = '' + task_is_failed = False + header_dict = _make_headers(extra_data) + + def result_encoder(success, result): + try: + nonlocal message + nonlocal task_is_failed + message = json.dumps(result, ensure_ascii=False, sort_keys=True) + task_is_failed = not success + except Exception as ex: + logger.error(f'{task_id}: Json encode result failed {ex}.') + + try: + if not is_intermediate and queue_dispatcher: + queue_dispatcher.on_task_started(task_id) + task_id = _before_task_started( + header_dict, + api_name, + function_name, + task_id, + decoded_params, + is_intermediate, + refund_if_task_failed, + only_available_for) + yield result_encoder + if task_is_failed: + status = 'failed' + else: + status = 'finished' + except Exception as e: + status = 'failed' + message = f'{type(e).__name__}: {str(e)}' + raise e + finally: + monitor_result = _after_task_finished( + header_dict, + task_id, + status, + message, + is_intermediate, + refund_if_failed, + decoded_params, + ) + if not is_intermediate: + logger.info(f'monitor_result: {monitor_result}') + extra_data['subscription_consumption'] = monitor_result.get('consumptions', {}) + if queue_dispatcher: + queue_dispatcher.on_task_finished(task_id, not task_is_failed, message) + + +def node_execution_monitor(get_output_data): + import nodes + + redis_client = get_redis_client() + + def wrapper(obj, input_data_all, extra_data, execution_block_cb=None, pre_execute_cb=None): + node_class_name = type(obj).__name__ + for k, v in nodes.NODE_CLASS_MAPPINGS.items(): + if type(obj) is v: + node_class_name = k + break + + with monitor_call_context( + None, + extra_data, + f'comfy.{node_class_name}', + 'comfyui', + decoded_params=diffus.decoded_params.get_monitor_params(obj, node_class_name, input_data_all), + is_intermediate=True, + ) as result_encoder: + try: + output_data = get_output_data(obj, input_data_all, extra_data, execution_block_cb, pre_execute_cb) + result_encoder(True, None) + post_output_to_image_gallery(redis_client, obj, _make_headers(extra_data), input_data_all, output_data) + return output_data + except Exception as ex: + result_encoder(False, ex) + raise + + return wrapper + + +def make_monitor_error_message(ex): + if isinstance(ex, MonitorException): + match (ex.status_code, ex.code): + case (402, "WEBUIFE-01010001"): + upgrade_info = { + "need_upgrade": True, + "reason": "INSUFFICIENT_CREDITS", + } + case (402, "WEBUIFE-01010003"): + upgrade_info = { + "need_upgrade": True, + "reason": "INSUFFICIENT_DAILY_CREDITS", + } + case (429, "WEBUIFE-01010004"): + upgrade_info = { + "need_upgrade": True, + "reason": "REACH_CONCURRENCY_LIMIT", + } + case _: + logger.error(f"mismatched status_code({ex.status_code}) and code({ex.code}) in 'MonitorException'") + upgrade_info = {"need_upgrade": False} + elif isinstance(ex, MonitorTierMismatchedException): + upgrade_info = { + "need_upgrade": True, + "reason": "TIER_MISSMATCH", + } + else: + upgrade_info = {"need_upgrade": False} + return upgrade_info diff --git a/diffus/task_queue.py b/diffus/task_queue.py new file mode 100644 index 00000000000..e260f73f351 --- /dev/null +++ b/diffus/task_queue.py @@ -0,0 +1,244 @@ +import json +import logging +import os +import threading +import time +import uuid + +import aiohttp.web_routedef +from aiohttp import web +import requests +from pydantic import BaseModel, Field +import version +import diffus.redis_client + +_logger = logging.getLogger(__name__) + + +class ServiceStatusResponse(BaseModel): + status: str = Field(title="Status", default='startup') + release_version: str = Field(default='') + + node_type: str = Field(title="NodeType", default='FOCUS') + instance_name: str = Field(default='') + + accepted_tiers: list[str] = Field(default=[]) + accepted_type_types: list[str] = Field(default=[]) + + current_task: str = Field(title="CurrentTask", default='') + queued_tasks: list = Field(title='QueuedTasks', default=[]) + finished_task_count: int = Field(title='FinishedTaskCount', default=0) + failed_task_count: int = Field(title='FailedTaskCount', default=0) + consecutive_failed_task_count: int = Field(title='ConsecutiveFailedTaskCount', default=0) + + +class ServiceStatusRequest(BaseModel): + status: str = Field(title="Status", default='startup') + node_type: str = Field(title="NodeType", default='FOCUS') + accepted_tiers: list[str] = Field(default=[]) + accepted_type_types: list[str] = Field(default=[]) + + +class _State: + def __init__(self) -> None: + self.host_ip = os.getenv('HOST_IP', '127.0.0.1') + self.service_port = os.getenv('SERVER_PORT') + + self.node_name = os.getenv('NODE_NAME') + self.node_type = os.getenv("NODE_TYPE") + self.accepted_tiers: [str] = os.getenv('ACCEPTED_TIERS').split(',') + self.accepted_type_types: [str] = os.getenv('ACCEPTED_TASK_TYPES').split(',') + + self.service_ready = False + self._redis_client = None + self.starting_flag = True + self.finished_task_count = 0 + self.failed_task_count = 0 + self.consecutive_failed_task_count = 0 + self.last_error_message = '' + self.service_interrupted = False + + self.task_thread = None + + self.service_status = 'startup' + + self.current_task = '' + self.remaining_tasks = 0 + + @property + def redis_client(self): + if self._redis_client is None: + self._redis_client = diffus.redis_client.get_redis_client() + return self._redis_client + + def reset_redis_client(self): + self._redis_client = None + + +def _setup_daemon_api(_task_state: _State, routes: aiohttp.web_routedef.RouteTableDef): + @routes.get("/daemon/v1/status") + async def get_status(request): + resp = ServiceStatusResponse( + release_version=version.version, + node_type=_task_state.node_type, + instance_name=_task_state.node_name, + accepted_tiers=_task_state.accepted_tiers, + accepted_type_types=_task_state.accepted_type_types, + + status=_task_state.service_status, + current_task=_task_state.current_task or '', + finished_task_count=_task_state.finished_task_count, + failed_task_count=_task_state.failed_task_count, + pending_task_count=_task_state.remaining_tasks, + consecutive_failed_task_count=_task_state.consecutive_failed_task_count, + ).model_dump() + + return web.json_response(resp) + + @routes.put("/daemon/v1/status") + async def update_status(request): + request_data = await request.json() + req = ServiceStatusRequest(**request_data) + if req.status: + _task_state.service_status = req.status + if req.accepted_tiers: + _task_state.accepted_tiers = req.accepted_tiers + _logger.info(f'update_status: service status was set to {_task_state.service_status}') + resp = await get_status(request) + return resp + + +def _service_is_alive(_task_state: _State): + try: + request_url = f'http://localhost:{_task_state.service_port}/daemon/v1/status' + resp = requests.get(request_url, json={}) + if resp.status_code == 200: + return True + else: + _logger.warning(f'failed to check service status on {request_url}: {resp.status_code} {resp.text}') + return False + except Exception as e: + _logger.warning(f'failed to check service status: {e.__str__()}') + return False + + +def _post_task(_task_state: _State, request_obj, retry=1): + task_id = request_obj['task_id'] + if not task_id: + task_id = str(uuid.uuid4()) + timeout = request_obj.get('timeout', 15 * 60) + headers = request_obj['headers'] + path = request_obj['path'] + + try: + encoded_headers = {} + for k, v in headers.items(): + encoded_headers[k] = ';'.join(v) + encoded_headers['X-Predict-Timeout'] = f'{timeout}' + encoded_headers['X-Task-Timeout'] = f'{timeout}' + encoded_headers['X-Task-Id'] = task_id + # pack request headers to extra_data + request_data = json.loads(request_obj['body']) + extra_data = request_data.get('extra_data', {}) + extra_data['diffus-request-headers'] = encoded_headers + request_data['extra_data'] = extra_data + + request_url = f'http://localhost:{_task_state.service_port}{path}' + request_timeout = timeout + 3 + + resp = requests.post(request_url, + headers=encoded_headers, + json=request_data, + timeout=request_timeout) + if not (199 < resp.status_code < 300): + return False + else: + _logger.error(f"failed to post task to server: {resp.status_code} {resp.text}") + return True + except Exception as e: + _logger.exception(f'failed to post task status to redis: {e}') + _task_state.reset_redis_client() + if retry > 0: + return _post_task(_task_state, request_obj, retry=retry - 1) + else: + raise + + +def _fetch_task(_task_state: _State, fetch_task_timeout=5): + if not _task_state.current_task and _task_state.remaining_tasks == 0: + queue_name_list = [] + for task_type in _task_state.accepted_type_types: + queue_name_list += [f"SD-{task_type}-TASKS-{tier}" for tier in _task_state.accepted_tiers] + _logger.debug( + f"begin to fetch pending requests from {queue_name_list}, current task: '{_task_state.current_task}'") + queued_task = _task_state.redis_client.blpop(queue_name_list, timeout=fetch_task_timeout) + if not queued_task: + _logger.debug(f'not get any pending requests in {fetch_task_timeout} seconds from {queue_name_list}') + # no task get, check service status, and fetch task again + return + + queue_name, packed_request = queued_task[0], queued_task[1] + if not packed_request or not queue_name: + # no task get, check service status, and fetch task again + return + + _logger.info(f'popped a task from {queue_name}: {packed_request}') + return json.loads(packed_request.decode('utf-8')) + else: + time.sleep(fetch_task_timeout) + + +class TaskDispatcher: + def __init__(self, prompt_queue, routes: aiohttp.web_routedef.RouteTableDef): + self._task_state = _State() + self._prompt_queue = prompt_queue + self._t = threading.Thread(target=self._task_loop, name='comfy-task-dispatcher-thread') + _setup_daemon_api(self._task_state, routes) + + def start(self): + self._task_state.service_status = 'up' + self._task_state.service_ready = True + self._t.start() + + def stop(self): + self._task_state.service_interrupted = True + self._t.join() + + def _task_loop(self): + while not self._task_state.service_interrupted \ + and not _service_is_alive(self._task_state): + time.sleep(1) + + while not self._task_state.service_interrupted \ + and self._task_state.service_ready \ + and self._task_state.service_status == 'up': + try: + # 1. update current state + self._task_state.remaining_tasks = self._prompt_queue.get_tasks_remaining() + # 2. fetch a task from remote queue + task = _fetch_task(self._task_state, fetch_task_timeout=5) + # 3. post task to service + if task: + _post_task(self._task_state, task) + except Exception as e: + _logger.exception(f'failed to fetch task from queue: {e.__str__()}') + time.sleep(3) + _logger.info(f'daemon_main: service is {self._task_state.service_status}, exit task loop now') + + def on_task_started(self, task_id): + _logger.info(f'on_task_started, {task_id}') + self._task_state.current_task = task_id + + def on_task_finished(self, task_id, success, messages): + _logger.info(f'on_task_finished, {task_id} {success} {messages}') + if self._task_state.current_task != task_id: + _logger.warning( + f'on_task_finished, task_state.current_task({self._task_state.current_task}) and task_id{task_id} are mismatched') + self._task_state.current_task = '' + if success: + self._task_state.finished_task_count += 1 + self._task_state.consecutive_failed_task_count = 0 + else: + self._task_state.failed_task_count += 1 + self._task_state.consecutive_failed_task_count += 1 + self._task_state.last_error_message = messages diff --git a/execution.py b/execution.py index 929ef85fac4..0734330b080 100644 --- a/execution.py +++ b/execution.py @@ -10,6 +10,10 @@ from typing import List, Literal, NamedTuple, Optional import torch + +import diffus.system_monitor +import execution_context +import node_helpers import nodes import comfy.model_management @@ -33,7 +37,7 @@ def __init__(self, dynprompt, outputs_cache): self.outputs_cache = outputs_cache self.is_changed = {} - def get(self, node_id): + def get(self, context: execution_context.ExecutionContext, node_id): if node_id in self.is_changed: return self.is_changed[node_id] @@ -49,7 +53,7 @@ def get(self, node_id): return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _ = get_input_data(context, node["inputs"], class_def, node_id, None) try: is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] @@ -88,13 +92,13 @@ def recursive_debug_dump(self): } return result -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): - valid_inputs = class_def.INPUT_TYPES() +def get_input_data(context: execution_context.ExecutionContext, inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): + valid_inputs = node_helpers.get_node_input_types(context, class_def) input_data_all = {} missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x) + input_type, input_category, input_info = get_input_info(context, class_def, x) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -127,6 +131,10 @@ def mark_missing(): input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "USER_HASH": + input_data_all[x] = [context.user_hash] + if h[x] == "EXECUTION_CONTEXT": + input_data_all[x] = [context] return input_data_all, missing_keys map_node_over_list = None #Don't hook this please @@ -191,8 +199,9 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): - +@diffus.system_monitor.node_execution_monitor +def get_output_data(obj, input_data_all, extra_data, execution_block_cb=None, pre_execute_cb=None): + results = [] uis = [] subgraph_results = [] @@ -242,7 +251,7 @@ def format_value(x): else: return str(x) -def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): +def execute(server, context: execution_context.ExecutionContext, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -280,7 +289,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: - input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys = get_input_data(context, inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -298,7 +307,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp )] if len(required_inputs) > 0: for i in required_inputs: - execution_list.make_input_strong_link(unique_id, i) + execution_list.make_input_strong_link(context, unique_id, i) return (ExecutionResult.PENDING, None, None) def execution_block_cb(block): @@ -321,7 +330,7 @@ def execution_block_cb(block): return block def pre_execute_cb(call_index): GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, extra_data, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -364,14 +373,16 @@ def pre_execute_cb(call_index): cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: - cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + cache.ensure_subcache_for(context, unique_id, new_node_ids).clean_unused() for node_id in new_output_ids: - execution_list.add_node(node_id) + execution_list.add_node(context, node_id) for link in new_output_links: - execution_list.add_strong_link(link[0], link[1], unique_id) + execution_list.add_strong_link(context, link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) caches.outputs.set(unique_id, output_data) + except diffus.system_monitor.MonitorTierMismatchedException: + raise except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -381,7 +392,10 @@ def pre_execute_cb(call_index): } return (ExecutionResult.FAILURE, error_details, iex) - except Exception as ex: + except (diffus.system_monitor.MonitorException, Exception) as ex: + upgrade_info = diffus.system_monitor.make_monitor_error_message(ex) + if upgrade_info['need_upgrade']: + raise typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) input_data_formatted = {} @@ -414,12 +428,14 @@ class PromptExecutor: def __init__(self, server, lru_size=None): self.lru_size = lru_size self.server = server + self.history_result = {} self.reset() def reset(self): self.caches = CacheSet(self.lru_size) self.status_messages = [] self.success = True + self.history_result = {} def add_message(self, event, data: dict, broadcast: bool): data = { @@ -458,7 +474,7 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e } self.add_message("execution_error", mes, broadcast=False) - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + def execute(self, context: execution_context.ExecutionContext, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -467,13 +483,14 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): self.server.client_id = None self.status_messages = [] + self.history_result = {} self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) for cache in self.caches.all: - cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.set_prompt(context, dynamic_prompt, prompt.keys(), is_changed_cache) cache.clean_unused() cached_nodes = [] @@ -490,7 +507,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): - execution_list.add_node(node_id) + execution_list.add_node(context, node_id) while not execution_list.is_empty(): node_id, error, ex = execution_list.stage_node_execution() @@ -498,7 +515,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + result, error, ex = execute(self.server, context, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -528,7 +545,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): comfy.model_management.unload_all_models() -def validate_inputs(prompt, item, validated): +def validate_inputs(context: execution_context.ExecutionContext, prompt, item, validated): unique_id = item if unique_id in validated: return validated[unique_id] @@ -537,7 +554,7 @@ def validate_inputs(prompt, item, validated): class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() + class_inputs = node_helpers.get_node_input_types(context, obj_class) valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] @@ -552,7 +569,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x) + type_input, input_category, extra_info = get_input_info(context, obj_class, x) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -605,7 +622,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = validate_inputs(context, prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -713,7 +730,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _ = get_input_data(inputs, obj_class, unique_id) + input_data_all, _ = get_input_data(context, inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -755,7 +772,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -def validate_prompt(prompt): +def validate_prompt(context: execution_context.ExecutionContext, prompt): outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -798,7 +815,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = validate_inputs(context, prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: diff --git a/execution_context.py b/execution_context.py new file mode 100644 index 00000000000..cdc2edd1af5 --- /dev/null +++ b/execution_context.py @@ -0,0 +1,152 @@ +import datetime +import uuid + +import diffus.models +import diffus.repository + + +class Geninfo: + def __init__(self, task_id): + self.positive_prompt = '' + self.negative_prompt = '' + self.steps = 0 + self.sampler = '' + self.cfg_scale = 0 + self.seed = 0 + self.task_id = task_id + + def dump(self): + return { + "Prompt": self.positive_prompt, + "Negative prompt": self.negative_prompt, + "Steps": self.steps, + "Sampler": self.sampler, + "CFG scale": self.cfg_scale, + "Seed": self.seed, + "Diffus task ID": self.task_id, + } + + +class ExecutionContext: + def __init__(self, request, extra_data={}): + self._headers = dict(request.headers) + self._extra_data = extra_data + self._used_models: dict[str, dict] = {} + self._checkpoints_model_base = "" + self._task_id = self._headers.get('x-task-id', str(uuid.uuid4())) + + self._geninfo = Geninfo(self._task_id) + + def validate_model(self, model_type, model_name, model_info=None): + if model_type not in diffus.models.FAVORITE_MODEL_TYPES: + return + if model_type not in self._used_models: + self._used_models[model_type] = {} + if model_name not in self._used_models[model_type]: + if not model_info: + model_info = diffus.repository.get_favorite_model_full_path(self.user_id, model_type, model_name) + self._used_models[model_type][model_name] = model_info + + def get_model(self, model_type, model_name): + return self._used_models.get(model_type, {}).get(model_name, None) + + @property + def task_id(self): + return self._task_id + + @property + def geninfo(self): + return self._geninfo.dump() + + @property + def positive_prompt(self): + return self._geninfo.positive_prompt + + @property + def negative_prompt(self): + return self._geninfo.negative_prompt + + @property + def loaded_model_ids(self): + result = [] + for model_type in diffus.models.FAVORITE_MODEL_TYPES: + result += [model_info.id for model_info in self._used_models.get(model_type, {}).values()] + return result + + @property + def checkpoints_model_base(self): + for model_info in self.loaded_checkpoints: + if model_info.base: + return model_info.base + # return self._checkpoints_model_base + return None + + @checkpoints_model_base.setter + def checkpoints_model_base(self, model_base): + self._checkpoints_model_base = model_base + + @property + def loaded_checkpoints(self): + return [model_info for model_info in self._used_models.get('checkpoints', {}).values() if model_info] + + @property + def loaded_loras(self): + return [ + model_info for model_info in self._used_models.get('loras', {}).values() if model_info + ] + [ + model_info for model_info in self._used_models.get('lycoris', {}).values() if model_info + ] + + @property + def user_hash(self): + if self._headers: + return self._headers.get('X-Diffus-User-Hash', None) or self._headers.get('x-diffus-user-hash', '') + else: + return '' + + @property + def user_id(self): + if self._headers: + return self._headers.get('User-Id', None) or self._headers.get('user-id', '') + else: + return '' + + @property + def extra_data(self): + return self._extra_data or {} + + @staticmethod + def _get_origin_text_from_tokens(tokens): + return [ + t[1]["_origin_text_"] for t in tokens if + len(t) > 1 and isinstance(t[1], dict) and "_origin_text_" in t[1] and t[1]["_origin_text_"] + ] + + def set_geninfo( + self, + positive_prompt={}, + negative_prompt={}, + steps=0, + sampler='', + cfg_scale=0, + seed=0, + ): + self._geninfo.positive_prompt = self._concat_prompt( + self._geninfo.positive_prompt, + self._get_origin_text_from_tokens(positive_prompt) + ) + self._geninfo.negative_prompt = self._concat_prompt( + self._geninfo.negative_prompt, + self._get_origin_text_from_tokens(negative_prompt) + ) + self._geninfo.steps = steps + self._geninfo.sampler = sampler + self._geninfo.cfg_scale = cfg_scale + self._geninfo.seed = seed + + @staticmethod + def _concat_prompt(prompt_1: str, prompt_2: list[str]): + if prompt_2: + return prompt_1 + " " + " ".join(prompt_2) + else: + return prompt_1 diff --git a/folder_paths.py b/folder_paths.py index 0c9e9f15dae..dd80821dd2b 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,22 @@ from __future__ import annotations - +import datetime import os +import shutil import time import mimetypes import logging from typing import Set, List, Dict, Tuple, Literal from collections.abc import Collection + +import diffus.repository +import diffus.models +import execution_context + + supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} + folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {} base_path = os.path.dirname(os.path.realpath(__file__)) @@ -27,11 +35,12 @@ folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) +folder_names_and_paths["ipadapter"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) -folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set()) +folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes"), os.path.join(base_path, "custom_nodes_builtin")], set()) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) @@ -40,9 +49,9 @@ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") -temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") -input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") -user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") +# temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") +# input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +# user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} @@ -85,52 +94,95 @@ def map_legacy(folder_name: str) -> str: "clip": "text_encoders"} return legacy.get(folder_name, folder_name) -if not os.path.exists(input_directory): - try: - os.makedirs(input_directory) - except: - logging.error("Failed to create input directory") +# if not os.path.exists(input_directory): +# try: +# os.makedirs(input_directory) +# except: +# logging.error("Failed to create input directory") + +def get_models_dir(user_hash): + return models_dir def set_output_directory(output_dir: str) -> None: global output_directory output_directory = output_dir + raise Exception("unsupported") def set_temp_directory(temp_dir: str) -> None: global temp_directory temp_directory = temp_dir + raise Exception("unsupported") def set_input_directory(input_dir: str) -> None: global input_directory input_directory = input_dir + raise Exception("unsupported") -def get_output_directory() -> str: - global output_directory - return output_directory -def get_temp_directory() -> str: - global temp_directory - return temp_directory +def get_output_directory(user_hash): + if not user_hash: + import traceback + import sys + traceback.print_stack(file=sys.stdout) + raise Exception("missed user_hash from get_output_directory") + return os.path.join(output_directory, get_relative_output_directory(user_hash)) + + +def get_relative_output_directory(user_hash): + return os.path.join(user_hash, "outputs", "comfyui", datetime.datetime.now().strftime("%Y-%m-%d")) + + +def _get_comfyui_user_data_base(user_hash): + return os.path.join(output_directory, user_hash, "comfyui") -def get_input_directory() -> str: - global input_directory - return input_directory -def get_user_directory() -> str: - return user_directory +def get_temp_directory(user_hash): + if not user_hash: + import traceback + import sys + traceback.print_stack(file=sys.stdout) + raise Exception("missed user_hash from get_temp_directory") + return os.path.join(_get_comfyui_user_data_base(user_hash), "temp") + + +def get_input_directory(user_hash): + if not user_hash: + import traceback + import sys + traceback.print_stack(file=sys.stdout) + raise Exception("missed user_hash from get_input_directory") + d = os.path.join(_get_comfyui_user_data_base(user_hash), "input") + if not os.path.exists(d): + os.makedirs(d, exist_ok=True) + return d + + +def clear_input_directory(user_hash): + d = get_input_directory(user_hash) + if os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + return d + +def get_user_directory(user_hash) -> str: + d = os.path.join(_get_comfyui_user_data_base(user_hash), "user") + if not os.path.exists(d): + os.makedirs(d, exist_ok=True) + return d def set_user_directory(user_dir: str) -> None: global user_directory user_directory = user_dir + raise Exception("unsupported") #NOTE: used in http server so don't put folders that should not be accessed remotely -def get_directory_by_type(type_name: str) -> str | None: +def get_directory_by_type(type_name: str, user_hash: str) -> str | None: if type_name == "output": - return get_output_directory() + return get_output_directory(user_hash) if type_name == "temp": - return get_temp_directory() + return get_temp_directory(user_hash) if type_name == "input": - return get_input_directory() + return get_input_directory(user_hash) return None def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]: @@ -158,15 +210,15 @@ def filter_files_content_types(files: List[str], content_types: Literal["image", # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def annotated_filepath(name: str) -> tuple[str, str | None]: +def annotated_filepath(name: str, user_hash) -> tuple[str, str | None]: if name.endswith("[output]"): - base_dir = get_output_directory() + base_dir = get_output_directory(user_hash) name = name[:-9] elif name.endswith("[input]"): - base_dir = get_input_directory() + base_dir = get_input_directory(user_hash) name = name[:-8] elif name.endswith("[temp]"): - base_dir = get_temp_directory() + base_dir = get_temp_directory(user_hash) name = name[:-7] else: return name, None @@ -174,23 +226,22 @@ def annotated_filepath(name: str) -> tuple[str, str | None]: return name, base_dir -def get_annotated_filepath(name: str, default_dir: str | None=None) -> str: - name, base_dir = annotated_filepath(name) - +def get_annotated_filepath(name: str, user_hash, default_dir: str | None=None) -> str: + name, base_dir = annotated_filepath(name, user_hash) if base_dir is None: if default_dir is not None: base_dir = default_dir else: - base_dir = get_input_directory() # fallback path + base_dir = get_input_directory(user_hash) # fallback path return os.path.join(base_dir, name) -def exists_annotated_filepath(name) -> bool: - name, base_dir = annotated_filepath(name) +def exists_annotated_filepath(name, user_hash) -> bool: + name, base_dir = annotated_filepath(name, user_hash) if base_dir is None: - base_dir = get_input_directory() # fallback path + base_dir = get_input_directory(user_hash) # fallback path filepath = os.path.join(base_dir, name) return os.path.exists(filepath) @@ -257,9 +308,16 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -def get_full_path(folder_name: str, filename: str) -> str | None: - global folder_names_and_paths - folder_name = map_legacy(folder_name) +def get_full_path(context: execution_context.ExecutionContext, folder_name: str, filename: str) -> str | None: + if folder_name in diffus.models.FAVORITE_MODEL_TYPES: + model_info = context.get_model(folder_name, filename) + if not model_info: + model_info = diffus.repository.get_favorite_model_full_path(context.user_id, folder_name, filename) + context.validate_model(folder_name, filename, model_info) + return model_info + else: + global folder_names_and_paths + folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: return None folders = folder_names_and_paths[folder_name] @@ -274,8 +332,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None: return None -def get_full_path_or_raise(folder_name: str, filename: str) -> str: - full_path = get_full_path(folder_name, filename) +def get_full_path_or_raise(context: execution_context.ExecutionContext, folder_name: str, filename: str) -> str: + full_path = get_full_path(context, folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") return full_path @@ -298,7 +356,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float] strong_cache = cache_helper.get(folder_name) if strong_cache is not None: return strong_cache - + global filename_list_cache global folder_names_and_paths folder_name = map_legacy(folder_name) @@ -320,8 +378,11 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float] return out -def get_filename_list(folder_name: str) -> list[str]: - folder_name = map_legacy(folder_name) +def get_filename_list(context: execution_context.ExecutionContext, folder_name: str) -> list[str]: + if folder_name in diffus.models.FAVORITE_MODEL_TYPES: + return diffus.repository.list_favorite_model_by_model_type(context.user_id, folder_name) + else: + folder_name = map_legacy(folder_name) out = cached_filename_list_(folder_name) if out is None: out = get_filename_list_(folder_name) @@ -330,6 +391,7 @@ def get_filename_list(folder_name: str) -> list[str]: cache_helper.set(folder_name, out) return list(out[0]) + def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]: def map_filename(filename: str) -> tuple[int, str]: prefix_len = len(os.path.basename(filename_prefix)) diff --git a/latent_preview.py b/latent_preview.py index d60e68d5512..165301004fa 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -58,7 +58,7 @@ def decode_latent_to_preview(self, x0): return preview_to_image(latent_image) -def get_previewer(device, latent_format): +def get_previewer(context, device, latent_format): previewer = None method = args.preview_method if method != LatentPreviewMethod.NoPreviews: @@ -66,11 +66,11 @@ def get_previewer(device, latent_format): taesd_decoder_path = None if latent_format.taesd_decoder_name is not None: taesd_decoder_path = next( - (fn for fn in folder_paths.get_filename_list("vae_approx") + (fn for fn in folder_paths.get_filename_list(context, "vae_approx") if fn.startswith(latent_format.taesd_decoder_name)), "" ) - taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) + taesd_decoder_path = folder_paths.get_full_path(context, "vae_approx", taesd_decoder_path) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -87,12 +87,12 @@ def get_previewer(device, latent_format): previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) return previewer -def prepare_callback(model, steps, x0_output_dict=None): +def prepare_callback(context, model, steps, x0_output_dict=None): preview_format = "JPEG" if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = get_previewer(model.load_device, model.model.latent_format) + previewer = get_previewer(context, model.load_device, model.model.latent_format) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): diff --git a/main.py b/main.py index c2c2ff8c878..5675c444305 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,8 @@ import comfy.options +import diffus.message +import diffus.task_queue +import diffus.system_monitor + comfy.options.enable_args_parsing() import os @@ -105,7 +109,7 @@ def cuda_malloc_warning(): if cuda_malloc_warning: logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q, server): +def prompt_worker(q, server, task_dispatcher): e = execution.PromptExecutor(server, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False @@ -122,8 +126,30 @@ def prompt_worker(q, server): execution_start_time = time.perf_counter() prompt_id = item[1] server.last_prompt_id = prompt_id - - e.execute(item[2], prompt_id, item[3], item[4]) + context = item[-1] + extra_data = item[3] + + begin = time.time() + monitor_error = None + try: + with diffus.system_monitor.monitor_call_context( + task_dispatcher, + extra_data, + 'comfy', + 'comfyui', + prompt_id, + is_intermediate=False, + only_available_for=['basic', 'plus', 'pro', 'api'], + ) as result_encoder: + e.execute(context, item[2], prompt_id, extra_data, item[4]) + result_encoder(e.success, e.status_messages) + except diffus.system_monitor.MonitorException as ex: + monitor_error = ex + except diffus.system_monitor.MonitorTierMismatchedException as ex: + monitor_error = ex + except Exception as ex: + logging.exception(ex) + end = time.time() need_gc = True q.task_done(item_id, e.history_result, @@ -133,10 +159,14 @@ def prompt_worker(q, server): messages=e.status_messages)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + if monitor_error is not None: + server.send_sync("monitor_error", { "node": None, 'prompt_id': prompt_id, 'used_time': end - begin, 'message': diffus.system_monitor.make_monitor_error_message(monitor_error) }, server.client_id) + else: + server.send_sync("finished", { "node": None, 'prompt_id': prompt_id, 'used_time': end - begin, 'subscription_consumption': extra_data['subscription_consumption'] }, server.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time - logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + logging.info(f"Prompt executed in {execution_time:.2f} seconds, client_id: {server.client_id}") flags = q.get_flags() free_memory = flags.get("free_memory", False) @@ -188,7 +218,7 @@ def cleanup_temp(): temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") logging.info(f"Setting temp directory to: {temp_dir}") folder_paths.set_temp_directory(temp_dir) - cleanup_temp() + # cleanup_temp() if args.windows_standalone_build: try: @@ -201,6 +231,7 @@ def cleanup_temp(): asyncio.set_event_loop(loop) server = server.PromptServer(loop) q = execution.PromptQueue(server) + task_dispatcher = diffus.task_queue.TaskDispatcher(q, server.routes) extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): @@ -217,7 +248,7 @@ def cleanup_temp(): server.add_routes() hijack_progress(server) - threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(q, server, task_dispatcher,)).start() if args.output_directory: output_dir = os.path.abspath(args.output_directory) @@ -225,17 +256,17 @@ def cleanup_temp(): folder_paths.set_output_directory(output_dir) #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes - folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) - folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) - folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) - folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) - folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) + # folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) + # folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) + # folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + # folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) + # folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) if args.input_directory: input_dir = os.path.abspath(args.input_directory) logging.info(f"Setting input directory to: {input_dir}") folder_paths.set_input_directory(input_dir) - + if args.user_directory: user_dir = os.path.abspath(args.user_directory) logging.info(f"Setting user directory to: {user_dir}") @@ -244,7 +275,7 @@ def cleanup_temp(): if args.quick_test_for_ci: exit(0) - os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) + # os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) call_on_start = None if args.auto_launch: def startup_server(scheme, address, port): @@ -256,10 +287,16 @@ def startup_server(scheme, address, port): webbrowser.open(f"{scheme}://{address}:{port}") call_on_start = startup_server + def on_startup(scheme, address, port): + if args.auto_launch: + startup_server(scheme, address, port) + task_dispatcher.start() + try: loop.run_until_complete(server.setup()) - loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=on_startup)) except KeyboardInterrupt: logging.info("\nStopped server") + task_dispatcher.stop() cleanup_temp() diff --git a/node_helpers.py b/node_helpers.py index 4b38bfff809..ee717ce016a 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,9 +1,14 @@ +import inspect import hashlib from comfy.cli_args import args + from PIL import ImageFile, UnidentifiedImageError +import execution_context + + def conditioning_set_values(conditioning, values={}): c = [] for t in conditioning: @@ -25,7 +30,31 @@ def pillow(fn, arg): finally: if prev_value is not None: ImageFile.LOAD_TRUNCATED_IMAGES = prev_value - return x + return x + + +def get_node_input_types(context: execution_context.ExecutionContext, node_class): + signature = inspect.signature(node_class.INPUT_TYPES) + positional_args = [] + inputs = [] + for i, param in enumerate(signature.parameters.values()): + if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + break + positional_args.append(param) + for i, param in enumerate(positional_args): + if (param.annotation == str or param.annotation == "str") and param.name == 'user_hash': + inputs.insert(i, context.user_hash) + if param.annotation == execution_context.ExecutionContext or param.annotation == "execution_context.ExecutionContext": + inputs.insert(i, context) + while len(inputs) < len(positional_args): + i = len(inputs) + param = positional_args[i] + if param.default == param.empty: + inputs.append(None) + else: + inputs.append(param.default) + return node_class.INPUT_TYPES(*inputs) + def hasher(): hashfuncs = { @@ -35,3 +64,4 @@ def hasher(): "sha512": hashlib.sha512 } return hashfuncs[args.default_hashing_function] + diff --git a/nodes.py b/nodes.py index 1cb4b5a5af0..6f713aaaa73 100644 --- a/nodes.py +++ b/nodes.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +import inspect import os import sys import json @@ -17,6 +18,8 @@ import numpy as np import safetensors.torch +import execution_context + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) import comfy.diffusers_load @@ -44,6 +47,7 @@ def before_node_execution(): def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) + MAX_RESOLUTION=16384 class CLIPTextEncode(ComfyNodeABC): @@ -51,7 +55,7 @@ class CLIPTextEncode(ComfyNodeABC): def INPUT_TYPES(s) -> InputTypeDict: return { "required": { - "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), + "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) } } @@ -64,8 +68,8 @@ def INPUT_TYPES(s) -> InputTypeDict: def encode(self, clip, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens), ) - + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"_origin_text_": text}), ) + class ConditioningCombine: @classmethod @@ -269,8 +273,8 @@ class VAEDecode: @classmethod def INPUT_TYPES(s): return { - "required": { - "samples": ("LATENT", {"tooltip": "The latent to be decoded."}), + "required": { + "samples": ("LATENT", {"tooltip": "The latent to be decoded."}), "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}) } } @@ -430,13 +434,13 @@ def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): class SaveLatent: def __init__(self): - self.output_dir = folder_paths.get_output_directory() + pass @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "user_hash": "USER_HASH"}, } RETURN_TYPES = () FUNCTION = "save" @@ -445,8 +449,9 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" - def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, user_hash=''): + output_dir = folder_paths.get_output_directory(user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) # support save metadata for latent sharing prompt_info = "" @@ -481,18 +486,19 @@ def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=No class LoadLatent: @classmethod - def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() + def INPUT_TYPES(s, user_hash: str = ''): + input_dir = folder_paths.get_input_directory(user_hash) files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] - return {"required": {"latent": [sorted(files), ]}, } + return {"required": {"latent": [sorted(files), ]}, + "hidden": {"user_hash": "USER_HASH"}} CATEGORY = "_for_testing" RETURN_TYPES = ("LATENT", ) FUNCTION = "load" - def load(self, latent): - latent_path = folder_paths.get_annotated_filepath(latent) + def load(self, latent, user_hash): + latent_path = folder_paths.get_annotated_filepath(latent, user_hash) latent = safetensors.torch.load_file(latent_path, device="cpu") multiplier = 1.0 if "latent_format_version_0" not in latent: @@ -501,55 +507,69 @@ def load(self, latent): return (samples, ) @classmethod - def IS_CHANGED(s, latent): - image_path = folder_paths.get_annotated_filepath(latent) + def IS_CHANGED(s, latent, user_hash): + image_path = folder_paths.get_annotated_filepath(latent, user_hash) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, latent): - if not folder_paths.exists_annotated_filepath(latent): + def VALIDATE_INPUTS(s, latent, user_hash): + if not folder_paths.exists_annotated_filepath(latent, user_hash): return "Invalid latent file: {}".format(latent) return True class CheckpointLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "config_name": (folder_paths.get_filename_list(context, "configs"), ), + "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "advanced/loaders" DEPRECATED = True - def load_checkpoint(self, config_name, ckpt_name): - config_path = folder_paths.get_full_path("configs", config_name) - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + @classmethod + def VALIDATE_INPUTS(cls, config_name, ckpt_name, context: execution_context.ExecutionContext=None): + context.validate_model("checkpoints", ckpt_name) + return True + + def load_checkpoint(self, config_name, ckpt_name, context=None): + config_path = folder_paths.get_full_path(context, "configs", config_name) + ckpt_path = folder_paths.get_full_path_or_raise(context, "checkpoints", ckpt_name) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CheckpointLoaderSimple: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + "required": { + "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } RETURN_TYPES = ("MODEL", "CLIP", "VAE") - OUTPUT_TOOLTIPS = ("The model used for denoising latents.", - "The CLIP model used for encoding text prompts.", + OUTPUT_TOOLTIPS = ("The model used for denoising latents.", + "The CLIP model used for encoding text prompts.", "The VAE model used for encoding and decoding images to and from latent space.") FUNCTION = "load_checkpoint" CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." - def load_checkpoint(self, ckpt_name): - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + @classmethod + def VALIDATE_INPUTS(cls, ckpt_name, context: execution_context.ExecutionContext = None): + context.validate_model("checkpoints", ckpt_name) + return True + + def load_checkpoint(self, ckpt_name, context: execution_context.ExecutionContext = None): + ckpt_path = folder_paths.get_full_path_or_raise(context, "checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] @@ -582,16 +602,22 @@ def load_checkpoint(self, model_path, output_vae=True, output_clip=True): class unCLIPCheckpointLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "ckpt_name": (folder_paths.get_filename_list(context, "checkpoints"), ), + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + @classmethod + def VALIDATE_INPUTS(cls, ckpt_name, output_vae=True, output_clip=True, context: execution_context.ExecutionContext=None): + context.validate_model("checkpoints", ckpt_name) + return True + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, context: execution_context.ExecutionContext=None): + ckpt_path = folder_paths.get_full_path_or_raise(context, "checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out @@ -616,17 +642,20 @@ def __init__(self): self.loaded_lora = None @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return { - "required": { + "required": { "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), - "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "lora_name": (folder_paths.get_filename_list(context, "loras"), {"tooltip": "The name of the LoRA."}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + }, + "hidden": { + "context": "EXECUTION_CONTEXT" } } - + RETURN_TYPES = ("MODEL", "CLIP") OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") FUNCTION = "load_lora" @@ -634,11 +663,16 @@ def INPUT_TYPES(s): CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." - def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + @classmethod + def VALIDATE_INPUTS(cls, model, clip, lora_name, strength_model, strength_clip, context: execution_context.ExecutionContext): + context.validate_model("loras", lora_name) + return True + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip, context: execution_context.ExecutionContext): if strength_model == 0 and strength_clip == 0: return (model, clip) - lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora_path = folder_paths.get_full_path_or_raise(context, "loras", lora_name) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: @@ -657,22 +691,28 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): class LoraLoaderModelOnly(LoraLoader): @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return {"required": { "model": ("MODEL",), - "lora_name": (folder_paths.get_filename_list("loras"), ), + "lora_name": (folder_paths.get_filename_list(context, "loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), - }} + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL",) FUNCTION = "load_lora_model_only" - def load_lora_model_only(self, model, lora_name, strength_model): - return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + @classmethod + def VALIDATE_INPUTS(cls, model, lora_name, strength_model, context: execution_context.ExecutionContext): + context.validate_model("loras", lora_name) + return True + + def load_lora_model_only(self, model, lora_name, strength_model, context: execution_context.ExecutionContext): + return (self.load_lora(model, None, lora_name, strength_model, 0, context)[0],) class VAELoader: @staticmethod - def vae_list(): - vaes = folder_paths.get_filename_list("vae") - approx_vaes = folder_paths.get_filename_list("vae_approx") + def vae_list(context: execution_context.ExecutionContext): + vaes = folder_paths.get_filename_list(context, "vae") + approx_vaes = folder_paths.get_filename_list(context, "vae_approx") sdxl_taesd_enc = False sdxl_taesd_dec = False sd1_taesd_enc = False @@ -710,18 +750,18 @@ def vae_list(): return vaes @staticmethod - def load_taesd(name): + def load_taesd(context, name): sd = {} - approx_vaes = folder_paths.get_filename_list("vae_approx") + approx_vaes = folder_paths.get_filename_list(context, "vae_approx") encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) - enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder)) + enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(context, "vae_approx", encoder)) for k in enc: sd["taesd_encoder.{}".format(k)] = enc[k] - dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder)) + dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(context, "vae_approx", decoder)) for k in dec: sd["taesd_decoder.{}".format(k)] = dec[k] @@ -740,51 +780,54 @@ def load_taesd(name): return sd @classmethod - def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "vae_name": (s.vae_list(context), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" #TODO: scale factor? - def load_vae(self, vae_name): + def load_vae(self, vae_name, context: execution_context.ExecutionContext): if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: - sd = self.load_taesd(vae_name) + sd = self.load_taesd(context, vae_name) else: - vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + vae_path = folder_paths.get_full_path_or_raise(context, "vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) return (vae,) class ControlNetLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "control_net_name": (folder_paths.get_filename_list(context, "controlnet"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" CATEGORY = "loaders" - def load_controlnet(self, control_net_name): - controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) + def load_controlnet(self, control_net_name, context: execution_context.ExecutionContext): + controlnet_path = folder_paths.get_full_path_or_raise(context, "controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s, context: execution_context.ExecutionContext): return {"required": { "model": ("MODEL",), - "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} + "control_net_name": (folder_paths.get_filename_list(context, "controlnet"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" CATEGORY = "loaders" - def load_controlnet(self, model, control_net_name): - controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) + def load_controlnet(self, model, control_net_name, context: execution_context.ExecutionContext): + controlnet_path = folder_paths.get_full_path_or_raise(context, "controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) return (controlnet,) @@ -872,16 +915,22 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta class UNETLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), - "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "unet_name": (folder_paths.get_filename_list(context, "diffusion_models"), ), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - def load_unet(self, unet_name, weight_dtype): + @classmethod + def VALIDATE_INPUTS(cls, unet_name, weight_dtype, context: execution_context.ExecutionContext): + context.validate_model(context, "diffusion_models", unet_name) + return True + + def load_unet(self, unet_name, weight_dtype, context: execution_context.ExecutionContext): model_options = {} if weight_dtype == "fp8_e4m3fn": model_options["dtype"] = torch.float8_e4m3fn @@ -891,16 +940,17 @@ def load_unet(self, unet_name, weight_dtype): elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2 - unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) + unet_path = folder_paths.get_full_path_or_raise(context, "diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) class CLIPLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "clip_name": (folder_paths.get_filename_list(context, "text_encoders"), ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ), - }} + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -908,7 +958,7 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5" - def load_clip(self, clip_name, type="stable_diffusion"): + def load_clip(self, clip_name, type="stable_diffusion", context=None): if type == "stable_cascade": clip_type = comfy.sd.CLIPType.STABLE_CASCADE elif type == "sd3": @@ -922,17 +972,19 @@ def load_clip(self, clip_name, type="stable_diffusion"): else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) + clip_path = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "clip_name1": (folder_paths.get_filename_list(context, "text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list(context, "text_encoders"), ), "type": (["sdxl", "sd3", "flux"], ), - }} + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -940,9 +992,9 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" - def load_clip(self, clip_name1, clip_name2, type): - clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) - clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + def load_clip(self, clip_name1, clip_name2, type, context: execution_context.ExecutionContext): + clip_path1 = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise(context, "text_encoders", clip_name2) if type == "sdxl": clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION elif type == "sd3": @@ -955,16 +1007,17 @@ def load_clip(self, clip_name1, clip_name2, type): class CLIPVisionLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ), - }} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "clip_name": (folder_paths.get_filename_list(context, "clip_vision"), ), + }, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" CATEGORY = "loaders" - def load_clip(self, clip_name): - clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) + def load_clip(self, clip_name, context: execution_context.ExecutionContext): + clip_path = folder_paths.get_full_path_or_raise(context, "clip_vision", clip_name) clip_vision = comfy.clip_vision.load(clip_path) return (clip_vision,) @@ -989,16 +1042,17 @@ def encode(self, clip_vision, image, crop): class StyleModelLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "style_model_name": (folder_paths.get_filename_list(context, "style_models"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" CATEGORY = "loaders" - def load_style_model(self, style_model_name): - style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) + def load_style_model(self, style_model_name, context): + style_model_path = folder_paths.get_full_path_or_raise(context, "style_models", style_model_name) style_model = comfy.sd.load_style_model(style_model_path) return (style_model,) @@ -1059,16 +1113,17 @@ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentati class GLIGENLoader: @classmethod - def INPUT_TYPES(s): - return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}} + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + return {"required": { "gligen_name": (folder_paths.get_filename_list(context, "gligen"), )}, + "hidden": {"context": "EXECUTION_CONTEXT"}} RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" CATEGORY = "loaders" - def load_gligen(self, gligen_name): - gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) + def load_gligen(self, gligen_name, context): + gligen_path = folder_paths.get_full_path_or_raise(context, "gligen", gligen_name) gligen = comfy.sd.load_gligen(gligen_path) return (gligen,) @@ -1110,7 +1165,7 @@ def __init__(self): @classmethod def INPUT_TYPES(s): return { - "required": { + "required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}), "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}) @@ -1415,7 +1470,7 @@ def set_mask(self, samples, mask): s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +def common_ksampler(context, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) @@ -1429,7 +1484,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - callback = latent_preview.prepare_callback(model, steps) + callback = latent_preview.prepare_callback(context, model, steps) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, @@ -1453,8 +1508,10 @@ def INPUT_TYPES(s): "negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}), "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}), + }, + "hidden": { + "context": "EXECUTION_CONTEXT"} } - } RETURN_TYPES = ("LATENT",) OUTPUT_TOOLTIPS = ("The denoised latent.",) @@ -1463,8 +1520,8 @@ def INPUT_TYPES(s): CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, context=None): + return common_ksampler(context, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) class KSamplerAdvanced: @classmethod @@ -1483,7 +1540,8 @@ def INPUT_TYPES(s): "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - } + }, + "hidden": {"context": "EXECUTION_CONTEXT"} } RETURN_TYPES = ("LATENT",) @@ -1491,18 +1549,17 @@ def INPUT_TYPES(s): CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, context=None): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) + return common_ksampler(context, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) class SaveImage: def __init__(self): - self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" self.compress_level = 4 @@ -1515,7 +1572,7 @@ def INPUT_TYPES(s): "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) }, "hidden": { - "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT", "user_hash": "USER_HASH" }, } @@ -1527,10 +1584,17 @@ def INPUT_TYPES(s): CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." - def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, context: execution_context.ExecutionContext = None, user_hash: str = ''): filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + if not user_hash: + user_hash = context.user_hash + if self.type == "temp": + output_dir = folder_paths.get_temp_directory(user_hash) + else: + output_dir = folder_paths.get_output_directory(user_hash) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]) results = list() + ts = time.time() for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) @@ -1544,7 +1608,7 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi metadata.add_text(x, json.dumps(extra_pnginfo[x])) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.png" + file = f"{filename_with_batch_num}_{counter:05}_{ts}.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, @@ -1557,7 +1621,6 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi class PreviewImage(SaveImage): def __init__(self): - self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 @@ -1566,24 +1629,27 @@ def __init__(self): def INPUT_TYPES(s): return {"required": {"images": ("IMAGE", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": + {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "context": "EXECUTION_CONTEXT", "user_hash": "USER_HASH"}, } class LoadImage: @classmethod - def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + input_dir = folder_paths.get_input_directory(context.user_hash) files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (sorted(files), {"image_upload": True})}, + "hidden": { + "context": "EXECUTION_CONTEXT"}, } CATEGORY = "image" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - def load_image(self, image): - image_path = folder_paths.get_annotated_filepath(image) + def load_image(self, image, context: execution_context.ExecutionContext): + image_path = folder_paths.get_annotated_filepath(image, context.user_hash) img = node_helpers.pillow(Image.open, image_path) @@ -1627,16 +1693,16 @@ def load_image(self, image): return (output_image, output_mask) @classmethod - def IS_CHANGED(s, image): - image_path = folder_paths.get_annotated_filepath(image) + def IS_CHANGED(s, image, context: execution_context.ExecutionContext): + image_path = folder_paths.get_annotated_filepath(image, context.user_hash) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): + def VALIDATE_INPUTS(s, image, context: execution_context.ExecutionContext): + if not folder_paths.exists_annotated_filepath(image, context.user_hash): return "Invalid image file: {}".format(image) return True @@ -1644,20 +1710,22 @@ def VALIDATE_INPUTS(s, image): class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] @classmethod - def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() + def INPUT_TYPES(s, context: execution_context.ExecutionContext): + input_dir = folder_paths.get_input_directory(context.user_hash) files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } + "channel": (s._color_channels, ), }, + "hidden": + {"user_hash": "USER_HASH"}, } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "load_image" - def load_image(self, image, channel): - image_path = folder_paths.get_annotated_filepath(image) + def load_image(self, image, channel, user_hash): + image_path = folder_paths.get_annotated_filepath(image, user_hash) i = node_helpers.pillow(Image.open, image_path) i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.getbands() != ("R", "G", "B", "A"): @@ -1676,16 +1744,16 @@ def load_image(self, image, channel): return (mask.unsqueeze(0),) @classmethod - def IS_CHANGED(s, image, channel): - image_path = folder_paths.get_annotated_filepath(image) + def IS_CHANGED(s, image, channel, user_hash): + image_path = folder_paths.get_annotated_filepath(image, user_hash) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): + def VALIDATE_INPUTS(s, image, user_hash): + if not folder_paths.exists_annotated_filepath(image, user_hash): return "Invalid image file: {}".format(image) return True @@ -2179,5 +2247,5 @@ def init_extra_nodes(init_custom_nodes=True): else: logging.warning("Please do a: pip install -r requirements.txt") logging.warning("") - + return import_failed diff --git a/requirements.txt b/requirements.txt index 4c2c0b2b221..2187cd2da7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,8 @@ psutil kornia>=0.7.1 spandrel soundfile +PyMySQL +SQLAlchemy +pydantic +requests +redis diff --git a/server.py b/server.py index 5e86558d4a7..39e970f3fea 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,9 @@ import asyncio import traceback +import diffus.message +import execution_context +import node_helpers import nodes import folder_paths import execution @@ -20,6 +23,7 @@ import aiohttp from aiohttp import web +from aiohttp import hdrs import logging import mimetypes @@ -64,6 +68,28 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response + +@web.middleware +async def compress_middleware( + request: web.Request, handler +) -> web.StreamResponse: + + accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() + + if web.ContentCoding.gzip.value in accept_encoding: + compressor = web.ContentCoding.gzip.value + elif web.ContentCoding.deflate.value in accept_encoding: + compressor = web.ContentCoding.deflate.value + else: + return await handler(request) + + resp = await handler(request) + if resp.content_length is not None and resp.content_length > 0: + resp.headers[hdrs.CONTENT_ENCODING] = compressor + resp.enable_compression() + return resp + + def create_cors_middleware(allowed_origin: str): @web.middleware async def cors_middleware(request: web.Request, handler): @@ -144,6 +170,8 @@ async def origin_only_middleware(request: web.Request, handler): return origin_only_middleware class PromptServer(): + instance = None + def __init__(self, loop): PromptServer.instance = self @@ -159,7 +187,7 @@ def __init__(self, loop): self.client_session:Optional[aiohttp.ClientSession] = None self.number = 0 - middlewares = [cache_control] + middlewares = [cache_control, compress_middleware] if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) else: @@ -181,6 +209,8 @@ def __init__(self, loop): self.on_prompt_handlers = [] + self.dffis_message_queue = diffus.message.MessageQueue() + @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() @@ -208,6 +238,14 @@ async def websocket_handler(request): self.sockets.pop(sid, None) return ws + @routes.get("/comfy") + async def get_root(request): + return web.FileResponse(os.path.join(self.web_root, "index.html")) + + @routes.get("/comfy/") + async def get_root_slash(request): + raise web.HTTPFound('../comfy') + @routes.get("/") async def get_root(request): response = web.FileResponse(os.path.join(self.web_root, "index.html")) @@ -217,10 +255,10 @@ async def get_root(request): return response @routes.get("/embeddings") - def get_embeddings(self): - embeddings = folder_paths.get_filename_list("embeddings") + async def get_embeddings(request): + embeddings = folder_paths.get_filename_list(execution_context.ExecutionContext(request), "embeddings") return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) - + @routes.get("/models") def list_model_types(request): model_types = list(folder_paths.folder_names_and_paths.keys()) @@ -230,14 +268,17 @@ def list_model_types(request): @routes.get("/models/{folder}") async def get_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) - files = folder_paths.get_filename_list(folder) + context = execution_context.ExecutionContext(request) + files = folder_paths.get_filename_list(context, folder) return web.json_response(files) @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join( + glob.escape(self.web_root), 'extensions_builtin/**/*.js'), recursive=True) + files += glob.glob(os.path.join( glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) @@ -249,22 +290,27 @@ async def get_extensions(request): return web.json_response(extensions) - def get_dir_by_type(dir_type): + @routes.get("/health/check") + async def health_check(request): + return web.json_response({'message': 'OK'}) + + + def get_dir_by_type(dir_type, user_hash): if dir_type is None: dir_type = "input" if dir_type == "input": - type_dir = folder_paths.get_input_directory() + type_dir = folder_paths.get_input_directory(user_hash) elif dir_type == "temp": - type_dir = folder_paths.get_temp_directory() + type_dir = folder_paths.get_temp_directory(user_hash) elif dir_type == "output": - type_dir = folder_paths.get_output_directory() + type_dir = folder_paths.get_output_directory(user_hash) return type_dir, dir_type def compare_image_hash(filepath, image): hasher = node_helpers.hasher() - + # function to compare hashes of two images to see if it already exists, fix to #3465 if os.path.exists(filepath): a = hasher() @@ -277,13 +323,13 @@ def compare_image_hash(filepath, image): return a.hexdigest() == b.hexdigest() return False - def image_upload(post, image_save_function=None): + def image_upload(context: execution_context.ExecutionContext, post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") image_is_duplicate = False image_upload_type = post.get("type") - upload_dir, image_upload_type = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type, context.user_hash) if image and image.file: filename = image.filename @@ -328,16 +374,17 @@ def image_upload(post, image_save_function=None): @routes.post("/upload/image") async def upload_image(request): post = await request.post() - return image_upload(post) + context = execution_context.ExecutionContext(request=request) + return image_upload(context, post) @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() - + context = execution_context.ExecutionContext(request=request) def image_save_function(image, post, filepath): original_ref = json.loads(post.get("original_ref")) - filename, output_dir = folder_paths.annotated_filepath(original_ref['filename']) + filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'], context.user_hash) # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: @@ -345,7 +392,7 @@ def image_save_function(image, post, filepath): if output_dir is None: type = original_ref.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + output_dir = folder_paths.get_directory_by_type(type, context.user_hash) if output_dir is None: return web.Response(status=400) @@ -372,13 +419,14 @@ def image_save_function(image, post, filepath): original_pil.putalpha(new_alpha) original_pil.save(filepath, compress_level=4, pnginfo=metadata) - return image_upload(post, image_save_function) + return image_upload(context, post, image_save_function) @routes.get("/view") async def view_image(request): + context = execution_context.ExecutionContext(request=request) if "filename" in request.rel_url.query: filename = request.rel_url.query["filename"] - filename,output_dir = folder_paths.annotated_filepath(filename) + filename,output_dir = folder_paths.annotated_filepath(filename, context.user_hash) # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: @@ -386,8 +434,7 @@ async def view_image(request): if output_dir is None: type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) - + output_dir = folder_paths.get_directory_by_type(type, context.user_hash) if output_dir is None: return web.Response(status=400) @@ -474,7 +521,8 @@ async def view_metadata(request): if not filename.endswith(".safetensors"): return web.Response(status=404) - safetensors_path = folder_paths.get_full_path(folder_name, filename) + context = execution_context.ExecutionContext(request) + safetensors_path = folder_paths.get_full_path(context, folder_name, filename) if safetensors_path is None: return web.Response(status=404) out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) @@ -524,13 +572,19 @@ async def system_stats(request): async def get_prompt(request): return web.json_response(self.get_queue_info()) - def node_info(node_class): + def node_info(context: execution_context.ExecutionContext, node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + if callable(obj_class.RETURN_TYPES): + return_types = obj_class.RETURN_TYPES(context) + else: + return_types = obj_class.RETURN_TYPES + + node_input_types = node_helpers.get_node_input_types(context, obj_class) info = {} - info['input'] = obj_class.INPUT_TYPES() - info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['input'] = node_input_types + info['input_order'] = {key: list(value.keys()) for (key, value) in node_input_types.items()} + info['output'] = return_types + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(return_types) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class @@ -558,9 +612,10 @@ def node_info(node_class): async def get_object_info(request): with folder_paths.cache_helper: out = {} + context = execution_context.ExecutionContext(request) for x in nodes.NODE_CLASS_MAPPINGS: try: - out[x] = node_info(x) + out[x] = node_info(context,x) except Exception as e: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) @@ -569,31 +624,71 @@ async def get_object_info(request): @routes.get("/object_info/{node_class}") async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) + context = execution_context.ExecutionContext(request=request) out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): - out[node_class] = node_info(node_class) + out[node_class] = node_info(context, node_class) return web.json_response(out) @routes.get("/history") async def get_history(request): - max_items = request.rel_url.query.get("max_items", None) - if max_items is not None: - max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + # max_items = request.rel_url.query.get("max_items", None) + # if max_items is not None: + # max_items = int(max_items) + # return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + return web.json_response([]) @routes.get("/history/{prompt_id}") async def get_history(request): - prompt_id = request.match_info.get("prompt_id", None) - return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + # prompt_id = request.match_info.get("prompt_id", None) + # return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + return web.json_response({}) + + @routes.delete("/inputs") + async def clear_input(request): + context = execution_context.ExecutionContext(request=request) + folder_paths.clear_input_directory(context.user_hash) + await self.send("input_cleared", { "node": None, "user_id": context.user_id }, context.user_id) + return web.json_response({ + "type": "input_cleared", + "data": { + "user_id": context.user_id + } + }) @routes.get("/queue") async def get_queue(request): queue_info = {} - current_queue = self.prompt_queue.get_current_queue() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + # current_queue = self.prompt_queue.get_current_queue() + queue_info['queue_running'] = [] + queue_info['queue_pending'] = [] return web.json_response(queue_info) + @routes.post("/prompt/valid") + async def post_prompt(request): + logging.info("got prompt") + json_data = await request.json() + json_data = self.trigger_on_prompt(json_data) + context = execution_context.ExecutionContext(request) + if "prompt" in json_data: + prompt = json_data["prompt"] + extra_data = {} + if "extra_data" in json_data: + extra_data = json_data["extra_data"] + + extra_data["client_id"] = context.user_id + + context = execution_context.ExecutionContext(request=request, extra_data=extra_data) + + valid = execution.validate_prompt(context, prompt) + if valid[0]: + return web.json_response({}, status=200) + else: + logging.warning("invalid prompt: {}".format(valid[1])) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) + else: + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) + @routes.post("/prompt") async def post_prompt(request): logging.info("got prompt") @@ -601,30 +696,39 @@ async def post_prompt(request): out_string = "" json_data = await request.json() json_data = self.trigger_on_prompt(json_data) + context = execution_context.ExecutionContext(request) if "number" in json_data: number = float(json_data['number']) else: number = self.number - if "front" in json_data: - if json_data['front']: - number = -number + # if "front" in json_data: + # if json_data['front']: + # number = -number self.number += 1 if "prompt" in json_data: prompt = json_data["prompt"] - valid = execution.validate_prompt(prompt) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] - if "client_id" in json_data: - extra_data["client_id"] = json_data["client_id"] + # if "client_id" in json_data: + # extra_data["client_id"] = json_data["client_id"] + extra_data["client_id"] = context.user_id + + context = execution_context.ExecutionContext(request=request, extra_data=extra_data) + + valid = execution.validate_prompt(context, prompt) if valid[0]: - prompt_id = str(uuid.uuid4()) + if 'x-task-id' in request.headers: + prompt_id = request.headers['x-task-id'] + else: + prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + extra_data["prompt_id"] = prompt_id + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, context)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: @@ -635,22 +739,22 @@ async def post_prompt(request): @routes.post("/queue") async def post_queue(request): - json_data = await request.json() - if "clear" in json_data: - if json_data["clear"]: - self.prompt_queue.wipe_queue() - if "delete" in json_data: - to_delete = json_data['delete'] - for id_to_delete in to_delete: - delete_func = lambda a: a[1] == id_to_delete - self.prompt_queue.delete_queue_item(delete_func) - - return web.Response(status=200) + # json_data = await request.json() + # if "clear" in json_data: + # if json_data["clear"]: + # self.prompt_queue.wipe_queue() + # if "delete" in json_data: + # to_delete = json_data['delete'] + # for id_to_delete in to_delete: + # delete_func = lambda a: a[1] == id_to_delete + # self.prompt_queue.delete_queue_item(delete_func) + + return web.Response(status=403) @routes.post("/interrupt") async def post_interrupt(request): - nodes.interrupt_processing() - return web.Response(status=200) + # nodes.interrupt_processing() + return web.Response(status=403) @routes.post("/free") async def post_free(request): @@ -665,24 +769,23 @@ async def post_free(request): @routes.post("/history") async def post_history(request): - json_data = await request.json() - if "clear" in json_data: - if json_data["clear"]: - self.prompt_queue.wipe_history() - if "delete" in json_data: - to_delete = json_data['delete'] - for id_to_delete in to_delete: - self.prompt_queue.delete_history_item(id_to_delete) + # json_data = await request.json() + # if "clear" in json_data: + # if json_data["clear"]: + # self.prompt_queue.wipe_history() + # if "delete" in json_data: + # to_delete = json_data['delete'] + # for id_to_delete in to_delete: + # self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) - def add_routes(self): self.user_manager.add_routes(self.routes) - self.app.add_subapp('/internal', self.internal_routes.get_app()) + # self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. # This is very useful for frontend dev server, which need to forward @@ -701,6 +804,7 @@ def add_routes(self): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([ web.static('/extensions/' + urllib.parse.quote(name), dir), + web.static('/extensions_builtin/' + urllib.parse.quote(name), dir), ]) self.app.add_routes([ @@ -757,6 +861,7 @@ async def send_image(self, image_data, sid=None): async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) + self.dffis_message_queue.send_message(sid, bytes(message)) if sid is None: sockets = list(self.sockets.values()) @@ -767,6 +872,7 @@ async def send_bytes(self, event, data, sid=None): async def send_json(self, event, data, sid=None): message = {"type": event, "data": data} + self.dffis_message_queue.send_message(sid, json.dumps(message)) if sid is None: sockets = list(self.sockets.values())