Skip to content

Commit

Permalink
Fixed a bug with the next agent in function results, added ON_CONDITI…
Browse files Browse the repository at this point in the history
…ON test, test tidy ups for comments
  • Loading branch information
marklysze committed Nov 25, 2024
1 parent ae823c9 commit e4a334e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 77 deletions.
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def generate_swarm_tool_reply(
if content.context_variables != {}:
self._context_variables.update(content.context_variables)
if content.agent is not None:
self._next_agent = content.agent
next_agent = content.agent
elif isinstance(content, Agent):
self._next_agent = content
next_agent = content

tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))
Expand Down
156 changes: 81 additions & 75 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
import sys
from typing import Any, Dict
from unittest.mock import MagicMock, Mock, call, patch
from unittest.mock import patch

import pytest

from autogen import ConversableAgent, UserProxyAgent, config_list_from_json
from autogen.agentchat.contrib.swarm_agent import (
__CONTEXT_VARIABLES_PARAM_NAME__,
AFTER_WORK,
Expand All @@ -15,18 +12,15 @@
SwarmResult,
initiate_swarm_chat,
)
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.agentchat.user_proxy_agent import UserProxyAgent

TEST_MESSAGES = [{"role": "user", "content": "Initial message"}]


def test_swarm_agent_initialization():
"""Test SwarmAgent initialization with valid and invalid parameters"""

# Valid initialization
agent = SwarmAgent("test_agent")
assert agent.name == "test_agent"
assert agent.human_input_mode == "NEVER"

# Invalid functions parameter
with pytest.raises(TypeError):
SwarmAgent("test_agent", functions="invalid")
Expand Down Expand Up @@ -79,15 +73,11 @@ def test_callable(x: int) -> SwarmAgent:

def test_on_condition():
"""Test ON_CONDITION initialization"""
agent = SwarmAgent("test")
condition = ON_CONDITION(agent=agent, condition="test condition")
assert condition.agent == agent
assert condition.condition == "test condition"

# Test with a ConversableAgent
test_conversable_agent = ConversableAgent("test_conversable_agent")
with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"):
condition = ON_CONDITION(agent=test_conversable_agent, condition="test condition")
_ = ON_CONDITION(agent=test_conversable_agent, condition="test condition")


def test_receiving_agent():
Expand Down Expand Up @@ -159,7 +149,6 @@ def test_swarm_transitions():
initial_agent=agent1, messages=multiple_messages, agents=[agent1, agent2]
)

assert isinstance(last_speaker, SwarmAgent)
assert last_speaker == agent1


Expand All @@ -174,11 +163,9 @@ def test_after_work_options():
def mock_generate_oai_reply(*args, **kwargs):
return True, "This is a mock response from the agent."

# Mock an LLM response by overriding the generate_oai_reply function
for agent in [agent1, agent2]:
for reply_func_tuple in agent._reply_func_list:
if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply":
reply_func_tuple["reply_func"] = mock_generate_oai_reply
# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply)

# 1. Test TERMINATE
agent1.after_work = AFTER_WORK(AfterWorkOption.TERMINATE)
Expand Down Expand Up @@ -228,6 +215,50 @@ def test_callable(last_speaker, messages, groupchat, context_variables):
assert chat_result.chat_history[3]["name"] == "agent2"


def test_on_condition_handoff():
"""Test ON_CONDITION in handoffs"""

testing_llm_config = {
"config_list": [
{
"model": "gpt-4o",
"api_key": "SAMPLE_API_KEY",
}
]
}

agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)

agent1.register_hand_off(hand_to=ON_CONDITION(agent2, "always take me to agent 2"))

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
return True, "This is a mock response from the agent."

# Fake generate_oai_reply
def mock_generate_oai_reply_tool(*args, **kwargs):
return True, {
"role": "assistant",
"name": "agent1",
"tool_calls": [{"type": "function", "function": {"name": "transfer_to_agent2"}}],
}

# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent1,
messages=TEST_MESSAGES,
agents=[agent1, agent2],
max_rounds=5,
)

# We should have transferred to agent2 after agent1 has finished
assert chat_result.chat_history[3]["name"] == "agent2"


def test_temporary_user_proxy():
"""Test that temporary user proxy agent name is cleared"""
agent1 = SwarmAgent("agent1")
Expand All @@ -242,7 +273,7 @@ def test_temporary_user_proxy():
assert message.get("name") != "_User"


def test_context_variables_updating():
def test_context_variables_updating_multi_tools():
"""Test context variables handling in tool calls"""

testing_llm_config = {
Expand All @@ -258,12 +289,17 @@ def test_context_variables_updating():
test_context_variables = {"my_key": 0}

# Increment the context variable
def test_func(context_variables: Dict[str, Any], param1: str) -> str:
def test_func_1(context_variables: Dict[str, Any], param1: str) -> str:
context_variables["my_key"] += 1
return SwarmResult(values=f"Test {param1}", context_variables=context_variables, agent=agent1)
return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1)

# Increment the context variable
def test_func_2(context_variables: Dict[str, Any], param2: str) -> str:
context_variables["my_key"] += 100
return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1)

agent1 = SwarmAgent("agent1", functions=[test_func], llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", functions=[test_func], llm_config=testing_llm_config)
agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config)

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
Expand All @@ -274,37 +310,32 @@ def mock_generate_oai_reply_tool(*args, **kwargs):
return True, {
"role": "assistant",
"name": "agent1",
"tool_calls": [{"type": "function", "function": {"name": "test_func", "arguments": '{"param1": "test"}'}}],
"tool_calls": [
{"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}},
{"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}},
],
}

# Mock an LLM response by overriding the generate_oai_reply function
for agent in [agent1, agent2]:
for reply_func_tuple in agent._reply_func_list:
if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply":
if agent == agent1:
reply_func_tuple["reply_func"] = mock_generate_oai_reply
elif agent == agent2:
reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool

# Test message with a tool call
tool_call_messages = [
{"role": "user", "content": "Initial message"},
]
# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent2,
messages=tool_call_messages,
messages=TEST_MESSAGES,
agents=[agent1, agent2],
context_variables=test_context_variables,
max_rounds=3,
)

# Ensure we've incremented the context variable
assert context_vars["my_key"] == 1
# in both tools, updated values should traverse
# 0 + 1 (func 1) + 100 (func 2) = 101
assert context_vars["my_key"] == 101


def test_context_variables_updating_multi_tools():
"""Test context variables handling in tool calls"""
def test_function_transfer():
"""Tests a function call that has a transfer to agent in the SwarmResult"""

testing_llm_config = {
"config_list": [
Expand All @@ -323,13 +354,8 @@ def test_func_1(context_variables: Dict[str, Any], param1: str) -> str:
context_variables["my_key"] += 1
return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1)

# Increment the context variable
def test_func_2(context_variables: Dict[str, Any], param2: str) -> str:
context_variables["my_key"] += 100
return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1)

agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", functions=[test_func_1], llm_config=testing_llm_config)

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
Expand All @@ -342,36 +368,22 @@ def mock_generate_oai_reply_tool(*args, **kwargs):
"name": "agent1",
"tool_calls": [
{"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}},
{"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}},
],
}

# Mock an LLM response by overriding the generate_oai_reply function
for agent in [agent1, agent2]:
for reply_func_tuple in agent._reply_func_list:
if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply":
if agent == agent1:
reply_func_tuple["reply_func"] = mock_generate_oai_reply
elif agent == agent2:
reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool

# Test message with a tool call
tool_call_messages = [
{"role": "user", "content": "Initial message"},
]
# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent2,
messages=tool_call_messages,
messages=TEST_MESSAGES,
agents=[agent1, agent2],
context_variables=test_context_variables,
max_rounds=3,
max_rounds=4,
)

# Ensure we've incremented the context variable
# in both tools, updated values should traverse
# 0 + 1 (func 1) + 100 (func 2) = 101
assert context_vars["my_key"] == 101
assert chat_result.chat_history[3]["name"] == "agent1"


def test_invalid_parameters():
Expand Down Expand Up @@ -401,12 +413,6 @@ def test_non_swarm_in_hand_off():
with pytest.raises(AssertionError, match="Invalid After Work value"):
agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent))

with pytest.raises(AssertionError, match="Invalid After Work value"):
agent1.register_hand_off(hand_to=AFTER_WORK(0))

with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"):
agent1.register_hand_off(hand_to=ON_CONDITION(0, "Testing"))

with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"):
agent1.register_hand_off(hand_to=ON_CONDITION(bad_agent, "Testing"))

Expand Down

0 comments on commit e4a334e

Please sign in to comment.