Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into add-tool-imports
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Dec 19, 2024
2 parents 5a58008 + 7f9a2b6 commit ec6d6be
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 45 deletions.
167 changes: 134 additions & 33 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"

__TOOL_EXECUTOR_NAME__ = "Tool_Execution"


class AfterWorkOption(Enum):
TERMINATE = "TERMINATE"
Expand All @@ -36,6 +38,14 @@ class AfterWorkOption(Enum):

@dataclass
class AFTER_WORK:
"""Handles the next step in the conversation when an agent doesn't suggest a tool call or a handoff
Args:
agent: The agent to hand off to or the after work option. Can be a SwarmAgent, a string name of a SwarmAgent, an AfterWorkOption, or a Callable.
The Callable signature is:
def my_after_work_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]:
"""

agent: Union[AfterWorkOption, "SwarmAgent", str, Callable]

def __post_init__(self):
Expand All @@ -45,8 +55,20 @@ def __post_init__(self):

@dataclass
class ON_CONDITION:
"""Defines a condition for transitioning to another agent or nested chats
Args:
target: The agent to hand off to or the nested chat configuration. Can be a SwarmAgent or a Dict.
If a Dict, it should follow the convention of the nested chat configuration, with the exception of a carryover configuration which is unique to Swarms.
Swarm Nested chat documentation: https://ag2ai.github.io/ag2/docs/topics/swarm#registering-handoffs-to-a-nested-chat
condition: The condition for transitioning to the target agent, evaluated by the LLM to determine whether to call the underlying function/tool which does the transition.
available: Optional condition to determine if this ON_CONDITION is available. Can be a Callable or a string.
If a string, it will look up the value of the context variable with that name, which should be a bool.
"""

target: Union["SwarmAgent", Dict[str, Any]] = None
condition: str = ""
available: Optional[Union[Callable, str]] = None

def __post_init__(self):
# Ensure valid types
Expand All @@ -58,9 +80,21 @@ def __post_init__(self):
# Ensure they have a condition
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"

if self.available is not None:
assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string"


@dataclass
class UPDATE_SYSTEM_MESSAGE:
"""Update the agent's system message before they reply
Args:
update_function: The string or function to update the agent's system message. Can be a string or a Callable.
If a string, it will be used as a template and substitute the context variables.
If a Callable, it should have the signature:
def my_update_function(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
"""

update_function: Union[Callable, str]

