Skip to content

Commit

Permalink
Merge pull request #236 from ag2ai/add-tool-imports-refactoring-and-docs
Browse files Browse the repository at this point in the history
Add tool imports refactoring and docs
  • Loading branch information
davorrunje authored Dec 19, 2024
2 parents ec6d6be + b41328f commit 166537e
Show file tree
Hide file tree
Showing 13 changed files with 682 additions and 551 deletions.
29 changes: 28 additions & 1 deletion autogen/interop/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,36 @@ def _sanitize_name(s: str) -> str:


class CrewAIInteroperability(Interoperable):
def convert_tool(self, tool: Any) -> Tool:
"""
A class implementing the `Interoperable` protocol for converting CrewAI tools
to a general `Tool` format.
This class takes a `CrewAITool` and converts it into a standard `Tool` object.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given CrewAI tool into a general `Tool` format.
This method ensures that the provided tool is a valid `CrewAITool`, sanitizes
the tool's name, processes its description, and prepares a function to interact
with the tool's arguments. It then returns a standardized `Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `CrewAITool`.
**kwargs (Any): Additional arguments, which are not supported by this method.
Returns:
Tool: A standardized `Tool` object converted from the CrewAI tool.
Raises:
ValueError: If the provided tool is not an instance of `CrewAITool`, or if
any additional arguments are passed.
"""
if not isinstance(tool, CrewAITool):
raise ValueError(f"Expected an instance of `crewai.tools.BaseTool`, got {type(tool)}")
if kwargs:
raise ValueError(f"The CrewAIInteroperability does not support any additional arguments, got {kwargs}")

# needed for type checking
crewai_tool: CrewAITool = tool # type: ignore[no-any-unimported]
Expand Down
59 changes: 57 additions & 2 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,83 @@


class Interoperability:
"""
A class to handle interoperability between different tool types.
This class allows the conversion of tools to various interoperability classes and provides functionality
for retrieving and registering interoperability classes.
"""

_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes()

def __init__(self) -> None:
"""
Initializes an instance of the Interoperability class.
This constructor does not perform any specific actions as the class is primarily used for its class
methods to manage interoperability classes.
"""
pass

def convert_tool(self, *, tool: Any, type: str) -> Tool:
def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool:
"""
Converts a given tool to an instance of a specified interoperability type.
Args:
tool (Any): The tool object to be converted.
type (str): The type of interoperability to convert the tool to.
**kwargs (Any): Additional arguments to be passed during conversion.
Returns:
Tool: The converted tool.
Raises:
ValueError: If the interoperability class for the provided type is not found.
"""
interop_cls = self.get_interoperability_class(type)
interop = interop_cls()
return interop.convert_tool(tool)
return interop.convert_tool(tool, **kwargs)

@classmethod
def get_interoperability_class(cls, type: str) -> Type[Interoperable]:
"""
Retrieves the interoperability class corresponding to the specified type.
Args:
type (str): The type of the interoperability class to retrieve.
Returns:
Type[Interoperable]: The interoperability class type.
Raises:
ValueError: If no interoperability class is found for the provided type.
"""
if type not in cls._interoperability_classes:
raise ValueError(f"Interoperability class {type} not found")
return cls._interoperability_classes[type]

@classmethod
def supported_types(cls) -> List[str]:
"""
Returns a sorted list of all supported interoperability types.
Returns:
List[str]: A sorted list of strings representing the supported interoperability types.
"""
return sorted(cls._interoperability_classes.keys())

@classmethod
def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None:
"""
Registers a new interoperability class with the given name.
Args:
name (str): The name to associate with the interoperability class.
interoperability_class (Type[Interoperable]): The class implementing the Interoperable protocol.
Raises:
ValueError: If the provided class does not implement the Interoperable protocol.
"""
if not issubclass(interoperability_class, Interoperable):
raise ValueError(
f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}"
Expand Down
22 changes: 21 additions & 1 deletion autogen/interop/interoperable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,24 @@

@runtime_checkable
class Interoperable(Protocol):
def convert_tool(self, tool: Any) -> Tool: ...
"""
A Protocol defining the interoperability interface for tool conversion.
This protocol ensures that any class implementing it provides the method
`convert_tool` to convert a given tool into a desired format or type.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given tool to a desired format or type.
This method should be implemented by any class adhering to the `Interoperable` protocol.
Args:
tool (Any): The tool object to be converted.
**kwargs (Any): Additional parameters to pass during the conversion process.
Returns:
Tool: The converted tool in the desired format or type.
"""
...
31 changes: 30 additions & 1 deletion autogen/interop/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,38 @@


class LangchainInteroperability(Interoperable):
def convert_tool(self, tool: Any) -> Tool:
"""
A class implementing the `Interoperable` protocol for converting Langchain tools
into a general `Tool` format.
This class takes a `LangchainTool` and converts it into a standard `Tool` object,
ensuring compatibility between Langchain tools and other systems that expect
the `Tool` format.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given Langchain tool into a general `Tool` format.
This method verifies that the provided tool is a valid `LangchainTool`,
processes the tool's input and description, and returns a standardized
`Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `LangchainTool`.
**kwargs (Any): Additional arguments, which are not supported by this method.
Returns:
Tool: A standardized `Tool` object converted from the Langchain tool.
Raises:
ValueError: If the provided tool is not an instance of `LangchainTool`, or if
any additional arguments are passed.
"""
if not isinstance(tool, LangchainTool):
raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}")
if kwargs:
raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}")

