Skip to content

Commit

Permalink
Merge pull request #231 from ag2ai/realtime-agent-start-swarm-at-startup
Browse files Browse the repository at this point in the history
Realtime agent start swarm at startup
  • Loading branch information
sternakt authored Dec 18, 2024
2 parents 1681bb8 + 1201f0c commit aed9fd5
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 58 deletions.
29 changes: 20 additions & 9 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

# import asyncio
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional

import anyio
import websockets
from asyncer import TaskGroup, asyncify, create_task_group, syncify

from autogen.agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat

from .function_observer import FunctionObserver

logger = logging.getLogger(__name__)


class Client(ABC):
def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
Expand Down Expand Up @@ -56,10 +61,11 @@ async def function_result(self, call_id, result):
await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

async def send_text(self, text: str):
async def send_text(self, *, role: str, text: str):
await self._openai_ws.send(json.dumps({"type": "response.cancel"}))
text_item = {
"type": "conversation.item.create",
"item": {"type": "message", "role": "system", "content": [{"type": "input_text", "text": text}]},
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
}
await self._openai_ws.send(json.dumps(text_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))
Expand All @@ -72,25 +78,25 @@ async def initialize_session(self):
"turn_detection": {"type": "server_vad"},
"voice": self._agent.voice,
"instructions": self._agent.system_message,
"modalities": ["audio"],
"modalities": ["audio", "text"],
"temperature": 0.8,
}
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
update = {"type": "session.update", "session": session_options}
print("Sending session update:", json.dumps(update), flush=True)
logger.info("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))
print("Sending session update finished", flush=True)
logger.info("Sending session update finished")

async def _read_from_client(self):
try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
print(f"Error in _read_from_client: {e}")
logger.warning(f"Error in _read_from_client: {e}")

async def run(self):
async with websockets.connect(
Expand All @@ -108,6 +114,11 @@ async def run(self):
self.tg.soonify(self._read_from_client)()
for observer in self._observers:
self.tg.soonify(observer.run)()

def run_task(self, task, *args: Any, **kwargs: Any):
self.tg.soonify(task)(*args, **kwargs)
if self._agent._start_swarm_chat:
self.tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._agent._initial_agent,
agents=self._agent._agents,
user_agent=self._agent,
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
7 changes: 5 additions & 2 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

import asyncio
import json
import logging

from asyncer import asyncify
from pydantic import BaseModel

from .realtime_observer import RealtimeObserver

logger = logging.getLogger(__name__)


class FunctionObserver(RealtimeObserver):
def __init__(self, agent):
Expand All @@ -21,7 +24,7 @@ def __init__(self, agent):

async def update(self, response):
if response.get("type") == "response.function_call_arguments.done":
print(f"Received event: {response['type']}", response)
logger.info(f"Received event: {response['type']}", response)
await self.call_function(
call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"])
)
Expand All @@ -34,13 +37,13 @@ async def call_function(self, call_id, name, kwargs):
result = await func(**kwargs)
except Exception:
result = "Function call failed"
logger.warning(f"Function call failed: {name}")

if isinstance(result, BaseModel):
result = result.model_dump_json()
elif not isinstance(result, str):
result = json.dumps(result)

print(f"Function call result: {result}")
await self._client.function_result(call_id, result)

async def run(self):
Expand Down
86 changes: 47 additions & 39 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# SPDX-License-Identifier: MIT

# import asyncio
import asyncio
import json
import logging
from abc import ABC, abstractmethod
Expand All @@ -29,12 +30,16 @@
logger = logging.getLogger(__name__)

SWARM_SYSTEM_MESSAGE = (
"You are a helpful voice assistant. Your task is to listen to user and to create tasks based on his/her inputs. E.g. if a user wishes to make change to his flight, you should create a task for it.\n"
"DO NOT ask any additional information about the task from the user. Start the task as soon as possible.\n"
"You have to assume that every task can be successfully completed by the swarm of agents and your only role is to create tasks for them.\n"
# "While the task is being executed, please keep the user on the line and inform him/her about the progress by calling the 'get_task_status' function. You might also get additional questions or status reports from the agents working on the task.\n"
# "Once the task is done, inform the user that the task is completed and ask if you can help with anything else.\n"
"Do not create unethical or illegal tasks.\n"
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
"Only call the 'answer_task_question' function when you have the answer from the user."
"You can communicate and will communicate using audio output only."
)

QUESTION_ROLE = "user"
QUESTION_MESSAGE = (
"I have a question/information for the myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
"repeat the question to me **WITH AUDIO OUTPUT** and then call 'answer_task_question' AFTER YOU GET THE ANSWER FROM ME\n\n"
"The question is: '{}'\n\n"
)


Expand Down Expand Up @@ -85,6 +90,9 @@ def __init__(

self._answer_event: anyio.Event = anyio.Event()
self._answer: str = ""
self._start_swarm_chat = False
self._initial_agent = None
self._agents = None

def register_swarm(
self,
Expand All @@ -102,32 +110,23 @@ def register_swarm(

self._oai_system_message = [{"content": system_message, "role": "system"}]

@self.register_handover(name="create_task", description="Create a task given by the user")
async def create_task(task_input: str) -> str:
self._client.run_task(
asyncify(initiate_swarm_chat),
initial_agent=initial_agent,
agents=agents,
user_agent=self,
messages=task_input,
after_work=AfterWorkOption.REVERT_TO_USER,
)

return "Task created successfully."

def _get_task_status(task_id: str) -> Generator[None, str, None]:
while True:
for s in [
"The task is in progress, agents are working on it. ETA is 1 minute",
"The task is successfully completed.",
]:
yield s

it = _get_task_status("task_id")

@self.register_handover(name="get_task_status", description="Get the status of the task")
async def get_task_status(task_id: str) -> str:
return next(it)
self._start_swarm_chat = True
self._initial_agent = initial_agent
self._agents = agents

# def _get_task_status(task_id: str) -> Generator[None, str, None]:
# while True:
# for s in [
# "The task is in progress, agents are working on it. ETA is 1 minute",
# "The task is successfully completed.",
# ]:
# yield s

# it = _get_task_status("task_id")

# @self.register_handover(name="get_task_status", description="Get the status of the task")
# async def get_task_status(task_id: str) -> str:
# return next(it)

self.register_handover(name="answer_task_question", description="Answer question from the task")(
self.set_answer
Expand Down Expand Up @@ -179,6 +178,21 @@ async def get_answer(self) -> str:
await self._answer_event.wait()
return self._answer

async def ask_question(self, question: str, question_timeout: int) -> str:
self.reset_answer()
await anyio.sleep(1)
await self._client.send_text(role=QUESTION_ROLE, text=question)

async def _check_event_set(timeout: int = question_timeout) -> None:
for _ in range(timeout):
if self._answer_event.is_set():
return True
await anyio.sleep(1)
return False

while not await _check_event_set():
await self._client.send_text(role=QUESTION_ROLE, text=question)

def check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -187,13 +201,7 @@ def check_termination_and_human_reply(
) -> Tuple[bool, Union[str, None]]:
async def get_input():
async with create_task_group() as tg:
self.reset_answer()
tg.soonify(self._client.send_text)(
"I have a question/information for the user from the agent working on a task. DO NOT ANSWER YOURSELF, "
"INFORM THE USER **WITH AUDIO OUTPUT** AND THEN CALL 'answer_task_question' TO PROPAGETE THE "
f"USER ANSWER TO THE AGENT WORKING ON THE TASK. The question is: '{messages[-1]['content']}'\n\n",
)
await self.get_answer()
tg.soonify(self.ask_question)(QUESTION_MESSAGE.format(messages[-1]["content"]), 20)

syncify(get_input)()

Expand Down
19 changes: 11 additions & 8 deletions autogen/agentchat/realtime_agent/twilio_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import base64
import json
import logging

from fastapi import WebSocketDisconnect

Expand All @@ -24,6 +25,8 @@
]
SHOW_TIMING_MATH = False

logger = logging.getLogger(__name__)


class TwilioAudioAdapter(RealtimeObserver):
def __init__(self, websocket):
Expand All @@ -40,7 +43,7 @@ def __init__(self, websocket):
async def update(self, response):
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
if response["type"] in LOG_EVENT_TYPES:
print(f"Received event: {response['type']}", response)
logger.info(f"Received event: {response['type']}", response)

if response.get("type") == "response.audio.delta" and "delta" in response:
audio_payload = base64.b64encode(base64.b64decode(response["delta"])).decode("utf-8")
Expand All @@ -50,7 +53,7 @@ async def update(self, response):
if self.response_start_timestamp_twilio is None:
self.response_start_timestamp_twilio = self.latest_media_timestamp
if SHOW_TIMING_MATH:
print(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")

# Update last_assistant_item safely
if response.get("item_id"):
Expand All @@ -60,24 +63,24 @@ async def update(self, response):

# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if response.get("type") == "input_audio_buffer.speech_started":
print("Speech started detected.")
logger.info("Speech started detected.")
if self.last_assistant_item:
print(f"Interrupting response with id: {self.last_assistant_item}")
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()

async def handle_speech_started_event(self):
"""Handle interruption when the caller's speech starts."""
print("Handling speech started event.")
logger.info("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_twilio is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
if SHOW_TIMING_MATH:
print(
logger.info(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
)

if self.last_assistant_item:
if SHOW_TIMING_MATH:
print(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")

truncate_event = {
"type": "conversation.item.truncate",
Expand Down Expand Up @@ -111,7 +114,7 @@ async def run(self):
await openai_ws.send(json.dumps(audio_append))
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
print(f"Incoming stream has started {self.stream_sid}")
logger.info(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_twilio = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
Expand Down
1 change: 1 addition & 0 deletions notebook/agentchat_realtime_swarm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@
"outputs": [],
"source": [
"# import asyncio\n",
"import logging\n",
"from datetime import datetime\n",
"from time import time\n",
"\n",
Expand Down

0 comments on commit aed9fd5

Please sign in to comment.