Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing fixed for realtime agent #254

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def initiate_swarm_chat(
user_agent: Optional[UserProxyAgent] = None,
max_rounds: int = 20,
context_variables: Optional[dict[str, Any]] = None,
after_work: Optional[Union[AFTER_WORK, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE),
after_work: Optional[Union[AfterWorkOption, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE),
) -> tuple[ChatResult, dict[str, Any], "SwarmAgent"]:
"""Initialize and run a swarm chat
Expand Down
123 changes: 86 additions & 37 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,83 @@
# import asyncio
import json
import logging
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import anyio
import websockets
from asyncer import TaskGroup, asyncify, create_task_group, syncify
from asyncer import TaskGroup, asyncify, create_task_group
from websockets import connect
from websockets.asyncio.client import ClientConnection

from autogen.agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
from ..contrib.swarm_agent import AfterWorkOption, SwarmAgent, initiate_swarm_chat

from .function_observer import FunctionObserver
if TYPE_CHECKING:
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver

logger = logging.getLogger(__name__)


class OpenAIRealtimeClient:
"""(Experimental) Client for OpenAI Realtime API."""

def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
def __init__(
self, agent: "RealtimeAgent", audio_adapter: "RealtimeObserver", function_observer: "FunctionObserver"
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
args:
agent: Agent instance
the agent to be used for the conversation
audio_adapter: RealtimeObserver
adapter for streaming the audio from the client
function_observer: FunctionObserver
observer for handling function calls
Args:
agent (RealtimeAgent): The agent that the client is associated with.
audio_adapter (RealtimeObserver): The audio adapter for the client.
function_observer (FunctionObserver): The function observer for the client.
"""
self._agent = agent
self._observers = []
self._openai_ws = None # todo factor out to OpenAIClient
self._observers: list["RealtimeObserver"] = []
self._openai_ws: Optional[ClientConnection] = None # todo factor out to OpenAIClient
self.register(audio_adapter)
self.register(function_observer)

# LLM config
llm_config = self._agent.llm_config

config = llm_config["config_list"][0]
config: dict[str, Any] = llm_config["config_list"][0] # type: ignore[index]

self.model = config["model"]
self.temperature = llm_config["temperature"]
self.api_key = config["api_key"]
self.model: str = config["model"]
self.temperature: float = llm_config["temperature"] # type: ignore[index]
self.api_key: str = config["api_key"]

# create a task group to manage the tasks
self.tg: Optional[TaskGroup] = None

def register(self, observer):
@property
def openai_ws(self) -> ClientConnection:
"""Get the OpenAI WebSocket connection."""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._openai_ws

def register(self, observer: "RealtimeObserver") -> None:
"""Register an observer to the client."""
observer.register_client(self)
self._observers.append(observer)

async def notify_observers(self, message):
"""Notify all observers of a message from the OpenAI Realtime API."""
async def notify_observers(self, message: dict[str, Any]) -> None:
"""Notify all observers of a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message from the OpenAI Realtime API.
"""
for observer in self._observers:
await observer.update(message)

async def function_result(self, call_id, result):
"""Send the result of a function call to the OpenAI Realtime API."""
async def function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
result_item = {
"type": "conversation.item.create",
"item": {
Expand All @@ -73,11 +93,23 @@ async def function_result(self, call_id, result):
"output": result,
},
}
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

async def send_text(self, *, role: str, text: str):
"""Send a text message to the OpenAI Realtime API."""
async def send_text(self, *, role: str, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""

if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

await self._openai_ws.send(json.dumps({"type": "response.cancel"}))
text_item = {
"type": "conversation.item.create",
Expand All @@ -87,7 +119,7 @@ async def send_text(self, *, role: str, text: str):
await self._openai_ws.send(json.dumps({"type": "response.create"}))

# todo override in specific clients
async def initialize_session(self):
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
# todo: move to config
Expand All @@ -100,25 +132,35 @@ async def initialize_session(self):
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
"""Send a session update to the OpenAI Realtime API."""
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

update = {"type": "session.update", "session": session_options}
logger.info("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))
logger.info("Sending session update finished")

async def _read_from_client(self):
async def _read_from_client(self) -> None:
"""Read messages from the OpenAI Realtime API."""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
logger.warning(f"Error in _read_from_client: {e}")

async def run(self):
async def run(self) -> None:
"""Run the client."""
async with websockets.connect(
async with connect(
f"wss://api.openai.com/v1/realtime?model={self.model}",
additional_headers={
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -127,17 +169,24 @@ async def run(self):
) as openai_ws:
self._openai_ws = openai_ws
await self.initialize_session()
# await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers])
async with create_task_group() as tg:
self.tg = tg
self.tg.soonify(self._read_from_client)()
for observer in self._observers:
self.tg.soonify(observer.run)()

initial_agent = self._agent._initial_agent
agents = self._agent._agents
user_agent = self._agent

if not (initial_agent and agents):
raise RuntimeError("Swarm not registered.")

if self._agent._start_swarm_chat:
self.tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._agent._initial_agent,
agents=self._agent._agents,
user_agent=self._agent,
initial_agent=initial_agent,
agents=agents,
user_agent=user_agent, # type: ignore[arg-type]
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
36 changes: 25 additions & 11 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,52 @@
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any

from asyncer import asyncify
from pydantic import BaseModel

from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from .realtime_agent import RealtimeAgent

logger = logging.getLogger(__name__)


class FunctionObserver(RealtimeObserver):
"""Observer for handling function calls from the OpenAI Realtime API."""

def __init__(self, agent):
def __init__(self, agent: "RealtimeAgent") -> None:
"""Observer for handling function calls from the OpenAI Realtime API.

Args:
agent: Agent instance
the agent to be used for the conversation
agent (RealtimeAgent): The realtime agent attached to the observer.
"""
super().__init__()
self._agent = agent

async def update(self, response):
"""Handle function call events from the OpenAI Realtime API."""
async def update(self, response: dict[str, Any]) -> None:
"""Handle function call events from the OpenAI Realtime API.

Args:
response (dict[str, Any]): The response from the OpenAI Realtime API.
"""
if response.get("type") == "response.function_call_arguments.done":
logger.info(f"Received event: {response['type']}", response)
await self.call_function(
call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"])
)

async def call_function(self, call_id, name, kwargs):
"""Call a function registered with the agent."""
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
"""Call a function registered with the agent.

Args:
call_id (str): The ID of the function call.
name (str): The name of the function to call.
kwargs (Any[str, Any]): The arguments to pass to the function.
"""

if name in self._agent.realtime_functions:
_, func = self._agent.realtime_functions[name]
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
Expand All @@ -54,19 +68,19 @@ async def call_function(self, call_id, name, kwargs):
elif not isinstance(result, str):
result = json.dumps(result)

await self._client.function_result(call_id, result)
await self.client.function_result(call_id, result)

async def run(self):
async def run(self) -> None:
"""Run the observer.

Initialize the session with the OpenAI Realtime API.
"""
await self.initialize_session()

async def initialize_session(self):
async def initialize_session(self) -> None:
"""Add registered tools to OpenAI with a session update."""
session_update = {
"tools": [schema for schema, _ in self._agent.realtime_functions.values()],
"tool_choice": "auto",
}
await self._client.session_update(session_update)
await self.client.session_update(session_update)
Loading
Loading