Skip to content

Commit

Permalink
Merge pull request #107 from ag2ai/swarmnestedchat
Browse files Browse the repository at this point in the history
Swarm: Support for nested chat as a hand off
  • Loading branch information
qingyun-wu authored Dec 5, 2024
2 parents b0effa6 + 4deb3c5 commit 44ea19e
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 25 deletions.
214 changes: 195 additions & 19 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import copy
import inspect
import json
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -42,12 +43,18 @@ def __post_init__(self):

@dataclass
class ON_CONDITION:
agent: "SwarmAgent"
target: Union["SwarmAgent", Dict[str, Any]] = None
condition: str = ""

# Ensure that agent is a SwarmAgent
def __post_init__(self):
assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent"
# Ensure valid types
if self.target is not None:
assert isinstance(self.target, SwarmAgent) or isinstance(
self.target, Dict
), "'target' must be a SwarmAgent or a Dict"

# Ensure they have a condition
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"


def initiate_swarm_chat(
Expand Down Expand Up @@ -96,18 +103,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

swarm_agent_names = [agent.name for agent in agents]

tool_execution = SwarmAgent(
name="Tool_Execution",
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution(context_variables=context_variables)

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
Expand Down Expand Up @@ -173,6 +174,43 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
else:
raise ValueError("Invalid After Work condition")

def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]):
"""Create nested chat agents and register nested chats"""
for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs):
nested_chats: Dict[str, Any] = nested_chat_handoff["nested_chats"]
condition = nested_chat_handoff["condition"]

# Create a nested chat agent specifically for this nested chat
nested_chat_agent = SwarmAgent(name=f"nested_chat_{agent.name}_{i + 1}")

nested_chat_agent.register_nested_chats(
nested_chats["chat_queue"],
reply_func_from_nested_chats=nested_chats.get("reply_func_from_nested_chats")
or "summary_from_nested_chats",
config=nested_chats.get("config", None),
trigger=lambda sender: True,
position=0,
use_async=nested_chats.get("use_async", False),
)

# After the nested chat is complete, transfer back to the parent agent
nested_chat_agent.register_hand_off(AFTER_WORK(agent=agent))

nested_chat_agents.append(nested_chat_agent)

# Nested chat is triggered through an agent transfer to this nested chat agent
agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition))

nested_chat_agents = []
for agent in agents:
create_nested_chats(agent, nested_chat_agents)

# Update tool execution agent with all the functions from all the agents
for agent in agents + nested_chat_agents:
tool_execution._function_map.update(agent._function_map)

swarm_agent_names = [agent.name for agent in agents + nested_chat_agents]

# If there's only one message and there's no identified swarm agent
# Start with a user proxy agent, creating one if they haven't passed one in
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
Expand All @@ -181,7 +219,10 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
temp_user_proxy = []

groupchat = GroupChat(
agents=[tool_execution] + agents + ([user_agent] if user_agent is not None else temp_user_proxy),
agents=[tool_execution]
+ agents
+ nested_chat_agents
+ ([user_agent] if user_agent is not None else temp_user_proxy),
messages=[], # Set to empty. We will resume the conversation with the messages
max_round=max_rounds,
speaker_selection_method=swarm_transition,
Expand Down Expand Up @@ -294,10 +335,15 @@ def __init__(

self.after_work = None

# use in the tool execution agent to transfer to the next agent
# Used only in the tool execution agent for context and transferring to the next agent
# Note: context variables are not stored for each agent
self._context_variables = {}
self._next_agent = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
Expand Down Expand Up @@ -342,16 +388,25 @@ def transfer_to_agent_name() -> SwarmAgent:
self.after_work = transit
elif isinstance(transit, ON_CONDITION):

# Create closure with current loop transit value
# to ensure the condition matches the one in the loop
def make_transfer_function(current_transit):
def transfer_to_agent() -> "SwarmAgent":
return current_transit.agent
if isinstance(transit.target, SwarmAgent):
# Transition to agent

# Create closure with current loop transit value
# to ensure the condition matches the one in the loop
def make_transfer_function(current_transit: ON_CONDITION):
def transfer_to_agent() -> "SwarmAgent":
return current_transit.target

return transfer_to_agent

return transfer_to_agent
transfer_func = make_transfer_function(transit)
self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition)

elif isinstance(transit.target, Dict):
# Transition to a nested chat
# We will store them here and establish them in the initiate_swarm_chat
self._nested_chat_handoffs.append({"nested_chats": transit.target, "condition": transit.condition})

transfer_func = make_transfer_function(transit)
self.add_single_function(transfer_func, f"transfer_to_{transit.agent.name}", transit.condition)
else:
raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK")

Expand Down Expand Up @@ -469,6 +524,127 @@ def add_functions(self, func_list: List[Callable]):
for func in func_list:
self.add_single_function(func)

@staticmethod
def process_nested_chat_carryover(
chat: Dict[str, Any],
recipient: ConversableAgent,
messages: List[Dict[str, Any]],
sender: ConversableAgent,
trim_n_messages: int = 0,
) -> None:
"""Process carryover messages for a nested chat (typically for the first chat of a swarm)
The carryover_config key is a dictionary containing:
"summary_method": The method to use to summarise the messages, can be "all", "last_msg", "reflection_with_llm" or a Callable
"summary_args": Optional arguments for the summary method
Supported carryover 'summary_methods' are:
"all" - all messages will be incorporated
"last_msg" - the last message will be incorporated
"reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message
Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
Args:
chat: The chat dictionary containing the carryover configuration
recipient: The recipient agent
messages: The messages from the parent chat
sender: The sender agent
trim_n_messages: The number of latest messages to trim from the messages list
"""

def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[str, Any]]]) -> str:
"""Concatenate the carryover message to the chat message."""
prefix = f"{chat_message}\n" if chat_message else ""

