Skip to content

Commit

Permalink
Merge pull request #137 from ag2ai/convagentcontextvar
Browse files Browse the repository at this point in the history
Add context variables to ConversableAgent
  • Loading branch information
marklysze authored Dec 4, 2024
2 parents 0a9e847 + fba4d08 commit a6840cf
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 11 deletions.
48 changes: 47 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
description: Optional[str] = None,
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
response_format: Optional[BaseModel] = None,
):
"""
Expand Down Expand Up @@ -136,7 +137,11 @@ def __init__(
resume previous had conversations. Defaults to an empty chat history.
silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of
silent in each function.
response_format(BaseModel): Used to specify structured response format for the agent. Not available for all LLMs.
context_variables (dict or None): Context variables that provide a persistent context for the agent.
Note: Will maintain a reference to the passed in context variables (enabling a shared context)
Only used in Swarms at this stage:
https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent
response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand Down Expand Up @@ -196,6 +201,8 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)

self._context_variables = context_variables if context_variables is not None else {}

# Setting up code execution.
# Do not register code execution reply if code execution is disabled.
if code_execution_config is not False:
Expand Down Expand Up @@ -523,6 +530,45 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
),
)

def get_context(self, key: str, default: Any = None) -> Any:
"""
Get a context variable by key.
Args:
key: The key to look up
default: Value to return if key doesn't exist
Returns:
The value associated with the key, or default if not found
"""
return self._context_variables.get(key, default)

def set_context(self, key: str, value: Any) -> None:
"""
Set a context variable.
Args:
key: The key to set
value: The value to associate with the key
"""
self._context_variables[key] = value

def update_context(self, context_variables: Dict[str, Any]) -> None:
"""
Update multiple context variables at once.
Args:
context_variables: Dictionary of variables to update/add
"""
self._context_variables.update(context_variables)

def pop_context(self, key: str, default: Any = None) -> Any:
"""
Remove and return a context variable.
Args:
key: The key to remove
default: Value to return if key doesn't exist
Returns:
The value that was removed, or default if key not found
"""
return self._context_variables.pop(key, default)

@property
def system_message(self) -> str:
"""Return the system message."""
Expand Down
77 changes: 67 additions & 10 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,20 +1527,77 @@ def test_handle_carryover():
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


@pytest.mark.skipif(skip_openai, reason=reason)
def test_context_variables():
# Test initialization with context_variables
initial_context = {"test_key": "test_value", "number": 42, "nested": {"inner": "value"}}
agent = ConversableAgent(name="context_test_agent", llm_config=False, context_variables=initial_context)

# Check that context was properly initialized
assert agent._context_variables == initial_context

# Test initialization without context_variables
agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False)
assert agent_no_context._context_variables == {}

# Test get_context
assert agent.get_context("test_key") == "test_value"
assert agent.get_context("number") == 42
assert agent.get_context("nested") == {"inner": "value"}
assert agent.get_context("non_existent") is None
assert agent.get_context("non_existent", default="default") == "default"

# Test set_context
agent.set_context("new_key", "new_value")
assert agent.get_context("new_key") == "new_value"

# Test overwriting existing value
agent.set_context("test_key", "updated_value")
assert agent.get_context("test_key") == "updated_value"

# Test update_context
new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"}
agent.update_context(new_values)
assert agent.get_context("bulk_key1") == "bulk_value1"
assert agent.get_context("bulk_key2") == "bulk_value2"
assert agent.get_context("test_key") == "bulk_updated_value"

# Test pop_context
# Pop existing key
popped_value = agent.pop_context("bulk_key1")
assert popped_value == "bulk_value1"
assert agent.get_context("bulk_key1") is None

# Pop with default value
default_value = "default_value"
popped_default = agent.pop_context("non_existent", default=default_value)
assert popped_default == default_value

# Pop without default (should return None)
popped_none = agent.pop_context("another_non_existent")
assert popped_none is None

# Verify final state of context
expected_final_context = {
"number": 42,
"nested": {"inner": "value"},
"new_key": "new_value",
"bulk_key2": "bulk_value2",
"test_key": "bulk_updated_value",
}
assert agent._context_variables == expected_final_context


if __name__ == "__main__":
# test_trigger()
# test_context()
# test_max_consecutive_auto_reply()
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
# test_handle_carryover():
# test_max_turn()
# test_process_before_send()
# test_message_func()

test_summary()
test_adding_duplicate_function_warning()
# test_summary()
# test_adding_duplicate_function_warning()
# test_function_registration_e2e_sync()

test_process_gemini_carryover()
test_process_carryover()
# test_process_gemini_carryover()
# test_process_carryover()
test_context_variables()

0 comments on commit a6840cf

Please sign in to comment.