def __post_init__(self):
Expand Down Expand Up @@ -107,9 +141,9 @@ def initiate_swarm_chat(
- REVERT_TO_USER : Revert to the user agent if a user agent is provided. If not provided, terminate the conversation.
- STAY : Stay with the last speaker.
Callable: A custom function that takes the current agent, messages, groupchat, and context_variables as arguments and returns the next agent. The function should return None to terminate.
Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a SwarmAgent (by reference or string name).
```python
def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat, context_variables: Optional[Dict[str, Any]]) -> Optional[SwarmAgent]:
def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]:
```
Returns:
ChatResult: Conversations chat history.
Expand All @@ -129,7 +163,7 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
messages = [{"role": "user", "content": messages}]

tool_execution = SwarmAgent(
name="Tool_Execution",
name=__TOOL_EXECUTOR_NAME__,
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution()
Expand All @@ -138,11 +172,6 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
for agent in agents:
tool_execution._function_map.update(agent._function_map)

# Point all SwarmAgent's context variables to this function's context_variables
# providing a single (shared) context across all SwarmAgents in the swarm
for agent in agents + [tool_execution]:
agent._context_variables = context_variables

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
Expand Down Expand Up @@ -191,34 +220,40 @@ def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat):
if (user_agent and last_speaker == user_agent) or groupchat.messages[-1]["role"] == "tool":
return last_swarm_speaker

# No agent selected via hand-offs (tool calls)
# Assume the work is Done
# override if agent-level after_work is defined, else use the global after_work
tmp_after_work = last_swarm_speaker.after_work if last_swarm_speaker.after_work is not None else after_work
if isinstance(tmp_after_work, AFTER_WORK):
tmp_after_work = tmp_after_work.agent

if isinstance(tmp_after_work, SwarmAgent):
return tmp_after_work
elif isinstance(tmp_after_work, AfterWorkOption):
if tmp_after_work == AfterWorkOption.TERMINATE or (
user_agent is None and tmp_after_work == AfterWorkOption.REVERT_TO_USER
):
# Resolve after_work condition (agent-level overrides global)
after_work_condition = (
last_swarm_speaker.after_work if last_swarm_speaker.after_work is not None else after_work
)
if isinstance(after_work_condition, AFTER_WORK):
after_work_condition = after_work_condition.agent

# Evaluate callable after_work
if isinstance(after_work_condition, Callable):
after_work_condition = after_work_condition(last_speaker, groupchat.messages, groupchat)

if isinstance(after_work_condition, str): # Agent name in a string
if after_work_condition in swarm_agent_names:
return groupchat.agent_by_name(name=after_work_condition)
else:
raise ValueError(f"Invalid agent name in after_work: {after_work_condition}")
elif isinstance(after_work_condition, SwarmAgent):
return after_work_condition
elif isinstance(after_work_condition, AfterWorkOption):
if after_work_condition == AfterWorkOption.TERMINATE:
return None
elif tmp_after_work == AfterWorkOption.REVERT_TO_USER:
return user_agent
elif tmp_after_work == AfterWorkOption.STAY:
elif after_work_condition == AfterWorkOption.REVERT_TO_USER:
return None if user_agent is None else user_agent
elif after_work_condition == AfterWorkOption.STAY:
return last_speaker
elif isinstance(tmp_after_work, Callable):
return tmp_after_work(last_speaker, groupchat.messages, groupchat, context_variables)
else:
raise ValueError("Invalid After Work condition")
raise ValueError("Invalid After Work condition or return value from callable")

def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]):
"""Create nested chat agents and register nested chats"""
for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs):
nested_chats: Dict[str, Any] = nested_chat_handoff["nested_chats"]
condition = nested_chat_handoff["condition"]
available = nested_chat_handoff["available"]

# Create a nested chat agent specifically for this nested chat
nested_chat_agent = SwarmAgent(name=f"nested_chat_{agent.name}_{i + 1}")
Expand All @@ -239,7 +274,7 @@ def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent])
nested_chat_agents.append(nested_chat_agent)

# Nested chat is triggered through an agent transfer to this nested chat agent
agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition))
agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition, available))

nested_chat_agents = []
for agent in agents:
Expand All @@ -249,6 +284,10 @@ def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent])
for agent in agents + nested_chat_agents:
tool_execution._function_map.update(agent._function_map)

# Add conditional functions to the tool_execution agent
for func_name, (func, on_condition) in agent._conditional_functions.items():
tool_execution._function_map[func_name] = func

swarm_agent_names = [agent.name for agent in agents + nested_chat_agents]

# If there's only one message and there's no identified swarm agent
Expand All @@ -270,6 +309,11 @@ def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent])
manager = GroupChatManager(groupchat)
clear_history = True

# Point all SwarmAgent's context variables to this function's context_variables
# providing a single (shared) context across all SwarmAgents in the swarm
for agent in agents + [tool_execution] + [manager]:
agent._context_variables = context_variables

