diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 6c3c12e6ce..c41820bf9b 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -12,6 +12,7 @@ from .contrib.swarm_agent import ( AFTER_WORK, ON_CONDITION, + UPDATE_SYSTEM_MESSAGE, AfterWorkOption, SwarmAgent, SwarmResult, @@ -39,4 +40,5 @@ "ON_CONDITION", "AFTER_WORK", "AfterWorkOption", + "UPDATE_SYSTEM_MESSAGE", ] diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index af8f2272bc..1465e60564 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -49,6 +49,23 @@ def __post_init__(self): assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent" +@dataclass +class UPDATE_SYSTEM_MESSAGE: + update_function: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.update_function, str): + pass + elif isinstance(self.update_function, Callable): + sig = signature(self.update_function) + if len(sig.parameters) != 2: + raise ValueError("Update function must accept two parameters, context_variables and messages") + if sig.return_annotation != str: + raise ValueError("Update function must return a string") + else: + raise ValueError("Update function must be either a string or a callable") + + def initiate_swarm_chat( initial_agent: "SwarmAgent", messages: Union[List[Dict[str, Any]], str], @@ -118,10 +135,6 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): """Swarm transition function to determine and prepare the next agent in the conversation""" next_agent = determine_next_agent(last_speaker, groupchat) - if next_agent and isinstance(next_agent, SwarmAgent): - # Update their state - next_agent.update_state(context_variables, groupchat.messages) - return next_agent def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat): @@ -301,9 +314,9 @@ def __init__( # use in the tool execution agent to transfer to the next agent self._next_agent = None - self.register_update_states_functions(update_state_functions) + self.register_update_state_functions(update_state_functions) - def register_update_states_functions(self, functions: Optional[Union[List[Callable], Callable]]): + def register_update_state_functions(self, functions: Optional[Union[List[Callable], Callable]]): """ Register functions that will be called when the agent is selected and before it speaks. You can add your own validation or precondition functions here. @@ -312,9 +325,6 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab functions (List[Callable[[], None]]): A list of functions to be registered. Each function is called when the agent is selected and before it speaks. """ - - # TEMP - THIS WILL BE UPDATED TO UTILISE A NEW HOOK - update_agent_state - if functions is None: return if not isinstance(functions, list) and not isinstance(functions, Callable): @@ -324,7 +334,35 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab functions = [functions] for func in functions: - self.register_hook("update_states_once_selected", func) + if isinstance(func, UPDATE_SYSTEM_MESSAGE): + + # Wrapper function that allows this to be used in the update_agent_state hook + # Its primary purpose, however, is just to update the agent's system message + # Outer function to create a closure with the update function + def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): + def update_system_message_wrapper( + context_variables: Dict[str, Any], messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + if isinstance(update_func.update_function, str): + # Templates like "My context variable passport is {passport}" will + # use the context_variables for substitution + sys_message = OpenAIWrapper.instantiate( + template=update_func.update_function, + context=context_variables, + allow_format_str_template=True, + ) + else: + sys_message = update_func.update_function(context_variables, messages) + + self.update_system_message(sys_message) + return messages + + return update_system_message_wrapper + + self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) + + else: + self.register_hook(hookable_method="update_agent_state", hook=func) def _set_to_tool_execution(self): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. @@ -497,21 +535,6 @@ def add_functions(self, func_list: List[Callable]): for func in func_list: self.add_single_function(func) - def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]): - """Updates the state of the agent prior to reply""" - - # TEMP - THIS WILL BE REPLACED BY A NEW HOOK - update_agent_state - - for hook in self.hook_lists["update_states_once_selected"]: - result = hook(self, context_variables, messages) - - if result is None: - continue - - returned_variables, returned_messages = result - self._context_variables.update(returned_variables) - messages = self.process_all_messages_before_reply(returned_messages) - # Forward references for SwarmAgent in SwarmResult SwarmResult.update_forward_refs() diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b558038eec..e05510e8a5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -261,6 +261,7 @@ def __init__( "process_last_received_message": [], "process_all_messages_before_reply": [], "process_message_before_send": [], + "update_agent_state": [], } def _validate_llm_config(self, llm_config): @@ -2046,6 +2047,9 @@ def generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + messages = self.process_update_agent_states(messages) + # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_last_received_message(messages) @@ -2116,6 +2120,9 @@ async def a_generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + messages = self.process_update_agent_states(messages) + # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_all_messages_before_reply(messages) @@ -2802,6 +2809,24 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) + def process_update_agent_states(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to update the agent's state. + Primarily used to update context variables. + Will, potentially, modify the messages. + """ + hook_list = self.hook_lists["update_agent_state"] + + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + processed_messages = hook(self._context_variables, processed_messages) + return processed_messages + def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages.