Skip to content

Commit

Permalink
fix: update swarm_agent to handle agent lookup and add relevant unit …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
bassilkhilo committed Dec 3, 2024
1 parent 8c97793 commit f1f273f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
6 changes: 6 additions & 0 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
if isinstance(next_agent, str):
next_agent = groupchat.agent_by_name(name=next_agent)

# If no agent is found, raise an error or handle it appropriately
if next_agent is None:
raise ValueError(
f"No agent found with the name '{next_agent}'. Ensure the agent exists in the group chat."
)

return next_agent

# get the last swarm agent
Expand Down
65 changes: 65 additions & 0 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
import os
import pprint
from typing import Any, Dict
from unittest.mock import MagicMock, patch

import pytest
from dotenv import load_dotenv

from autogen.agentchat.contrib.swarm_agent import (
__CONTEXT_VARIABLES_PARAM_NAME__,
Expand Down Expand Up @@ -461,5 +464,67 @@ def test_initialization():
)


def test_string_agent_params_for_transfer():
"""Test that string agent parameters are handled correctly."""

# Load environment variables
load_dotenv()

# Fetch OpenAI API key from the environment
openai_api_key = os.environ.get("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY is not set in the environment variables.")

# Initialize context variables
context_variables = {}

# Define test messages
messages = [
{
"role": "user",
"content": "Begin by calling the hello_world() function.",
}
]

# Define a simple function for testing
def hello_world(context_variables: dict) -> SwarmResult:
value = "Hello, World!"
return SwarmResult(values=value, context_variables=context_variables, agent="agent_2")

# Define LLM configuration
llm_config = {
"cache_seed": 42,
"model": "gpt-4o-2024-08-06",
"api_key": openai_api_key,
}

# Create SwarmAgent instances
agent_1 = SwarmAgent(
name="agent_1",
system_message="Your task is to call hello_world() function.",
llm_config=llm_config,
functions=[hello_world],
)

agent_2 = SwarmAgent(
name="agent_2",
system_message="Your task is to let the user know what the previous agent said.",
llm_config=llm_config,
functions=[],
)

# Initiate the swarm chat
chat_result, final_context, last_active_agent = initiate_swarm_chat(
initial_agent=agent_1,
agents=[agent_1, agent_2],
context_variables=context_variables,
messages=messages,
after_work=AFTER_WORK(AfterWorkOption.TERMINATE),
)

# Print the chat result for debugging or verification
# pprint.pprint(chat_result)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f1f273f

Please sign in to comment.