Skip to content

Commit

Permalink
Merge pull request #206 from davorinrusevljan/realtime-agent-rush
Browse files Browse the repository at this point in the history
Separated client out from RealtimeAgent
  • Loading branch information
sternakt authored Dec 13, 2024
2 parents 96cc3bd + 6145e62 commit 299c862
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 71 deletions.
75 changes: 75 additions & 0 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import json
from abc import ABC, abstractmethod

import websockets

from .function_observer import FunctionObserver


class Client(ABC):
def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
self._agent = agent
self._observers = []
self._openai_ws = None # todo factor out to OpenAIClient
self.register(audio_adapter)
self.register(function_observer)

def register(self, observer):
observer.register_client(self)
self._observers.append(observer)

async def notify_observers(self, message):
for observer in self._observers:
await observer.update(message)

async def function_result(self, call_id, result):
result_item = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

# todo override in specific clients
async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._agent.voice,
"instructions": self._agent.system_message,
"modalities": ["text", "audio"],
"temperature": 0.8,
}
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))

async def _read_from_client(self):
try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")

async def run(self):

async with websockets.connect(
"wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
additional_headers={
"Authorization": f"Bearer {self._agent.llm_config['config_list'][0]['api_key']}",
"OpenAI-Beta": "realtime=v1",
},
) as openai_ws:
self._openai_ws = openai_ws
await self.initialize_session()
await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers])
12 changes: 6 additions & 6 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class FunctionObserver(RealtimeObserver):
def __init__(self, registered_functions):
def __init__(self, agent):
super().__init__()
self.registered_functions = registered_functions
self._agent = agent

async def update(self, response):
if response.get("type") == "response.function_call_arguments.done":
Expand All @@ -24,16 +24,16 @@ async def update(self, response):
)

async def call_function(self, call_id, name, kwargs):
_, func = self.registered_functions[name]
await self.client.function_result(call_id, func(**kwargs))
_, func = self._agent.registered_functions[name]
await self._client.function_result(call_id, func(**kwargs))

async def run(self):
await self.initialize_session()

async def initialize_session(self):
"""Add tool to OpenAI."""
session_update = {
"tools": [schema for schema, _ in self.registered_functions.values()],
"tools": [schema for schema, _ in self._agent.registered_functions.values()],
"tool_choice": "auto",
}
await self.client.session_update(session_update)
await self._client.session_update(session_update)
64 changes: 5 additions & 59 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from autogen.agentchat.agent import Agent, LLMAgent
from autogen.function_utils import get_function_schema

from .client import Client
from .function_observer import FunctionObserver
from .realtime_observer import RealtimeObserver

Expand All @@ -40,71 +41,16 @@ def __init__(
context_variables: Optional[Dict[str, Any]] = None,
voice: str = "alloy",
):

self._client = Client(self, audio_adapter, FunctionObserver(self))
self.llm_config = llm_config
self._oai_system_message = [{"content": system_message, "role": "system"}]
self.voice = voice
self.observers = []
self.openai_ws = None
self.registered_functions = {}

self.register(audio_adapter)

def register(self, observer):
observer.register_client(self)
self.observers.append(observer)

async def notify_observers(self, message):
for observer in self.observers:
await observer.update(message)

async def function_result(self, call_id, result):
result_item = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
await self.openai_ws.send(json.dumps(result_item))
await self.openai_ws.send(json.dumps({"type": "response.create"}))

async def _read_from_client(self):
try:
async for openai_message in self.openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")
self._oai_system_message = [{"content": system_message, "role": "system"}] # todo still needed?

async def run(self):
self.register(FunctionObserver(registered_functions=self.registered_functions))
async with websockets.connect(
"wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
additional_headers={
"Authorization": f"Bearer {self.llm_config['config_list'][0]['api_key']}",
"OpenAI-Beta": "realtime=v1",
},
) as openai_ws:
self.openai_ws = openai_ws
await self.initialize_session()
await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self.observers])

async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self.voice,
"instructions": self.system_message,
"modalities": ["text", "audio"],
"temperature": 0.8,
}
await self.session_update(session_update)

async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update))
await self.openai_ws.send(json.dumps(update))
await self._client.run()

def register_handover(
self,
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/realtime_agent/realtime_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

class RealtimeObserver(ABC):
def __init__(self):
self.client = None
self._client = None

def register_client(self, client):
self.client = client
self._client = client

@abstractmethod
async def run(self, openai_ws):
Expand Down
7 changes: 3 additions & 4 deletions autogen/agentchat/realtime_agent/twilio_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class TwilioAudioAdapter(RealtimeObserver):
def __init__(self, websocket):
super().__init__()
self.client = None
self.websocket = websocket

# Connection specific state
Expand Down Expand Up @@ -86,7 +85,7 @@ async def handle_speech_started_event(self):
"content_index": 0,
"audio_end_ms": elapsed_time,
}
await self.client.openai_ws.send(json.dumps(truncate_event))
await self._client.openai_ws.send(json.dumps(truncate_event))

await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})

Expand All @@ -101,7 +100,7 @@ async def send_mark(self):
self.mark_queue.append("responsePart")

async def run(self):
openai_ws = self.client.openai_ws
openai_ws = self._client._openai_ws
await self.initialize_session()

async for message in self.websocket.iter_text():
Expand All @@ -126,4 +125,4 @@ async def initialize_session(self):
"input_audio_format": "g711_ulaw",
"output_audio_format": "g711_ulaw",
}
await self.client.session_update(session_update)
await self._client.session_update(session_update)

0 comments on commit 299c862

Please sign in to comment.