Skip to content

Commit

Permalink
Add Parameter Extractor node #24
Browse files Browse the repository at this point in the history
  • Loading branch information
receyuki committed Dec 11, 2023
1 parent 52882ee commit 97c6aa9
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"parameterDisplay.js",
"seedGen.js",
"loaderDisplay.js",
"extractorDisplay.js",
]

for file in files_to_copy:
Expand Down
38 changes: 38 additions & 0 deletions js/extractorDisplay.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import {app} from "../../scripts/app.js";
import {ComfyWidgets} from "../../scripts/widgets.js";

// Create a read-only string widget
function createWidget(app, node, widgetName, type) {
const widget = ComfyWidgets[type](node, widgetName, ["STRING", {multiline: true}], app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.textAlign = "center";
widget.inputEl.style.fontSize = "0.75rem";
return widget;
}

app.registerExtension({
name: "sd_prompt_reader.extractorDisplay",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "SDParameterExtractor") {
const onNodeCreated = nodeType.prototype.onNodeCreated;

nodeType.prototype.onNodeCreated = function () {
const result = onNodeCreated?.apply(this, arguments);

// Create widgets
const value_display = createWidget(app, this, "value_display", "STRING");
};

// Update widgets
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (message) {
onExecuted?.apply(this, arguments);
this.widgets.find(obj => obj.name === "value_display").value = message.text[1]
this.widgets.find(obj => obj.name === "parameter").options.values = message.text[0]
if (this.widgets.find(obj => obj.name === "parameter").value === "parameters not loaded") {
this.widgets.find(obj => obj.name === "parameter").value = message.text[0][0]
}
};
}
},
});
103 changes: 103 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import json
import re
import numpy as np
from pathlib import Path
from PIL import Image, ImageOps
Expand Down Expand Up @@ -936,13 +937,114 @@ def VALIDATE_INPUTS(
return True


class SDParameterExtractor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"settings": (
"STRING",
{"default": "", "multiline": True, "forceInput": True},
)
},
"optional": {
"parameter": (
["parameters not loaded"],
{"default": "parameters not loaded"},
),
"value_type": (["STRING", "INT", "FLOAT"], {"default": "STRING"}),
"parameter_index": (
"INT",
{"default": 0, "min": 0, "max": 255, "step": 1},
),
},
}

RETURN_TYPES = (any_type,)

RETURN_NAMES = ("VALUE",)
FUNCTION = "extract_param"
CATEGORY = "SD Prompt Reader"

def extract_param(
self,
settings: str = "",
parameter: str = "",
value_type: str = "STRING",
parameter_index: int = 0,
):
setting_dict = self.parse_setting(settings)
if not settings or not parameter or parameter == "parameters not loaded":
return {
"ui": {
"text": (list(setting_dict.keys()), ""),
},
"result": ("",),
}

result = setting_dict.get(parameter)

try:
if isinstance(result, tuple):
result = result[parameter_index]
if value_type == "INT":
result = int(result)
elif value_type == "FLOAT":
result = float(result)
except IndexError:
return {
"ui": {
"text": (list(setting_dict.keys()), "Parameter index out of range"),
},
"result": ("",),
}
except (ValueError, TypeError):
return {
"ui": {
"text": (
list(setting_dict.keys()),
f"{parameter}: {result}\n"
f"{result} is not a valid number; it will be output as STRING",
),
},
"result": (result,),
}
return {
"ui": {
"text": (list(setting_dict.keys()), f"{parameter}: {result}"),
},
"result": (result,),
}

@staticmethod
def parse_setting(settings):
pattern = re.compile(r"([^:,]+):\s*\(([^)]+)\)|([^:,]+):\s*([^,]+)")

matches = pattern.findall(settings)

result = {}
for match in matches:
key, value_paren, key_nonparen, value_nonparen = match
if key:
key = key.strip()
value = value_paren.strip()
value = tuple(v.strip() for v in value.split(","))
else:
key = key_nonparen.strip()
value = value_nonparen.strip()
result[key] = value

return result


NODE_CLASS_MAPPINGS = {
"SDPromptReader": SDPromptReader,
"SDPromptSaver": SDPromptSaver,
"SDParameterGenerator": SDParameterGenerator,
"SDPromptMerger": SDPromptMerger,
"SDTypeConverter": SDTypeConverter,
"SDBatchLoader": SDBatchLoader,
"SDParameterExtractor": SDParameterExtractor,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -952,4 +1054,5 @@ def VALIDATE_INPUTS(
"SDPromptMerger": "SD Prompt Merger",
"SDTypeConverter": "SD Type Converter",
"SDBatchLoader": "SD Batch Loader",
"SDParameterExtractor": "SD Parameter Extractor",
}

0 comments on commit 97c6aa9

Please sign in to comment.