diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 6c3c12e6ce..c41820bf9b 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -12,6 +12,7 @@ from .contrib.swarm_agent import ( AFTER_WORK, ON_CONDITION, + UPDATE_SYSTEM_MESSAGE, AfterWorkOption, SwarmAgent, SwarmResult, @@ -39,4 +40,5 @@ "ON_CONDITION", "AFTER_WORK", "AfterWorkOption", + "UPDATE_SYSTEM_MESSAGE", ] diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 4e084377c4..c525c53db3 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -4,6 +4,8 @@ import copy import inspect import json +import re +import warnings from dataclasses import dataclass from enum import Enum from inspect import signature @@ -57,6 +59,29 @@ def __post_init__(self): assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string" +@dataclass +class UPDATE_SYSTEM_MESSAGE: + update_function: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.update_function, str): + # find all {var} in the string + vars = re.findall(r"\{(\w+)\}", self.update_function) + if len(vars) == 0: + warnings.warn("Update function string contains no variables. This is probably unintended.") + + elif isinstance(self.update_function, Callable): + sig = signature(self.update_function) + if len(sig.parameters) != 2: + raise ValueError( + "Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively" + ) + if sig.return_annotation != str: + raise ValueError("Update function must return a string") + else: + raise ValueError("Update function must be either a string or a callable") + + def initiate_swarm_chat( initial_agent: "SwarmAgent", messages: Union[List[Dict[str, Any]], str], @@ -107,12 +132,27 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any name="Tool_Execution", system_message="Tool Execution", ) - tool_execution._set_to_tool_execution(context_variables=context_variables) + tool_execution._set_to_tool_execution() + + # Update tool execution agent with all the functions from all the agents + for agent in agents: + tool_execution._function_map.update(agent._function_map) + + # Point all SwarmAgent's context variables to this function's context_variables + # providing a single (shared) context across all SwarmAgents in the swarm + for agent in agents + [tool_execution]: + agent._context_variables = context_variables INIT_AGENT_USED = False def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): - """Swarm transition function to determine the next agent in the conversation""" + """Swarm transition function to determine and prepare the next agent in the conversation""" + next_agent = determine_next_agent(last_speaker, groupchat) + + return next_agent + + def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat): + """Determine the next agent in the conversation""" nonlocal INIT_AGENT_USED if not INIT_AGENT_USED: INIT_AGENT_USED = True @@ -310,6 +350,9 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, + update_agent_state_before_reply: Optional[ + Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] + ] = None, **kwargs, ) -> None: super().__init__( @@ -335,23 +378,70 @@ def __init__( self.after_work = None - # Used only in the tool execution agent for context and transferring to the next agent - # Note: context variables are not stored for each agent - self._context_variables = {} + # Used in the tool execution agent to transfer to the next agent self._next_agent = None # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat # List of Dictionaries containing the nested_chats and condition self._nested_chat_handoffs = [] - def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None): + self.register_update_agent_state_before_reply(update_agent_state_before_reply) + + def register_update_agent_state_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): + """ + Register functions that will be called when the agent is selected and before it speaks. + You can add your own validation or precondition functions here. + + Args: + functions (List[Callable[[], None]]): A list of functions to be registered. Each function + is called when the agent is selected and before it speaks. + """ + if functions is None: + return + if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]: + raise ValueError("functions must be a list of callables") + + if not isinstance(functions, list): + functions = [functions] + + for func in functions: + if isinstance(func, UPDATE_SYSTEM_MESSAGE): + + # Wrapper function that allows this to be used in the update_agent_state hook + # Its primary purpose, however, is just to update the agent's system message + # Outer function to create a closure with the update function + def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): + def update_system_message_wrapper( + agent: ConversableAgent, messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + if isinstance(update_func.update_function, str): + # Templates like "My context variable passport is {passport}" will + # use the context_variables for substitution + sys_message = OpenAIWrapper.instantiate( + template=update_func.update_function, + context=agent._context_variables, + allow_format_str_template=True, + ) + else: + sys_message = update_func.update_function(agent, messages) + + agent.update_system_message(sys_message) + return messages + + return update_system_message_wrapper + + self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) + + else: + self.register_hook(hookable_method="update_agent_state", hook=func) + + def _set_to_tool_execution(self): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. This agent will be used internally and should not be visible to the user. - It will execute the tool calls and update the context_variables and next_agent accordingly. + It will execute the tool calls and update the referenced context_variables and next_agent accordingly. """ self._next_agent = None - self._context_variables = context_variables or {} self._reply_func_list.clear() self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply) @@ -491,6 +581,7 @@ def generate_swarm_tool_reply( return False, None def add_single_function(self, func: Callable, name=None, description=""): + """Add a single function to the agent, removing context variables for LLM use""" if name: func._name = name else: diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index f1fadcca27..ffd6923721 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -264,6 +264,7 @@ def __init__( "process_last_received_message": [], "process_all_messages_before_reply": [], "process_message_before_send": [], + "update_agent_state": [], } def _validate_llm_config(self, llm_config): @@ -2091,6 +2092,9 @@ def generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + self.update_agent_state_before_reply(messages) + # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_last_received_message(messages) @@ -2161,6 +2165,9 @@ async def a_generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + self.update_agent_state_before_reply(messages) + # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_all_messages_before_reply(messages) @@ -2847,6 +2854,18 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) + def update_agent_state_before_reply(self, messages: List[Dict]) -> None: + """ + Calls any registered capability hooks to update the agent's state. + Primarily used to update context variables. + Will, potentially, modify the messages. + """ + hook_list = self.hook_lists["update_agent_state"] + + # Call each hook (in order of registration) to process the messages. + for hook in hook_list: + hook(self, messages) + def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 7f8bf43a7f..ae2f3cf9b9 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +from typing import Any, Dict, List from unittest.mock import MagicMock, patch import pytest @@ -10,6 +10,7 @@ __CONTEXT_VARIABLES_PARAM_NAME__, AFTER_WORK, ON_CONDITION, + UPDATE_SYSTEM_MESSAGE, AfterWorkOption, SwarmAgent, SwarmResult, @@ -461,6 +462,101 @@ def test_initialization(): ) +def test_update_system_message(): + """Tests the update_agent_state_before_reply functionality with multiple scenarios""" + + # Test container to capture system messages + class MessageContainer: + def __init__(self): + self.captured_sys_message = "" + + message_container = MessageContainer() + + # 1. Test with a callable function + def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str: + return f"System message with {agent.get_context('test_var')} and {len(messages)} messages" + + # 2. Test with a string template + template_message = "Template message with {test_var}" + + # Create agents with different update configurations + agent1 = SwarmAgent("agent1", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) + + agent2 = SwarmAgent("agent2", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) + + # Mock the reply function to capture the system message + def mock_generate_oai_reply(*args, **kwargs): + # Capture the system message for verification + message_container.captured_sys_message = args[0]._oai_system_message[0]["content"] + return True, "Mock response" + + # Register mock reply for both agents + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + # Test context and messages + test_context = {"test_var": "test_value"} + test_messages = [{"role": "user", "content": "Test message"}] + + # Run chat with first agent (using callable function) + chat_result1, context_vars1, last_speaker1 = initiate_swarm_chat( + initial_agent=agent1, messages=test_messages, agents=[agent1], context_variables=test_context, max_rounds=2 + ) + + # Verify callable function result + assert message_container.captured_sys_message == "System message with test_value and 1 messages" + + # Reset captured message + message_container.captured_sys_message = "" + + # Run chat with second agent (using string template) + chat_result2, context_vars2, last_speaker2 = initiate_swarm_chat( + initial_agent=agent2, messages=test_messages, agents=[agent2], context_variables=test_context, max_rounds=2 + ) + + # Verify template result + assert message_container.captured_sys_message == "Template message with test_value" + + # Test invalid update function + with pytest.raises(ValueError, match="Update function must be either a string or a callable"): + SwarmAgent("agent3", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(123)) + + # Test invalid callable (wrong number of parameters) + def invalid_update_function(context_variables): + return "Invalid function" + + with pytest.raises(ValueError, match="Update function must accept two parameters"): + SwarmAgent("agent4", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) + + # Test invalid callable (wrong return type) + def invalid_return_function(context_variables, messages) -> dict: + return {} + + with pytest.raises(ValueError, match="Update function must return a string"): + SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) + + # Test multiple update functions + def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: + return "Another update" + + agent6 = SwarmAgent( + "agent6", + update_agent_state_before_reply=[ + UPDATE_SYSTEM_MESSAGE(custom_update_function), + UPDATE_SYSTEM_MESSAGE(another_update_function), + ], + ) + + agent6.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result6, context_vars6, last_speaker6 = initiate_swarm_chat( + initial_agent=agent6, messages=test_messages, agents=[agent6], context_variables=test_context, max_rounds=2 + ) + + # Verify last update function took effect + assert message_container.captured_sys_message == "Another update" + + def test_string_agent_params_for_transfer(): """Test that string agent parameters are handled correctly without using real LLMs.""" # Define test configuration diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 1724eea982..82a0a3cc83 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -159,8 +159,58 @@ "])\n", "\n", "agent_2.handoff(hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)]) # Terminate the chat if no handoff is suggested\n", - "```\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Update Agent state before replying\n", + "\n", + "It can be useful to update a swarm agent's state before they reply. For example, using an agent's context variables you could change their system message based on the state of the workflow.\n", + "\n", + "When initialising a swarm agent use the `update_agent_state_before_reply` parameter to register updates that run after the agent is selected, but before they reply.\n", + "\n", + "`update_agent_state_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", "\n", + "- `UPDATE_SYSTEM_MESSAGE` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", + "- Callable with two parameters of type `ConversableAgent` for the agent and `List[Dict[str Any]]` for the messages, and does not return a value\n", + "\n", + "Below is an example of setting these up when creating a Swarm agent.\n", + "\n", + "```python\n", + "# Creates a system message string\n", + "def create_system_prompt_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> str:\n", + " preferred_name = my_agent.get_context(\"preferred_name\", \"(name not provided)\")\n", + "\n", + " # Note that the returned string will be treated like an f-string using the context variables\n", + " return \"You are a customer service representative helping a customer named \"\n", + " + preferred_name\n", + " + \" and their passport number is '{passport_number}'.\"\n", + "\n", + "# Function to update an Agent's state\n", + "def my_callable_state_update_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> None:\n", + " agent.set_context(\"context_key\", 43)\n", + " agent.update_system_message(\"You are a customer service representative.\")\n", + "\n", + "# Create the SwarmAgent and set agent updates\n", + "customer_service = SwarmAgent(\n", + " name=\"CustomerServiceRep\",\n", + " system_message=\"You are a customer service representative.\",\n", + " update_agent_state_before_reply=[\n", + " UPDATE_SYSTEM_MESSAGE(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", + " UPDATE_SYSTEM_MESSAGE(create_system_prompt_function),\n", + " my_callable_state_update_function]\n", + " ...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "### Initialize SwarmChat with `initiate_swarm_chat`\n", "\n", "After a set of swarm agents are created, you can initiate a swarm chat by calling `initiate_swarm_chat`.\n", @@ -185,7 +235,7 @@ "\n", "> How are context variables updated?\n", "\n", - "The context variables will only be updated through custom function calls when returning a `SwarmResult` object. In fact, all interactions with context variables will be done through function calls (accessing and updating). The context variables dictionary is a reference, and any modification will be done in place.\n", + "In a swarm, the context variables are shared amongst Swarm agents. As context variables are available at the agent level, you can use the context variable getters/setters on the agent to view and change the shared context variables. If you're working with a function that returns a `SwarmResult` you should update the passed in context variables and return it in the `SwarmResult`, this will ensure the shared context is updated.\n", "\n", "> What is the difference between ON_CONDITION and AFTER_WORK?\n", "\n",