Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow passing a Component to the set method #3597

Merged
merged 7 commits into from
Aug 28, 2024
31 changes: 31 additions & 0 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,40 @@ def _method_is_valid_output(self, method: Callable):
)
return method_is_output

def _build_error_string_from_matching_pairs(self, matching_pairs: list[tuple[Output, Input]]):
text = ""
for output, input_ in matching_pairs:
text += f"{output.name}[{','.join(output.types)}]->{input_.name}[{','.join(input_.input_types or [])}]\n"
return text

def _find_matching_output_method(self, value: "Component"):
# get all outputs of the value component
outputs = value.outputs
# check if the any of the types in the output.types matches ONLY one input in the current component
matching_pairs = []
for output in outputs:
for input_ in self.inputs:
for output_type in output.types:
if input_.input_types and output_type in input_.input_types:
matching_pairs.append((output, input_))
if len(matching_pairs) > 1:
matching_pairs_str = self._build_error_string_from_matching_pairs(matching_pairs)
raise ValueError(
f"There are multiple outputs from {value.__class__.__name__} that can connect to inputs in {self.__class__.__name__}: {matching_pairs_str}"
)
output, input_ = matching_pairs[0]
if not isinstance(output.method, str):
raise ValueError(f"Method {output.method} is not a valid output of {value.__class__.__name__}")
return getattr(value, output.method)

def _process_connection_or_parameter(self, key, value):
_input = self._get_or_create_input(key)
# We need to check if callable AND if it is a method from a class that inherits from Component
if isinstance(value, Component):
# We need to find the Output that can connect to an input of the current component
# if there's more than one output that matches, we need to raise an error
# because we don't know which one to connect to
value = self._find_matching_output_method(value)
if callable(value) and self._inherits_from_component(value):
try:
self._method_is_valid_output(value)
Expand Down
10 changes: 4 additions & 6 deletions src/backend/tests/unit/custom/custom_component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ def test_set_invalid_output():
chatoutput.set(input_value=chatinput.build_config)


def test_set_invalid_input():
def test_set_component():
crewai_agent = CrewAIAgentComponent()
task = SequentialTaskComponent()
with pytest.raises(
ValueError,
match="You set CrewAI Agent as value for `agent`. You should pass one of the following: 'build_output'",
):
task.set(agent=crewai_agent)
task.set(agent=crewai_agent)
assert task._edges[0]["source"] == crewai_agent._id
assert crewai_agent in task._components
15 changes: 15 additions & 0 deletions src/backend/tests/unit/graph/graph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import pytest

from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish

Expand Down Expand Up @@ -139,3 +141,16 @@ def test_graph_functional_start_end():
assert len(results) == len(ids) + 1
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
assert results[-1] == Finish()


def test_graph_set_with_invalid_component():
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
with pytest.raises(ValueError, match="There are multiple outputs"):
chat_output.set(sender_name=chat_input)


def test_graph_set_with_valid_component():
tool = YfinanceToolComponent()
tool_calling_agent = ToolCallingAgentComponent()
tool_calling_agent.set(tools=[tool])