Skip to content

Commit

Permalink
Implemented update_agent_state hook, UPDATE_SYSTEM_MESSAGE
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <mark@sze.family>
  • Loading branch information
marklysze committed Nov 30, 2024
1 parent e768a69 commit 40c0b47
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 25 deletions.
2 changes: 2 additions & 0 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .contrib.swarm_agent import (
AFTER_WORK,
ON_CONDITION,
UPDATE_SYSTEM_MESSAGE,
AfterWorkOption,
SwarmAgent,
SwarmResult,
Expand Down Expand Up @@ -39,4 +40,5 @@
"ON_CONDITION",
"AFTER_WORK",
"AfterWorkOption",
"UPDATE_SYSTEM_MESSAGE",
]
73 changes: 48 additions & 25 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def __post_init__(self):
assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent"


@dataclass
class UPDATE_SYSTEM_MESSAGE:
update_function: Union[Callable, str]

def __post_init__(self):
if isinstance(self.update_function, str):
pass
elif isinstance(self.update_function, Callable):
sig = signature(self.update_function)
if len(sig.parameters) != 2:
raise ValueError("Update function must accept two parameters, context_variables and messages")
if sig.return_annotation != str:
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
messages: Union[List[Dict[str, Any]], str],
Expand Down Expand Up @@ -118,10 +135,6 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Swarm transition function to determine and prepare the next agent in the conversation"""
next_agent = determine_next_agent(last_speaker, groupchat)

if next_agent and isinstance(next_agent, SwarmAgent):
# Update their state
next_agent.update_state(context_variables, groupchat.messages)

return next_agent

def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat):
Expand Down Expand Up @@ -301,9 +314,9 @@ def __init__(
# use in the tool execution agent to transfer to the next agent
self._next_agent = None

self.register_update_states_functions(update_state_functions)
self.register_update_state_functions(update_state_functions)

def register_update_states_functions(self, functions: Optional[Union[List[Callable], Callable]]):
def register_update_state_functions(self, functions: Optional[Union[List[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
You can add your own validation or precondition functions here.
Expand All @@ -312,9 +325,6 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab
functions (List[Callable[[], None]]): A list of functions to be registered. Each function
is called when the agent is selected and before it speaks.
"""

# TEMP - THIS WILL BE UPDATED TO UTILISE A NEW HOOK - update_agent_state

if functions is None:
return
if not isinstance(functions, list) and not isinstance(functions, Callable):
Expand All @@ -324,7 +334,35 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab
functions = [functions]

for func in functions:
self.register_hook("update_states_once_selected", func)
if isinstance(func, UPDATE_SYSTEM_MESSAGE):

# Wrapper function that allows this to be used in the update_agent_state hook
# Its primary purpose, however, is just to update the agent's system message
# Outer function to create a closure with the update function
def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE):
def update_system_message_wrapper(
context_variables: Dict[str, Any], messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
if isinstance(update_func.update_function, str):
# Templates like "My context variable passport is {passport}" will
# use the context_variables for substitution
sys_message = OpenAIWrapper.instantiate(
template=update_func.update_function,
context=context_variables,
allow_format_str_template=True,
)
else:
sys_message = update_func.update_function(context_variables, messages)

self.update_system_message(sys_message)
return messages

return update_system_message_wrapper

self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func))

else:
self.register_hook(hookable_method="update_agent_state", hook=func)

def _set_to_tool_execution(self):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
Expand Down Expand Up @@ -497,21 +535,6 @@ def add_functions(self, func_list: List[Callable]):
for func in func_list:
self.add_single_function(func)

def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]):
"""Updates the state of the agent prior to reply"""

# TEMP - THIS WILL BE REPLACED BY A NEW HOOK - update_agent_state

for hook in self.hook_lists["update_states_once_selected"]:
result = hook(self, context_variables, messages)

if result is None:
continue

returned_variables, returned_messages = result
self._context_variables.update(returned_variables)
messages = self.process_all_messages_before_reply(returned_messages)


# Forward references for SwarmAgent in SwarmResult
SwarmResult.update_forward_refs()
25 changes: 25 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def __init__(
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
"update_agent_state": [],
}

def _validate_llm_config(self, llm_config):
Expand Down Expand Up @@ -2046,6 +2047,9 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
messages = self.process_update_agent_states(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_received_message(messages)
Expand Down Expand Up @@ -2116,6 +2120,9 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
messages = self.process_update_agent_states(messages)

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages_before_reply(messages)
Expand Down Expand Up @@ -2802,6 +2809,24 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_update_agent_states(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to update the agent's state.
Primarily used to update context variables.
Will, potentially, modify the messages.
"""
hook_list = self.hook_lists["update_agent_state"]

# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages

# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
processed_messages = hook(self._context_variables, processed_messages)
return processed_messages

def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
Expand Down

0 comments on commit 40c0b47

Please sign in to comment.