Skip to content

Commit

Permalink
feat: add ComponentTool to support converting Component to Tool (#3412)
Browse files Browse the repository at this point in the history
* feat: Add ComponentTool to convert a Component to a Tool

* test(component): add unit test for ComponentTool with ChatInput input.

* feat: Add method to convert Component to ComponentTool.

* feat: Add unit test for ChatInput to Tool conversion.

* chore: add comment

* test: fix assertion

---------

Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
  • Loading branch information
ogabrielluiz and italojohnny authored Aug 19, 2024
1 parent 149c96d commit 75dbb68
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/backend/base/langflow/base/tools/component_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

from langchain_core.tools import BaseTool, ToolException

from langflow.custom.custom_component.component import Component


class ComponentTool(BaseTool):
name: str
description: str
component: "Component"

def __init__(self, component: "Component") -> None:
"""Initialize the tool."""
from langflow.io.schema import create_input_schema

name = component.name or component.__class__.__name__
description = component.description or ""
args_schema = create_input_schema(component.inputs)
super().__init__(name=name, description=description, args_schema=args_schema, component=component)
# self.component = component

@property
def args(self) -> dict:
schema = self.get_input_schema()
return schema.schema()["properties"]

def _run(
self,
*args: Any,
**kwargs: Any,
) -> dict:
"""Use the tool."""
try:
results, _ = self.component(**kwargs)
return results
except Exception as e:
raise ToolException(f"Error running {self.name}: {e}")


ComponentTool.update_forward_refs()
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,9 @@ def build(self, **kwargs):

def _get_fallback_input(self, **kwargs):
return Input(**kwargs)

def to_tool(self):
# TODO: This is a temporary solution to avoid circular imports
from langflow.base.tools.component_tool import ComponentTool

return ComponentTool(component=self)
62 changes: 62 additions & 0 deletions src/backend/tests/unit/base/tools/test_component_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from langflow.base.tools.component_tool import ComponentTool
from langflow.components.inputs.ChatInput import ChatInput


@pytest.fixture
def client():
pass


def test_component_tool():
chat_input = ChatInput()
component_tool = ComponentTool(component=chat_input)
assert component_tool.name == "ChatInput"
assert component_tool.description == chat_input.description
assert component_tool.args == {
"input_value": {
"default": "",
"description": "Message to be passed as input.",
"title": "Input Value",
"type": "string",
},
"should_store_message": {
"default": True,
"description": "Store the message in the history.",
"title": "Should Store Message",
"type": "boolean",
},
"sender": {
"default": "User",
"description": "Type of sender.",
"enum": ["Machine", "User"],
"title": "Sender",
"type": "string",
},
"sender_name": {
"default": "User",
"description": "Name of the sender.",
"title": "Sender Name",
"type": "string",
},
"session_id": {
"default": "",
"description": "The session ID of the chat. If empty, the current session ID parameter will be used.",
"title": "Session Id",
"type": "string",
},
"files": {
"default": "",
"description": "Files to be sent with the message.",
"items": {"type": "string"},
"title": "Files",
"type": "array",
},
}
assert component_tool.component == chat_input

result = component_tool.invoke(input=dict(input_value="test"))
assert isinstance(result, dict)
assert hasattr(result["message"], "get_text")
assert result["message"].get_text() == "test"
16 changes: 16 additions & 0 deletions src/backend/tests/unit/custom/component/test_component_to_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from langflow.components.inputs.ChatInput import ChatInput


@pytest.fixture
def client():
pass


def test_component_to_tool():
chat_input = ChatInput()
tool = chat_input.to_tool()
assert tool.name == "ChatInput"
assert tool.description == "Get chat inputs from the Playground."
assert tool.component._id == chat_input._id

0 comments on commit 75dbb68

Please sign in to comment.