Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RealtimeAgent for Real-Time Conversational AI Support in ag2 Framework #196

Merged
merged 62 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
a3f0676
WIP
sternakt Dec 12, 2024
793bf68
WIP: Draft initial version of realtime agent
sternakt Dec 12, 2024
24ea3c2
remove openai_ws from observers (half way through)
davorinrusevljan Dec 12, 2024
b79e8a1
Merge pull request #198 from davorinrusevljan/realtime-agent-rush
sternakt Dec 13, 2024
942e3d0
WIP add handover registration to realtime agent
sternakt Dec 13, 2024
efa4c8b
Merge branch 'realtime-agent' of github.com:ag2ai/ag2 into realtime-a…
sternakt Dec 13, 2024
97f12e0
Implement function calling and register handover to realtime agent
sternakt Dec 13, 2024
96cc3bd
Run pre-commit
sternakt Dec 13, 2024
6145e62
Separated client out from RealtimeAgent
davorinrusevljan Dec 13, 2024
299c862
Merge pull request #206 from davorinrusevljan/realtime-agent-rush
sternakt Dec 13, 2024
3469f13
WIP: Integrate swarm into RealtimeAgent
sternakt Dec 17, 2024
1a26555
wip: refactoring a bit
davorrunje Dec 17, 2024
0b3aa50
WIP
sternakt Dec 17, 2024
995f0fd
wip: refactoring
davorrunje Dec 17, 2024
85fa907
Rework into anyio
sternakt Dec 17, 2024
6b058a6
Merge pull request #225 from ag2ai/realtime-agent-refactoring-anyio
sternakt Dec 17, 2024
99ed595
WIP: Cleanup example notebook
sternakt Dec 17, 2024
1681bb8
WIP: Cleanup example notebook
sternakt Dec 17, 2024
75af6b8
WIP
sternakt Dec 18, 2024
5bd052d
Add question polling
sternakt Dec 18, 2024
e8fc5a2
Sync question asking
sternakt Dec 18, 2024
1201f0c
Replace prints with logging
sternakt Dec 18, 2024
aed9fd5
Merge pull request #231 from ag2ai/realtime-agent-start-swarm-at-startup
sternakt Dec 18, 2024
90e0516
Init RealtimeAgent blogpost
sternakt Dec 18, 2024
39d5a0e
Add realtime agent swarm image
sternakt Dec 18, 2024
d3e1864
WIP: blog
sternakt Dec 19, 2024
147f171
Prepare blogpost
sternakt Dec 19, 2024
ff2f47a
Update autogen/agentchat/realtime_agent/realtime_agent.py
sternakt Dec 19, 2024
e2153fb
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
f04deb2
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
6e4ed4f
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
da27ec8
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
3c5eb4d
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
1ac6dbc
Update notebook/agentchat_realtime_swarm.ipynb
sternakt Dec 19, 2024
4b6229e
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
71a1aa6
Update website/blog/2024-12-18-RealtimeAgent/index.mdx
sternakt Dec 19, 2024
425218c
Revise agentchat_realtime_swarm.ipynb
sternakt Dec 19, 2024
1cc619a
Refactor RealtimeAgent.ask_question
sternakt Dec 19, 2024
4cab66e
Remove RealtimeOpenAIClient mention from the blogpost
sternakt Dec 19, 2024
acef7b4
Add docs
sternakt Dec 19, 2024
c193859
Remove prints
sternakt Dec 19, 2024
bd6d640
realtime deps moved to main deps
davorrunje Dec 19, 2024
f361341
Blog polishing
sternakt Dec 19, 2024
8d5b41a
Merge branch 'realtime-agent' of https://github.com/ag2ai/ag2 into re…
sternakt Dec 19, 2024
21e0c68
Update notebook title
marklysze Dec 19, 2024
7290c06
Twilio spelling corrected
marklysze Dec 19, 2024
449ad8c
Updated setups for ag2 and autogen packages
marklysze Dec 19, 2024
ea01e71
Notebook text tweaks
marklysze Dec 19, 2024
c68633e
Update response message to match video
marklysze Dec 19, 2024
6c01a66
websocket realtime wip(1)
davorinrusevljan Dec 19, 2024
0b023a1
websocket realtime wip(2)
davorinrusevljan Dec 19, 2024
4ae306e
websocket realtime wip(3)
davorinrusevljan Dec 19, 2024
ba15132
websocket realtime wip(4)
davorinrusevljan Dec 19, 2024
972b53b
websocket realtime wip(5)
davorinrusevljan Dec 19, 2024
9d959db
websocket realtime wip(6)
davorinrusevljan Dec 19, 2024
0ec36f1
websocket realtime wip(7)
davorinrusevljan Dec 19, 2024
8344e79
Merge pull request #241 from davorinrusevljan/realtime-agent-rush
davorrunje Dec 19, 2024
6da2e00
Merge remote-tracking branch 'origin/main' into realtime-agent
davorrunje Dec 19, 2024
65c5574
pre-commit run on all files
davorrunje Dec 19, 2024
a2aefe9
Merge remote-tracking branch 'origin/main' into realtime-agent
davorrunje Dec 19, 2024
1342f75
websockets upgraded to 14.x
davorrunje Dec 19, 2024
d3fa24d
merge with main
davorrunje Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions autogen/agentchat/realtime_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .twilio_observer import TwilioAudioAdapter

__all__ = [
"RealtimeAgent",
"FunctionObserver",
"TwilioAudioAdapter",
]
75 changes: 75 additions & 0 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import json
from abc import ABC, abstractmethod

import websockets

from .function_observer import FunctionObserver