# needed for type checking
langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported]
Expand Down
60 changes: 58 additions & 2 deletions autogen/interop/pydantic_ai/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0


import warnings
from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional
Expand All @@ -17,11 +18,38 @@


class PydanticAIInteroperability(Interoperable):
"""
A class implementing the `Interoperable` protocol for converting Pydantic AI tools
into a general `Tool` format.
This class takes a `PydanticAITool` and converts it into a standard `Tool` object,
ensuring compatibility between Pydantic AI tools and other systems that expect
the `Tool` format. It also provides a mechanism for injecting context parameters
into the tool's function.
"""

@staticmethod
def inject_params( # type: ignore[no-any-unimported]
ctx: Optional[RunContext[Any]],
tool: PydanticAITool,
) -> Callable[..., Any]:
"""
Wraps the tool's function to inject context parameters and handle retries.
This method ensures that context parameters are properly passed to the tool
when invoked and that retries are managed according to the tool's settings.
Args:
ctx (Optional[RunContext[Any]]): The run context, which may include dependencies
and retry information.
tool (PydanticAITool): The Pydantic AI tool whose function is to be wrapped.
Returns:
Callable[..., Any]: A wrapped function that includes context injection and retry handling.
Raises:
ValueError: If the tool fails after the maximum number of retries.
"""
max_retries = tool.max_retries if tool.max_retries is not None else 1
f = tool.function

Expand Down Expand Up @@ -54,14 +82,42 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

return wrapper

def convert_tool(self, tool: Any, deps: Any = None) -> AG2PydanticAITool:
def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool:
"""
Converts a given Pydantic AI tool into a general `Tool` format.
This method verifies that the provided tool is a valid `PydanticAITool`,
handles context dependencies if necessary, and returns a standardized `Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `PydanticAITool`.
deps (Any, optional): The dependencies to inject into the context, required if
the tool takes a context. Defaults to None.
**kwargs (Any): Additional arguments that are not used in this method.
Returns:
AG2PydanticAITool: A standardized `Tool` object converted from the Pydantic AI tool.
Raises:
ValueError: If the provided tool is not an instance of `PydanticAITool`, or if
dependencies are missing for tools that require a context.
UserWarning: If the `deps` argument is provided for a tool that does not take a context.
"""
if not isinstance(tool, PydanticAITool):
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}")

# needed for type checking
pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported]

if deps is not None:
if tool.takes_ctx and deps is None:
raise ValueError("If the tool takes a context, the `deps` argument must be provided")
if not tool.takes_ctx and deps is not None:
warnings.warn(
"The `deps` argument is provided but will be ignored because the tool does not take a context.",
UserWarning,
)

if tool.takes_ctx:
ctx = RunContext(
deps=deps,
retry=0,
Expand Down
33 changes: 33 additions & 0 deletions autogen/tools/pydantic_ai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,33 @@


class PydanticAITool(Tool):
"""
A class representing a Pydantic AI Tool that extends the general Tool functionality
with additional functionality specific to Pydantic AI tools.
This class inherits from the Tool class and adds functionality for registering
tools with a ConversableAgent, along with providing additional schema information
specific to Pydantic AI tools, such as parameters and function signatures.
Attributes:
parameters_json_schema (Dict[str, Any]): A schema describing the parameters
that the tool's function expects.
"""

def __init__(
self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any]
) -> None:
"""
Initializes a PydanticAITool object with the provided name, description,
function, and parameter schema.
Args:
name (str): The name of the tool.
description (str): A description of what the tool does.
func (Callable[..., Any]): The function that is executed when the tool is called.
parameters_json_schema (Dict[str, Any]): A schema describing the parameters
that the function accepts.
"""
super().__init__(name, description, func)
self._func_schema = {
"type": "function",
Expand All @@ -26,4 +50,13 @@ def __init__(
}

def register_for_llm(self, agent: ConversableAgent) -> None:
"""
Registers the tool with the ConversableAgent for use with a language model (LLM).
This method updates the agent's tool signature to include the function schema,
allowing the agent to invoke the tool correctly during interactions with the LLM.
Args:
agent (ConversableAgent): The agent with which the tool will be registered.
"""
agent.update_tool_signature(self._func_schema, is_remove=False)
30 changes: 30 additions & 0 deletions autogen/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@


class Tool:
"""
A class representing a Tool that can be used by an agent for various tasks.
This class encapsulates a tool with a name, description, and an executable function.
The tool can be registered with a ConversableAgent for use either with an LLM or for direct execution.
Attributes:
name (str): The name of the tool.
description (str): A brief description of the tool's purpose or function.
func (Callable[..., Any]): The function to be executed when the tool is called.
"""

def __init__(self, name: str, description: str, func: Callable[..., Any]) -> None:
"""Create a new Tool object.
Expand All @@ -41,7 +53,25 @@ def func(self) -> Callable[..., Any]:
return self._func

def register_for_llm(self, agent: ConversableAgent) -> None:
"""
Registers the tool for use with a ConversableAgent's language model (LLM).
This method registers the tool so that it can be invoked by the agent during
interactions with the language model.
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_llm(name=self._name, description=self._description)(self._func)

def register_for_execution(self, agent: ConversableAgent) -> None:
"""
Registers the tool for direct execution by a ConversableAgent.
This method registers the tool so that it can be executed by the agent,
typically outside of the context of an LLM interaction.
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_execution(name=self._name)(self._func)
Loading

0 comments on commit 166537e

Please sign in to comment.