if isinstance(carryover_message, str):
content = carryover_message
elif isinstance(carryover_message, list):
content = "\n".join(
msg["content"] for msg in carryover_message if "content" in msg and msg["content"] is not None
)
else:
raise ValueError("Carryover message must be a string or a list of dictionaries")

return f"{prefix}Context:\n{content}"

carryover_config = chat["carryover_config"]

if "summary_method" not in carryover_config:
raise ValueError("Carryover configuration must contain a 'summary_method' key")

carryover_summary_method = carryover_config["summary_method"]
carryover_summary_args = carryover_config.get("summary_args") or {}

chat_message = chat.get("message", "")

# deep copy and trim the latest messages
content_messages = copy.deepcopy(messages)
content_messages = content_messages[:-trim_n_messages]

if carryover_summary_method == "all":
# Put a string concatenated value of all parent messages into the first message
# (e.g. message = <first nested chat message>\nContext: \n<swarm message 1>\n<swarm message 2>\n...)
carry_over_message = concat_carryover(chat_message, content_messages)

elif carryover_summary_method == "last_msg":
# (e.g. message = <first nested chat message>\nContext: \n<last swarm message>)
carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"])

elif carryover_summary_method == "reflection_with_llm":
# (e.g. message = <first nested chat message>\nContext: \n<llm summary>)

# Add the messages to the nested chat agent for reflection (we'll clear after reflection)
chat["recipient"]._oai_messages[sender] = content_messages

carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary(
sender=sender,
recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection
summary_args=carryover_summary_args,
)

recipient._oai_messages[sender] = []

carry_over_message = concat_carryover(chat_message, carry_over_message_llm)

elif isinstance(carryover_summary_method, Callable):
# (e.g. message = <first nested chat message>\nContext: \n<function's return string>)
carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args)

carry_over_message = concat_carryover(chat_message, carry_over_message_result)

chat["message"] = carry_over_message

@staticmethod
def _summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""Overridden _summary_from_nested_chats method from ConversableAgent.
This function initiates one or a sequence of chats between the "recipient" and the agents in the chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Swarm Updates:
- the 'messages' parameter contains the parent chat's messages
- the first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the swarm chat into the first chat of the nested chats). Only applies to the first chat.
e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None}
summary_method can be "last_msg", "all", "reflection_with_llm", Callable
The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
The summary will be concatenated to the message of the first chat in the queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""

# Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent
if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]:
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, 2)

chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = sender.initiate_chats(chat_to_run)
return True, res[-1].summary


# Forward references for SwarmAgent in SwarmResult
SwarmResult.update_forward_refs()
10 changes: 5 additions & 5 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_on_condition():

# Test with a ConversableAgent
test_conversable_agent = ConversableAgent("test_conversable_agent")
with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"):
_ = ON_CONDITION(agent=test_conversable_agent, condition="test condition")
with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"):
_ = ON_CONDITION(target=test_conversable_agent, condition="test condition")


def test_receiving_agent():
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_on_condition_handoff():
agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)

agent1.register_hand_off(hand_to=ON_CONDITION(agent2, "always take me to agent 2"))
agent1.register_hand_off(hand_to=ON_CONDITION(target=agent2, condition="always take me to agent 2"))

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
Expand Down Expand Up @@ -428,8 +428,8 @@ def test_non_swarm_in_hand_off():
with pytest.raises(AssertionError, match="Invalid After Work value"):
agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent))

with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"):
agent1.register_hand_off(hand_to=ON_CONDITION(bad_agent, "Testing"))
with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"):
agent1.register_hand_off(hand_to=ON_CONDITION(target=bad_agent, condition="Testing"))

with pytest.raises(ValueError, match="hand_to must be a list of ON_CONDITION or AFTER_WORK"):
agent1.register_hand_off(0)
Expand Down
Loading

0 comments on commit 44ea19e

Please sign in to comment.