Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separated client out from RealtimeAgent #206

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import json
from abc import ABC, abstractmethod

import websockets

from .function_observer import FunctionObserver


class Client(ABC):
def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
self._agent = agent
self._observers = []
self._openai_ws = None # todo factor out to OpenAIClient
self.register(audio_adapter)
self.register(function_observer)

def register(self, observer):
observer.register_client(self)
self._observers.append(observer)

async def notify_observers(self, message):
for observer in self._observers:
await observer.update(message)

async def function_result(self, call_id, result):
result_item = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

# todo override in specific clients
async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._agent.voice,
"instructions": self._agent.system_message,
"modalities": ["text", "audio"],
"temperature": 0.8,
}
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))

async def _read_from_client(self):
try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")

async def run(self):

async with websockets.connect(
"wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
additional_headers={
"Authorization": f"Bearer {self._agent.llm_config['config_list'][0]['api_key']}",
"OpenAI-Beta": "realtime=v1",
},
) as openai_ws:
self._openai_ws = openai_ws
await self.initialize_session()
await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers])
12 changes: 6 additions & 6 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class FunctionObserver(RealtimeObserver):
def __init__(self, registered_functions):
def __init__(self, agent):
super().__init__()
self.registered_functions = registered_functions
self._agent = agent

async def update(self, response):
if response.get("type") == "response.function_call_arguments.done":
Expand All @@ -24,16 +24,16 @@ async def update(self, response):
)

async def call_function(self, call_id, name, kwargs):
_, func = self.registered_functions[name]
await self.client.function_result(call_id, func(**kwargs))
_, func = self._agent.registered_functions[name]
await self._client.function_result(call_id, func(**kwargs))

async def run(self):
await self.initialize_session()

async def initialize_session(self):
"""Add tool to OpenAI."""
session_update = {
"tools": [schema for schema, _ in self.registered_functions.values()],
"tools": [schema for schema, _ in self._agent.registered_functions.values()],
"tool_choice": "auto",
}
await self.client.session_update(session_update)
await self._client.session_update(session_update)
64 changes: 5 additions & 59 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from autogen.agentchat.agent import Agent, LLMAgent
from autogen.function_utils import get_function_schema

from .client import Client
from .function_observer import FunctionObserver
from .realtime_observer import RealtimeObserver

Expand All @@ -40,71 +41,16 @@ def __init__(
context_variables: Optional[Dict[str, Any]] = None,
voice: str = "alloy",
):

self._client = Client(self, audio_adapter, FunctionObserver(self))
self.llm_config = llm_config
self._oai_system_message = [{"content": system_message, "role": "system"}]
self.voice = voice
self.observers = []
self.openai_ws = None
self.registered_functions = {}

self.register(audio_adapter)

def register(self, observer):
observer.register_client(self)
self.observers.append(observer)

async def notify_observers(self, message):
for observer in self.observers:
await observer.update(message)

async def function_result(self, call_id, result):
result_item = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
await self.openai_ws.send(json.dumps(result_item))
await self.openai_ws.send(json.dumps({"type": "response.create"}))

async def _read_from_client(self):
try:
async for openai_message in self.openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")
self._oai_system_message = [{"content": system_message, "role": "system"}] # todo still needed?

async def run(self):
self.register(FunctionObserver(registered_functions=self.registered_functions))
async with websockets.connect(
"wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
additional_headers={
"Authorization": f"Bearer {self.llm_config['config_list'][0]['api_key']}",
"OpenAI-Beta": "realtime=v1",
},
) as openai_ws:
self.openai_ws = openai_ws
await self.initialize_session()
await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self.observers])

async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self.voice,
"instructions": self.system_message,
"modalities": ["text", "audio"],
"temperature": 0.8,
}
await self.session_update(session_update)

async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update))
await self.openai_ws.send(json.dumps(update))
await self._client.run()

def register_handover(
self,
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/realtime_agent/realtime_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

class RealtimeObserver(ABC):
def __init__(self):
self.client = None
self._client = None

def register_client(self, client):
self.client = client
self._client = client

@abstractmethod
async def run(self, openai_ws):
Expand Down
7 changes: 3 additions & 4 deletions autogen/agentchat/realtime_agent/twilio_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class TwilioAudioAdapter(RealtimeObserver):
def __init__(self, websocket):
super().__init__()
self.client = None
self.websocket = websocket

# Connection specific state
Expand Down Expand Up @@ -86,7 +85,7 @@ async def handle_speech_started_event(self):
"content_index": 0,
"audio_end_ms": elapsed_time,
}
await self.client.openai_ws.send(json.dumps(truncate_event))
await self._client.openai_ws.send(json.dumps(truncate_event))

await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})

Expand All @@ -101,7 +100,7 @@ async def send_mark(self):
self.mark_queue.append("responsePart")

async def run(self):
openai_ws = self.client.openai_ws
openai_ws = self._client._openai_ws
await self.initialize_session()

async for message in self.websocket.iter_text():
Expand All @@ -126,4 +125,4 @@ async def initialize_session(self):
"input_audio_format": "g711_ulaw",
"output_audio_format": "g711_ulaw",
}
await self.client.session_update(session_update)
await self._client.session_update(session_update)
Loading