Skip to content

Commit

Permalink
"Proposed changes: add hook update_states_once_selected"
Browse files Browse the repository at this point in the history
  • Loading branch information
linmou committed Nov 29, 2024
1 parent 1460bbb commit 22189f1
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
code_execution_config=False,
system_message_func: Optional[Callable] = None,
update_state_functions: Optional[Union[List[Callable], Callable]] = None,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -296,11 +296,30 @@ def __init__(
# use in the tool execution agent to transfer to the next agent
self._context_variables = {}
self._next_agent = None

self._system_message_func = system_message_func

self.hook_lists['update_states_once_selected'] = []

self.register_update_states_functions(update_state_functions)

def register_update_states_functions(self, functions: Optional[Union[List[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
You can add your own validation or precondition functions here.
Args:
functions (List[Callable[[], None]]): A list of functions to be registered. Each function
is called when the agent is selected and before it speaks.
"""
if functions is None: return
if not isinstance(functions, list) and not isinstance(functions, Callable):
raise ValueError("functions must be a list of callables")

if isinstance(functions, Callable):
functions = [functions]

for func in functions:
self.register_hook('update_states_once_selected', func)

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
Expand Down Expand Up @@ -477,14 +496,17 @@ def update_state(self, context_variables: Optional[Dict[str, Any]], messages: Li
"""Updates the state of the agent, system message so far. This is called when they're selected and just before they speak."""

for hook in self.hook_lists['update_states_once_selected']:
sig = signature(hook)
returned_variables, messages = hook(self, context_variables, messages)
# if returned_variables is not None:
# context_variables.update(returned_variables)
result = hook(self, context_variables, messages)

if result is None: continue

returned_variables, returned_messages = result
context_variables.update(returned_variables)
messages = self.process_all_messages_before_reply(returned_messages)

This comment has been minimized.

Copy link
@linmou

linmou Nov 29, 2024

Author Collaborator

process_all_messages_before_reply should not be used here. I change it to messages = returned_messages in the next commit



if self._system_message_func:
self.update_system_message(self._system_message_func(context_variables, messages))
# if self._system_message_func:
# self.update_system_message(self._system_message_func(context_variables, messages))


# Forward references for SwarmAgent in SwarmResult
Expand Down

0 comments on commit 22189f1

Please sign in to comment.