Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tool support for pydantic ai #230

Merged
merged 15 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)

self.client = OpenAIWrapper(**self.llm_config)

def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool):
"""update a tool_signature in the LLM configuration for tool_call.

Args:
Expand Down
2 changes: 2 additions & 0 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .helpers import get_all_interoperability_classes
from .interoperable import Interoperable

__all__ = ["Interoperable"]


class Interoperability:
_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes()
Expand Down
19 changes: 19 additions & 0 deletions autogen/interop/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

import sys

if sys.version_info < (3, 9):
raise ImportError("This submodule is only supported for Python versions 3.9 and above")

try:
import pydantic_ai.tools
except ImportError:
raise ImportError(
"Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]"
)

from .pydantic_ai import PydanticAIInteroperability

__all__ = ["PydanticAIInteroperability"]
86 changes: 86 additions & 0 deletions autogen/interop/pydantic_ai/pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0


from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional

from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

from ...tools import PydanticAITool as AG2PydanticAITool
from ..interoperability import Interoperable

__all__ = ["PydanticAIInteroperability"]


class PydanticAIInteroperability(Interoperable):
@staticmethod
def inject_params( # type: ignore[no-any-unimported]
ctx: Optional[RunContext[Any]],
tool: PydanticAITool,
) -> Callable[..., Any]:
max_retries = tool.max_retries if tool.max_retries is not None else 1
f = tool.function

@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if tool.current_retry >= max_retries:
raise ValueError(f"{tool.name} failed after {max_retries} retries")

try:
if ctx is not None:
kwargs.pop("ctx", None)
ctx.retry = tool.current_retry
result = f(**kwargs, ctx=ctx)
else:
result = f(**kwargs)
tool.current_retry = 0
except Exception as e:
tool.current_retry += 1
raise e

return result

sig = signature(f)
if ctx is not None:
new_params = [param for name, param in sig.parameters.items() if name != "ctx"]
else:
new_params = list(sig.parameters.values())

wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]

return wrapper

def convert_tool(self, tool: Any, deps: Any = None) -> AG2PydanticAITool:
if not isinstance(tool, PydanticAITool):
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}")

# needed for type checking
pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported]

if deps is not None:
ctx = RunContext(
deps=deps,
retry=0,
# All messages send to or returned by a model.
# This is mostly used on pydantic_ai Agent level.
messages=None, # TODO: check in the future if this is needed on Tool level
tool_name=pydantic_ai_tool.name,
)
else:
ctx = None

func = PydanticAIInteroperability.inject_params(
ctx=ctx,
tool=pydantic_ai_tool,
)

return AG2PydanticAITool(
name=pydantic_ai_tool.name,
description=pydantic_ai_tool.description,
func=func,
parameters_json_schema=pydantic_ai_tool._parameters_json_schema,
)
2 changes: 1 addition & 1 deletion autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
CEREBRAS_PRICING_1K = {
# Convert pricing per million to per thousand tokens.
"llama3.1-8b": (0.10 / 1000, 0.10 / 1000),
"llama3.1-70b": (0.60 / 1000, 0.60 / 1000),
"llama-3.3-70b": (0.85 / 1000, 1.20 / 1000),
}


Expand Down
3 changes: 2 additions & 1 deletion autogen/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pydantic_ai_tool import PydanticAITool
from .tool import Tool

__all__ = ["Tool"]
__all__ = ["PydanticAITool", "Tool"]
29 changes: 29 additions & 0 deletions autogen/tools/pydantic_ai_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict

from autogen.agentchat.conversable_agent import ConversableAgent

from .tool import Tool

__all__ = ["PydanticAITool"]


class PydanticAITool(Tool):
def __init__(
self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any]
) -> None:
super().__init__(name, description, func)
self._func_schema = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters_json_schema,
},
}

def register_for_llm(self, agent: ConversableAgent) -> None:
agent.update_tool_signature(self._func_schema, is_remove=False)
6 changes: 3 additions & 3 deletions notebook/tools_crewai_tools_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"from crewai_tools import FileWriterTool, ScrapeWebsiteTool\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interoperability.crewai import CrewAIInteroperability"
"from autogen.interop.crewai import CrewAIInteroperability"
]
},
{
Expand Down Expand Up @@ -168,7 +168,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -182,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.10.16"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions notebook/tools_langchain_tools_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"from langchain_community.utilities import WikipediaAPIWrapper\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interoperability.langchain import LangchainInteroperability"
"from autogen.interop.langchain import LangchainInteroperability"
]
},
{
Expand Down Expand Up @@ -112,7 +112,7 @@
"ag2_tool = langchain_interop.convert_tool(langchain_tool)\n",
"\n",
"ag2_tool.register_for_execution(user_proxy)\n",
"ag2_tool.register_for_llm(chatbot)\n"
"ag2_tool.register_for_llm(chatbot)"
]
},
{
Expand Down
183 changes: 183 additions & 0 deletions notebook/tools_pydantic_ai_tools_integration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Integrating PydanticAI Tools with the AG2 Framework\n",
"\n",
"In this tutorial, we demonstrate how to integrate [PydanticAI Tools](https://ai.pydantic.dev/tools/) into the AG2 framework. This process enables smooth interoperability between the two systems, allowing developers to leverage PydanticAI's powerful tools within AG2's flexible agent-based architecture. By the end of this guide, you will understand how to configure agents, convert PydanticAI tools for use in AG2, and validate the integration with a practical example.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation\n",
"To integrate LangChain tools into the AG2 framework, install the required dependencies:\n",
"\n",
"```bash\n",
"pip install ag2[interop-pydantic-ai]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n",
"\n",
"Import necessary modules and tools.\n",
"- `BaseModel`: Used to define data structures for tool inputs and outputs.\n",
"- `RunContext`: Provides context during the execution of tools.\n",
"- `PydanticAITool`: Represents a tool in the PydanticAI framework.\n",
"- `AssistantAgent` and `UserProxyAgent`: Agents that facilitate communication in the AG2 framework.\n",
"- `PydanticAIInteroperability`: A bridge for integrating PydanticAI tools with the AG2 framework."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from typing import Optional\n",
"\n",
"from pydantic import BaseModel\n",
"from pydantic_ai import RunContext\n",
"from pydantic_ai.tools import Tool as PydanticAITool\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interop.pydantic_ai import PydanticAIInteroperability"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent Configuration\n",
"\n",
"Configure the agents for the interaction.\n",
"- `config_list` defines the LLM configurations, including the model and API key.\n",
"- `UserProxyAgent` simulates user inputs without requiring actual human interaction (set to `NEVER`).\n",
"- `AssistantAgent` represents the AI agent, configured with the LLM settings."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"config_list = [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}]\n",
"user_proxy = UserProxyAgent(\n",
" name=\"User\",\n",
" human_input_mode=\"NEVER\",\n",
")\n",
"\n",
"chatbot = AssistantAgent(\n",
" name=\"chatbot\",\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Integration\n",
"\n",
"Integrate the PydanticAI tool with AG2.\n",
"\n",
"- Define a `Player` model using `BaseModel` to structure the input data.\n",
"- Use `RunContext` to securely inject dependencies (like the `Player` instance) into the tool function without exposing them to the LLM.\n",
"- Implement `get_player` to define the tool's functionality, accessing `ctx.deps` for injected data.\n",
"- Convert the tool to an AG2-compatible format with `PydanticAIInteroperability` and register it for execution and LLM communication.\n",
"- Convert the PydanticAI tool into an AG2-compatible format using `convert_tool`.\n",
"- Register the tool for both execution and communication with the LLM by associating it with the `user_proxy` and `chatbot`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Player(BaseModel):\n",
" name: str\n",
" age: int\n",
"\n",
"\n",
"def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -> str: # type: ignore[valid-type]\n",
" \"\"\"Get the player's name.\n",
"\n",
" Args:\n",
" additional_info: Additional information which can be used.\n",
" \"\"\"\n",
" return f\"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}\" # type: ignore[attr-defined]\n",
"\n",
"\n",
"pydantic_ai_interop = PydanticAIInteroperability()\n",
"pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True)\n",
"\n",
"# player will be injected as a dependency\n",
"player = Player(name=\"Luka\", age=25)\n",
"ag2_tool = pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player)\n",
"\n",
"ag2_tool.register_for_execution(user_proxy)\n",
"ag2_tool.register_for_llm(chatbot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Initiate a conversation between the `UserProxyAgent` and the `AssistantAgent`.\n",
"\n",
"- Use the `initiate_chat` method to send a message from the `user_proxy` to the `chatbot`.\n",
"- In this example, the user requests the chatbot to retrieve player information, providing \"goal keeper\" as additional context.\n",
"- The `Player` instance is securely injected into the tool using `RunContext`, ensuring the chatbot can retrieve and use this data during the interaction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"user_proxy.initiate_chat(\n",
" recipient=chatbot, message=\"Get player, for additional information use 'goal keeper'\", max_turns=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading