From e4a334e2808b9ecfd2c8e1831ba603b16b13cd78 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 25 Nov 2024 19:10:08 +0000 Subject: [PATCH] Fixed a bug with the next agent in function results, added ON_CONDITION test, test tidy ups for comments --- autogen/agentchat/contrib/swarm_agent.py | 4 +- test/agentchat/contrib/test_swarm.py | 156 ++++++++++++----------- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 1e305858d8..c1c790a906 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -407,9 +407,9 @@ def generate_swarm_tool_reply( if content.context_variables != {}: self._context_variables.update(content.context_variables) if content.agent is not None: - self._next_agent = content.agent + next_agent = content.agent elif isinstance(content, Agent): - self._next_agent = content + next_agent = content tool_responses_inner.append(tool_response) contents.append(str(tool_response["content"])) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 29bddef7ed..a9d1693bf8 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,11 +1,8 @@ -import os -import sys from typing import Any, Dict -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import patch import pytest -from autogen import ConversableAgent, UserProxyAgent, config_list_from_json from autogen.agentchat.contrib.swarm_agent import ( __CONTEXT_VARIABLES_PARAM_NAME__, AFTER_WORK, @@ -15,6 +12,8 @@ SwarmResult, initiate_swarm_chat, ) +from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.agentchat.user_proxy_agent import UserProxyAgent TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] @@ -22,11 +21,6 @@ def test_swarm_agent_initialization(): """Test SwarmAgent initialization with valid and invalid parameters""" - # Valid initialization - agent = SwarmAgent("test_agent") - assert agent.name == "test_agent" - assert agent.human_input_mode == "NEVER" - # Invalid functions parameter with pytest.raises(TypeError): SwarmAgent("test_agent", functions="invalid") @@ -79,15 +73,11 @@ def test_callable(x: int) -> SwarmAgent: def test_on_condition(): """Test ON_CONDITION initialization""" - agent = SwarmAgent("test") - condition = ON_CONDITION(agent=agent, condition="test condition") - assert condition.agent == agent - assert condition.condition == "test condition" # Test with a ConversableAgent test_conversable_agent = ConversableAgent("test_conversable_agent") with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): - condition = ON_CONDITION(agent=test_conversable_agent, condition="test condition") + _ = ON_CONDITION(agent=test_conversable_agent, condition="test condition") def test_receiving_agent(): @@ -159,7 +149,6 @@ def test_swarm_transitions(): initial_agent=agent1, messages=multiple_messages, agents=[agent1, agent2] ) - assert isinstance(last_speaker, SwarmAgent) assert last_speaker == agent1 @@ -174,11 +163,9 @@ def test_after_work_options(): def mock_generate_oai_reply(*args, **kwargs): return True, "This is a mock response from the agent." - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - reply_func_tuple["reply_func"] = mock_generate_oai_reply + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) # 1. Test TERMINATE agent1.after_work = AFTER_WORK(AfterWorkOption.TERMINATE) @@ -228,6 +215,50 @@ def test_callable(last_speaker, messages, groupchat, context_variables): assert chat_result.chat_history[3]["name"] == "agent2" +def test_on_condition_handoff(): + """Test ON_CONDITION in handoffs""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + + agent1.register_hand_off(hand_to=ON_CONDITION(agent2, "always take me to agent 2")) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Fake generate_oai_reply + def mock_generate_oai_reply_tool(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": "transfer_to_agent2"}}], + } + + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2], + max_rounds=5, + ) + + # We should have transferred to agent2 after agent1 has finished + assert chat_result.chat_history[3]["name"] == "agent2" + + def test_temporary_user_proxy(): """Test that temporary user proxy agent name is cleared""" agent1 = SwarmAgent("agent1") @@ -242,7 +273,7 @@ def test_temporary_user_proxy(): assert message.get("name") != "_User" -def test_context_variables_updating(): +def test_context_variables_updating_multi_tools(): """Test context variables handling in tool calls""" testing_llm_config = { @@ -258,12 +289,17 @@ def test_context_variables_updating(): test_context_variables = {"my_key": 0} # Increment the context variable - def test_func(context_variables: Dict[str, Any], param1: str) -> str: + def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 - return SwarmResult(values=f"Test {param1}", context_variables=context_variables, agent=agent1) + return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) + + # Increment the context variable + def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: + context_variables["my_key"] += 100 + return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", functions=[test_func], llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func], llm_config=testing_llm_config) + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -274,37 +310,32 @@ def mock_generate_oai_reply_tool(*args, **kwargs): return True, { "role": "assistant", "name": "agent1", - "tool_calls": [{"type": "function", "function": {"name": "test_func", "arguments": '{"param1": "test"}'}}], + "tool_calls": [ + {"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}}, + {"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}}, + ], } - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - if agent == agent1: - reply_func_tuple["reply_func"] = mock_generate_oai_reply - elif agent == agent2: - reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool - - # Test message with a tool call - tool_call_messages = [ - {"role": "user", "content": "Initial message"}, - ] + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent2, - messages=tool_call_messages, + messages=TEST_MESSAGES, agents=[agent1, agent2], context_variables=test_context_variables, max_rounds=3, ) # Ensure we've incremented the context variable - assert context_vars["my_key"] == 1 + # in both tools, updated values should traverse + # 0 + 1 (func 1) + 100 (func 2) = 101 + assert context_vars["my_key"] == 101 -def test_context_variables_updating_multi_tools(): - """Test context variables handling in tool calls""" +def test_function_transfer(): + """Tests a function call that has a transfer to agent in the SwarmResult""" testing_llm_config = { "config_list": [ @@ -323,13 +354,8 @@ def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) - # Increment the context variable - def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: - context_variables["my_key"] += 100 - return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func_1], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -342,36 +368,22 @@ def mock_generate_oai_reply_tool(*args, **kwargs): "name": "agent1", "tool_calls": [ {"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}}, - {"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}}, ], } - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - if agent == agent1: - reply_func_tuple["reply_func"] = mock_generate_oai_reply - elif agent == agent2: - reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool - - # Test message with a tool call - tool_call_messages = [ - {"role": "user", "content": "Initial message"}, - ] + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent2, - messages=tool_call_messages, + messages=TEST_MESSAGES, agents=[agent1, agent2], context_variables=test_context_variables, - max_rounds=3, + max_rounds=4, ) - # Ensure we've incremented the context variable - # in both tools, updated values should traverse - # 0 + 1 (func 1) + 100 (func 2) = 101 - assert context_vars["my_key"] == 101 + assert chat_result.chat_history[3]["name"] == "agent1" def test_invalid_parameters(): @@ -401,12 +413,6 @@ def test_non_swarm_in_hand_off(): with pytest.raises(AssertionError, match="Invalid After Work value"): agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent)) - with pytest.raises(AssertionError, match="Invalid After Work value"): - agent1.register_hand_off(hand_to=AFTER_WORK(0)) - - with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): - agent1.register_hand_off(hand_to=ON_CONDITION(0, "Testing")) - with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): agent1.register_hand_off(hand_to=ON_CONDITION(bad_agent, "Testing"))