From bf9a90a145bffdeced29343d0d606278108be735 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 3 Dec 2024 10:06:34 -0800 Subject: [PATCH] Revert "Add union link connection type support (#5806)" (#5889) This reverts commit 8d4e06324fb6c477c1f7f409c857c33d0b3b0ce2. --- comfy_execution/validation.py | 32 -------- execution.py | 6 +- .../validate_node_input_test.py | 75 ------------------- 3 files changed, 3 insertions(+), 110 deletions(-) delete mode 100644 comfy_execution/validation.py delete mode 100644 tests-unit/execution_test/validate_node_input_test.py diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py deleted file mode 100644 index 43fb6426d7a..00000000000 --- a/comfy_execution/validation.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - - -def validate_node_input( - received_type: str, input_type: str, strict: bool = False -) -> bool: - """ - received_type and input_type are both strings of the form "T1,T2,...". - - If strict is True, the input_type must contain the received_type. - For example, if received_type is "STRING" and input_type is "STRING,INT", - this will return True. But if received_type is "STRING,INT" and input_type is - "INT", this will return False. - - If strict is False, the input_type must have overlap with the received_type. - For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT", - this will return True. - """ - # If the types are exactly the same, we can return immediately - if received_type == input_type: - return True - - # Split the type strings into sets for comparison - received_types = set(t.strip() for t in received_type.split(",")) - input_types = set(t.strip() for t in input_type.split(",")) - - if strict: - # In strict mode, all received types must be in the input types - return received_types.issubset(input_types) - else: - # In non-strict mode, there must be at least one type in common - return len(received_types.intersection(input_types)) > 0 diff --git a/execution.py b/execution.py index 929ef85fac4..768e35abced 100644 --- a/execution.py +++ b/execution.py @@ -16,7 +16,6 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID -from comfy_execution.validation import validate_node_input from comfy.cli_args import args class ExecutionResult(Enum): @@ -528,6 +527,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): comfy.model_management.unload_all_models() + def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and received_type != type_input: + details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", diff --git a/tests-unit/execution_test/validate_node_input_test.py b/tests-unit/execution_test/validate_node_input_test.py deleted file mode 100644 index d6605e97f3a..00000000000 --- a/tests-unit/execution_test/validate_node_input_test.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -from comfy_execution.validation import validate_node_input - - -def test_exact_match(): - """Test cases where types match exactly""" - assert validate_node_input("STRING", "STRING") - assert validate_node_input("STRING,INT", "STRING,INT") - assert ( - validate_node_input("INT,STRING", "STRING,INT") - ) # Order shouldn't matter - - -def test_strict_mode(): - """Test strict mode validation""" - # Should pass - received type is subset of input type - assert validate_node_input("STRING", "STRING,INT", strict=True) - assert validate_node_input("INT", "STRING,INT", strict=True) - assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True) - - # Should fail - received type is not subset of input type - assert not validate_node_input("STRING,INT", "STRING", strict=True) - assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True) - assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True) - - -def test_non_strict_mode(): - """Test non-strict mode validation (default behavior)""" - # Should pass - types have overlap - assert validate_node_input("STRING,BOOLEAN", "STRING,INT") - assert validate_node_input("STRING,INT", "INT,BOOLEAN") - assert validate_node_input("STRING", "STRING,INT") - - # Should fail - no overlap in types - assert not validate_node_input("BOOLEAN", "STRING,INT") - assert not validate_node_input("FLOAT", "STRING,INT") - assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT") - - -def test_whitespace_handling(): - """Test that whitespace is handled correctly""" - assert validate_node_input("STRING, INT", "STRING,INT") - assert validate_node_input("STRING,INT", "STRING, INT") - assert validate_node_input(" STRING , INT ", "STRING,INT") - assert validate_node_input("STRING,INT", " STRING , INT ") - - -def test_empty_strings(): - """Test behavior with empty strings""" - assert validate_node_input("", "") - assert not validate_node_input("STRING", "") - assert not validate_node_input("", "STRING") - - -def test_single_vs_multiple(): - """Test single type against multiple types""" - assert validate_node_input("STRING", "STRING,INT,BOOLEAN") - assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False) - assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True) - - -@pytest.mark.parametrize( - "received,input_type,strict,expected", - [ - ("STRING", "STRING", False, True), - ("STRING,INT", "STRING,INT", False, True), - ("STRING", "STRING,INT", True, True), - ("STRING,INT", "STRING", True, False), - ("BOOLEAN", "STRING,INT", False, False), - ("STRING,BOOLEAN", "STRING,INT", False, True), - ], -) -def test_parametrized_cases(received, input_type, strict, expected): - """Parametrized test cases for various scenarios""" - assert validate_node_input(received, input_type, strict) == expected