Skip to content

Commit

Permalink
Added SwarmResult, start of context variables, progress on run_chat
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze committed Nov 15, 2024
1 parent 0b5a6c1 commit fb1ad59
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 51 deletions.
79 changes: 55 additions & 24 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from ..code_utils import content_str
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
Expand All @@ -23,7 +23,7 @@
from .agent import Agent
from .contrib.capabilities import transform_messages
from .conversable_agent import ConversableAgent
from .swarm import SwarmAgent
from .swarm import SwarmAgent, SwarmResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,6 +111,7 @@ def custom_speaker_selection_func(
- select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional)
- select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional)
- role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
- context_variables: dictionary of context variables for use with swarm-based group chats
"""

agents: List[Agent]
Expand Down Expand Up @@ -150,6 +151,7 @@ def custom_speaker_selection_func(
select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None
select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None
role_for_select_speaker_messages: Optional[str] = "system"
context_variables: Optional[Dict] = None

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin", "swarm"]
_VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None]
Expand Down Expand Up @@ -279,8 +281,13 @@ def __post_init__(self):
raise ValueError("select_speaker_auto_verbose cannot be None or non-bool")

# Ensure, for swarms, all agents are swarm agents
if self.speaker_selection_method == "swarm" and not all(isinstance(agent, SwarmAgent) for agent in self.agents):
raise ValueError("All agents must be of type SwarmAgent when using the 'swarm' speaker selection method.")
if self.speaker_selection_method == "swarm":
"""MS TEMP REMOVE
if not all(isinstance(agent, SwarmAgent) for agent in self.agents):
raise ValueError("All agents must be of type SwarmAgent when using the 'swarm' speaker selection method.")
"""
if not isinstance(self.context_variables, dict):
self.context_variables = {}

@property
def agent_names(self) -> List[str]:
Expand Down Expand Up @@ -426,30 +433,23 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
return random.choice(agents)

def swarm_select_speaker(self, last_speaker: Agent, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
"""Select the next speaker using the swarm pattern. Note that this does not need to cater for when the agent is continuing to speak."""
messages = self.messages

# TODO TODO TODO

# Always start with the first speaker
if len(messages) <= 1:
return last_speaker

# If the last message is a tool call, the last agent should execute it
if "tool_calls" in messages[-1]:
return last_speaker # If it's a tool_call then the agent executes it
last_message = messages[-1]

# If the last message is a tool response, check if the tool response is the name of the next agent
# Otherwise return the last agent before the tool call
if "tool_responses" in messages[-1]:
tool_call_msg = messages[-1].get("content", "")
if self.agent_by_name(name=tool_call_msg):
return self.agent_by_name(name=messages[-1].get("content", ""))
return self.agent_by_name(name=messages[-2].get("name", ""))
# elif last_speaker in [flight_modification, flight_cancel, flight_change, lost_baggage, triage_agent]:
# return user
# If the last message is a TRANSFER message, extract agent name and return them
if "content" in last_message and last_message["content"].startswith("TRANSFER:"):
agent_name = last_message["content"].split(":")[1].strip()
if self.agent_by_name(name=agent_name):
return self.agent_by_name(agent_name)

else:
return self.agent_by_name(name=messages[-2].get("name", ""))
# Otherwise, return the agent before the previous one
return self.agent_by_name(name=messages[-2].get("name", ""))

def _prepare_and_select_agents(
self,
Expand Down Expand Up @@ -1166,13 +1166,13 @@ def run_chat(
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
context_variables: Optional[Dict] = {}, # For Swarms
) -> Tuple[bool, Optional[str]]:
"""Run a group chat."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
next_speaker = None # The next swarm agent to speak, determined by the current swarm agent
groupchat = config
send_introductions = getattr(groupchat, "send_introductions", False)
silent = getattr(self, "_silent", False)
Expand All @@ -1187,7 +1187,6 @@ def run_chat(

# Swarm
if self.groupchat.speaker_selection_method == "swarm":
context_variables = copy.deepcopy(context_variables)
config.allow_repeat_speaker = True # Swarms allow the last speaker to be the next speaker

if self.client_cache is not None:
Expand All @@ -1205,13 +1204,45 @@ def run_chat(
# The conversation is over or it's the last round
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
if next_speaker:
# Speaker has already been selected (swarm)
speaker = next_speaker
next_speaker = None
else:
speaker = groupchat.select_speaker(speaker, self)

if not silent:
iostream = IOStream.get_default()
iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True)

# Update the context_variables on the agent
if self.groupchat.speaker_selection_method == "swarm" and isinstance(speaker, SwarmAgent):
speaker.context_variables.update(groupchat.context_variables)

# let the speaker speak
reply = speaker.generate_reply(sender=self)

# If we have a swarm reply, update context variables
if isinstance(reply, SwarmResult):
if reply.context_variables:
self.groupchat.context_variables.update(reply.context_variables)

reply_value = "\n".join(reply.values)

if reply.next_agent is not None:
next_speaker = groupchat.agent_by_name(reply.next_agent)
else:
# If there are multiple replies, it indicates multiple tool calls
# In this case we will see if any of the replies contains an agent Transfer and set the reply to that
if len(reply.values) > 1:
for content in reply.values:
if content in groupchat.agent_names:
reply_value = content
break

# Replaces the swarm result with string value
reply = reply_value

except KeyboardInterrupt:
# let the admin agent speak if interrupted
if groupchat.admin_name in groupchat.agent_names:
Expand Down
1 change: 0 additions & 1 deletion autogen/agentchat/swarm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .swarm_agent import *
from .swarm_core import *
118 changes: 97 additions & 21 deletions autogen/agentchat/swarm/swarm_agent.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,50 @@
import json
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from inspect import signature
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from openai.types.chat.chat_completion import ChatCompletion
from pydantic import BaseModel

import autogen
from autogen.agentchat import Agent, ConversableAgent
from autogen.function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from autogen.function_utils import get_function_schema, remove_parameter_from_function_schema
from autogen.oai import OpenAIWrapper


def parse_json_object(response: str) -> dict:
return json.loads(response)


# Parameter name for context variables
# Use the value in functions and they will be substituted with the context variables:
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"


class SwarmResult:
"""
Encapsulates the possible return values for a swarm agent function.
arguments:
values (str): The result values as a string. Can be many due to multiple tool calls.
agent (SwarmAgent): The swarm agent instance, if applicable.
context_variables (dict): A dictionary of context variables.
"""

values: List[str] = []
agent: Optional["SwarmAgent"] = None
context_variables: dict = {}

def __init__(
self,
values: str, # Text response, could be the next agent name as well
next_agent: str = None, # The name of the next agent if known
context_variables: dict = {},
) -> None:
self.values = values
self.next_agent = next_agent
self.context_variables = context_variables


class SwarmAgent(ConversableAgent):
def __init__(
self,
Expand All @@ -25,6 +56,7 @@ def __init__(
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
context_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -37,39 +69,76 @@ def __init__(
description=description,
**kwargs,
)

if isinstance(functions, list):
self.add_functions(functions)
elif isinstance(functions, Callable):
self.add_single_function(functions)

self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_reply_with_tool_calls)
self.context_variables = context_variables or {}

def update_context_variables(self, context_variables: Dict[str, Any]) -> None:
pass

# return str or any instance of BaseModel from pydantic

def generate_reply_with_tool_calls(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
) -> Tuple[bool, SwarmResult]:

client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]

messages = self._oai_system_message + [{"role": "user", "content": input}]
response = self.client.create(messages=messages)

if isinstance(response, ChatCompletion):
response = self.client.extract_text_or_completion_object(response)
if isinstance(response, str):
return response
elif isinstance(response, dict):
_, func_response = self.generate_tool_calls_reply([response])
return [response, func_response]
response = self._generate_oai_reply_from_client(client, self._oai_system_message + messages, self.client_cache)

if isinstance(response, str):
return True, SwarmResult(
values=[response],
next_agent=self.name,
)
elif isinstance(response, dict):
# Tool calls, inject context_variables back in to the response before executing the tools
if "tool_calls" in response:
for tool_call in response["tool_calls"]:
if tool_call["type"] == "function":
function_name = tool_call["function"]["name"]

# Check if this function exists in our function map
if function_name in self._function_map:
func = self._function_map[function_name] # Get the original function

# Check if function has context_variables parameter
sig = signature(func)
needs_context = __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters

if needs_context:
# Parse existing arguments
try:
current_args = json.loads(tool_call["function"]["arguments"])
except json.JSONDecodeError:
current_args = {}

# Inject context_variables
updated_args = {"context_variables": self.context_variables, **current_args}

# Update the tool call with new arguments
tool_call["function"]["arguments"] = json.dumps(updated_args)

_, func_response = self.generate_tool_calls_reply([response])

return_values = []
for response in func_response["tool_responses"]:
return_values.append(response["content"])

return True, SwarmResult(
values=return_values,
next_agent=None,
)
else:
raise ValueError("Invalid response type:", type(response))

Expand All @@ -79,12 +148,19 @@ def add_single_function(self, func: Callable, description=""):
if description:
func._description = description
else:
func._description = func.__doc__
# Use function's docstring, strip whitespace, fall back to empty string
func._description = (func.__doc__ or "").strip()

f = get_function_schema(func, name=func._name, description=func._description)
self.update_tool_signature(f, is_remove=False)

# Remove the context_variable parameter from the function signature stored in self.llm_config["tools"]
# This is done to prevent the context_variable parameter from being passed to the function when it is called
# by the LLM
f_no_context_variable = remove_parameter_from_function_schema(f, __CONTEXT_VARIABLES_PARAM_NAME__)
self.update_tool_signature(f_no_context_variable, is_remove=False)

self.register_function({func._name: self._wrap_function(func)})

def add_functions(self, func_list: List[Callable]):
for func in func_list:
self.add_single_function(func["func"])
self.add_single_function(func)
4 changes: 0 additions & 4 deletions autogen/agentchat/swarm/swarm_core.py

This file was deleted.

Loading

0 comments on commit fb1ad59

Please sign in to comment.