class Client(ABC):
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
self._agent = agent
self._observers = []
self._openai_ws = None # todo factor out to OpenAIClient
self.register(audio_adapter)
self.register(function_observer)

def register(self, observer):
observer.register_client(self)
self._observers.append(observer)

async def notify_observers(self, message):
for observer in self._observers:
await observer.update(message)

async def function_result(self, call_id, result):
result_item = {
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

# todo override in specific clients
async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._agent.voice,
"instructions": self._agent.system_message,
"modalities": ["text", "audio"],
"temperature": 0.8,
}
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))

async def _read_from_client(self):
try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")

async def run(self):

async with websockets.connect(
"wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
additional_headers={
"Authorization": f"Bearer {self._agent.llm_config['config_list'][0]['api_key']}",
"OpenAI-Beta": "realtime=v1",
},
) as openai_ws:
self._openai_ws = openai_ws
await self.initialize_session()
await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers])
39 changes: 39 additions & 0 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import json

from .realtime_observer import RealtimeObserver


class FunctionObserver(RealtimeObserver):
def __init__(self, agent):
super().__init__()
self._agent = agent

async def update(self, response):
if response.get("type") == "response.function_call_arguments.done":
print("!" * 50)
print(f"Received event: {response['type']}", response)
await self.call_function(
call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"])
)

async def call_function(self, call_id, name, kwargs):
_, func = self._agent.registered_functions[name]
await self._client.function_result(call_id, func(**kwargs))

async def run(self):
await self.initialize_session()

async def initialize_session(self):
"""Add tool to OpenAI."""
session_update = {
"tools": [schema for schema, _ in self._agent.registered_functions.values()],
"tool_choice": "auto",
}
await self._client.session_update(session_update)
85 changes: 85 additions & 0 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import asyncio
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union

import websockets

from autogen.agentchat.agent import Agent, LLMAgent
from autogen.function_utils import get_function_schema

from .client import Client
from .function_observer import FunctionObserver
from .realtime_observer import RealtimeObserver

F = TypeVar("F", bound=Callable[..., Any])


class RealtimeAgent(LLMAgent):
def __init__(
self,
name: str,
audio_adapter: RealtimeObserver,
system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
voice: str = "alloy",
):

self._client = Client(self, audio_adapter, FunctionObserver(self))
self.llm_config = llm_config
self.voice = voice
self.registered_functions = {}

self._oai_system_message = [{"content": system_message, "role": "system"}] # todo still needed?

async def run(self):
await self._client.run()

def register_handover(
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
self,
*,
description: str,
name: Optional[str] = None,
) -> Callable[[F], F]:
def _decorator(func: F, name=name) -> F:
"""Decorator for registering a function to be used by an agent.

Args:
func: the function to be registered.

Returns:
The function to be registered, with the _description attribute set to the function description.

Raises:
ValueError: if the function description is not provided and not propagated by a previous decorator.
RuntimeError: if the LLM config is not set up before registering a function.

"""
# get JSON schema for the function
name = name or func.__name__

schema = get_function_schema(func, name=name, description=description)["function"]
schema["type"] = "function"

self.registered_functions[name] = (schema, func)

return func

return _decorator
24 changes: 24 additions & 0 deletions autogen/agentchat/realtime_agent/realtime_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

from abc import ABC, abstractmethod


class RealtimeObserver(ABC):
def __init__(self):
self._client = None

def register_client(self, client):
self._client = client

@abstractmethod
async def run(self, openai_ws):
pass

@abstractmethod
async def update(self, message):
pass
128 changes: 128 additions & 0 deletions autogen/agentchat/realtime_agent/twilio_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import base64
import json

from fastapi import WebSocketDisconnect

from .realtime_observer import RealtimeObserver

LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False


class TwilioAudioAdapter(RealtimeObserver):
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, websocket):
super().__init__()
self.websocket = websocket

# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
self.mark_queue = []
self.response_start_timestamp_twilio = None

async def update(self, response):
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
if response["type"] in LOG_EVENT_TYPES:
print(f"Received event: {response['type']}", response)

if response.get("type") == "response.audio.delta" and "delta" in response:
audio_payload = base64.b64encode(base64.b64decode(response["delta"])).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)

if self.response_start_timestamp_twilio is None:
self.response_start_timestamp_twilio = self.latest_media_timestamp
if SHOW_TIMING_MATH:
print(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")

# Update last_assistant_item safely
if response.get("item_id"):
self.last_assistant_item = response["item_id"]

await self.send_mark()

# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if response.get("type") == "input_audio_buffer.speech_started":
print("Speech started detected.")
if self.last_assistant_item:
print(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()

async def handle_speech_started_event(self):
"""Handle interruption when the caller's speech starts."""
print("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_twilio is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
if SHOW_TIMING_MATH:
print(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
)

if self.last_assistant_item:
if SHOW_TIMING_MATH:
print(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")

truncate_event = {
"type": "conversation.item.truncate",
"item_id": self.last_assistant_item,
"content_index": 0,
"audio_end_ms": elapsed_time,
}
await self._client.openai_ws.send(json.dumps(truncate_event))

await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})

self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_twilio = None

async def send_mark(self):
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")

async def run(self):
openai_ws = self._client._openai_ws
await self.initialize_session()

async for message in self.websocket.iter_text():
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
audio_append = {"type": "input_audio_buffer.append", "audio": data["media"]["payload"]}
await openai_ws.send(json.dumps(audio_append))
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
print(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_twilio = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)

async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"input_audio_format": "g711_ulaw",
"output_audio_format": "g711_ulaw",
}
await self._client.session_update(session_update)
Loading
Loading