if len(messages) > 1:
last_agent, last_message = manager.resume(messages=messages)
clear_history = False
Expand Down Expand Up @@ -337,6 +381,7 @@ class SwarmAgent(ConversableAgent):
Additional args:
functions (List[Callable]): A list of functions to register with the agent.
update_agent_state_before_reply (List[Callable]): A list of functions, including UPDATE_SYSTEM_MESSAGEs, called to update the agent before it replies.
"""

def __init__(
Expand Down Expand Up @@ -387,6 +432,13 @@ def __init__(

self.register_update_agent_state_before_reply(update_agent_state_before_reply)

# Store conditional functions (and their ON_CONDITION instances) to add/remove later when transitioning to this agent
self._conditional_functions = {}

# Register the hook to update agent state (except tool executor)
if name != __TOOL_EXECUTOR_NAME__:
self.register_hook("update_agent_state", self._update_conditional_functions)

def register_update_agent_state_before_reply(self, functions: Optional[Union[List[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
Expand Down Expand Up @@ -490,16 +542,50 @@ def transfer_to_agent() -> "SwarmAgent":
return transfer_to_agent

transfer_func = make_transfer_function(transit)
self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition)

# Store function to add/remove later based on it being 'available'
# Function names are made unique and allow multiple ON_CONDITIONS to the same agent
base_func_name = f"transfer_{self.name}_to_{transit.target.name}"
func_name = base_func_name
count = 2
while func_name in self._conditional_functions:
func_name = f"{base_func_name}_{count}"
count += 1

# Store function to add/remove later based on it being 'available'
self._conditional_functions[func_name] = (transfer_func, transit)

elif isinstance(transit.target, Dict):
# Transition to a nested chat
# We will store them here and establish them in the initiate_swarm_chat
self._nested_chat_handoffs.append({"nested_chats": transit.target, "condition": transit.condition})
self._nested_chat_handoffs.append(
{"nested_chats": transit.target, "condition": transit.condition, "available": transit.available}
)

else:
raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK")

@staticmethod
def _update_conditional_functions(agent: Agent, messages: Optional[List[Dict]] = None) -> None:
"""Updates the agent's functions based on the ON_CONDITION's available condition."""
for func_name, (func, on_condition) in agent._conditional_functions.items():
is_available = True

if on_condition.available is not None:
if isinstance(on_condition.available, Callable):
is_available = on_condition.available(agent, next(iter(agent.chat_messages.values())))
elif isinstance(on_condition.available, str):
is_available = agent.get_context(on_condition.available) or False

if is_available:
if func_name not in agent._function_map:
agent.add_single_function(func, func_name, on_condition.condition)
else:
# Remove function using the stored name
if func_name in agent._function_map:
agent.update_tool_signature(func_name, is_remove=True)
del agent._function_map[func_name]

def generate_swarm_tool_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -621,6 +707,7 @@ def process_nested_chat_carryover(
recipient: ConversableAgent,
messages: List[Dict[str, Any]],
sender: ConversableAgent,
config: Any,
trim_n_messages: int = 0,
) -> None:
"""Process carryover messages for a nested chat (typically for the first chat of a swarm)
Expand Down Expand Up @@ -666,7 +753,12 @@ def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[
carryover_summary_method = carryover_config["summary_method"]
carryover_summary_args = carryover_config.get("summary_args") or {}

chat_message = chat.get("message", "")
chat_message = ""
message = chat.get("message")

# If the message is a callable, run it and get the result
if message:
chat_message = message(recipient, messages, sender, config) if callable(message) else message

# deep copy and trim the latest messages
content_messages = copy.deepcopy(messages)
Expand Down Expand Up @@ -725,15 +817,24 @@ def _summary_from_nested_chats(
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""

# Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent
restore_chat_queue_message = False
if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]:
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, 2)
if "message" in chat_queue[0]:
# As we're updating the message in the nested chat queue, we need to restore it after finishing this nested chat.
restore_chat_queue_message = True
original_chat_queue_message = chat_queue[0]["message"]
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, config, 2)

chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = sender.initiate_chats(chat_to_run)

# We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses
if restore_chat_queue_message:
chat_queue[0]["message"] = original_chat_queue_message

return True, res[-1].summary


Expand Down
Loading

0 comments on commit ec6d6be

Please sign in to comment.