Skip to content

Commit

Permalink
Merge branch 'main' into mcts_reason
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyCNM authored Dec 17, 2024
2 parents 14e225e + 2d996c0 commit bba250b
Show file tree
Hide file tree
Showing 61 changed files with 2,478 additions and 772 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-graph-rag-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
pip install pytest
- name: Install FalkorDB SDK when on linux
run: |
pip install -e .[graph_rag_falkor_db]
pip install -e .[graph-rag-falkor-db]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion MAINTAINERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
| Eric Moore | [emooreatx](https://github.com/emooreatx) | IBM | all|
| Evan David | [evandavid1](https://github.com/evandavid1) | - | all |
| Tvrtko Sternak | [sternakt](https://github.com/sternakt) | airt.ai | structured output |

| Jiacheng Shang | [Eric-Shang](https://github.com/Eric-Shang) | Toast | RAG |

**Pending Maintainers list (Marked with \*, Waiting for explicit approval from the maintainers)**
| Name | GitHub Handle | Organization | Features |
Expand Down
2 changes: 2 additions & 0 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .contrib.swarm_agent import (
AFTER_WORK,
ON_CONDITION,
UPDATE_SYSTEM_MESSAGE,
AfterWorkOption,
SwarmAgent,
SwarmResult,
Expand Down Expand Up @@ -44,6 +45,7 @@
"ON_CONDITION",
"AFTER_WORK",
"AfterWorkOption",
"UPDATE_SYSTEM_MESSAGE",
"ReasoningAgent",
"visualize_tree",
"ThinkNode",
Expand Down
36 changes: 25 additions & 11 deletions autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
If None, FalkorDB will auto generate an ontology from the input docs.
"""
self.name = name
self.ontology_table_name = name + "_ontology"
self.host = host
self.port = port
self.username = username
Expand All @@ -65,7 +66,7 @@ def connect_db(self):
"""
if self.name in self.falkordb.list_graphs():
try:
self.ontology = self._load_ontology_from_db(self.name)
self.ontology = self._load_ontology_from_db()
except Exception:
warnings.warn("Graph Ontology is not loaded.")

Expand Down Expand Up @@ -103,6 +104,8 @@ def init_db(self, input_doc: List[Document]):
sources=sources,
model=self.model,
)
# Save Ontology to graph for future access.
self._save_ontology_to_db(self.ontology)

self.knowledge_graph = KnowledgeGraph(
name=self.name,
Expand All @@ -118,9 +121,6 @@ def init_db(self, input_doc: List[Document]):

# Establishing a chat session will maintain the history
self._chat_session = self.knowledge_graph.chat_session()

# Save Ontology to graph for future access.
self._save_ontology_to_db(self.name, self.ontology)
else:
raise ValueError("No input documents could be loaded.")

Expand Down Expand Up @@ -149,17 +149,31 @@ def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryR

return GraphStoreQueryResult(answer=response["response"], results=[])

def __get_ontology_storage_graph(self, graph_name: str) -> Graph:
ontology_table_name = graph_name + "_ontology"
return self.falkordb.select_graph(ontology_table_name)
def delete(self) -> bool:
"""
Delete graph and its data from database.
"""
all_graphs = self.falkordb.list_graphs()
if self.name in all_graphs:
self.falkordb.select_graph(self.name).delete()
if self.ontology_table_name in all_graphs:
self.falkordb.select_graph(self.ontology_table_name).delete()
return True

def __get_ontology_storage_graph(self) -> Graph:
return self.falkordb.select_graph(self.ontology_table_name)

def _save_ontology_to_db(self, graph_name: str, ontology: Ontology):
def _save_ontology_to_db(self, ontology: Ontology):
"""
Save graph ontology to a separate table with {graph_name}_ontology
"""
graph = self.__get_ontology_storage_graph(graph_name)
if self.ontology_table_name in self.falkordb.list_graphs():
raise ValueError("Knowledge graph {} is already created.".format(self.name))
graph = self.__get_ontology_storage_graph()
ontology.save_to_graph(graph)

def _load_ontology_from_db(self, graph_name: str) -> Ontology:
graph = self.__get_ontology_storage_graph(graph_name)
def _load_ontology_from_db(self) -> Ontology:
if self.ontology_table_name not in self.falkordb.list_graphs():
raise ValueError("Knowledge graph {} has not been created.".format(self.name))
graph = self.__get_ontology_storage_graph()
return Ontology.from_graph(graph)
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
database: str = "neo4j",
username: str = "neo4j",
password: str = "neo4j",
llm: LLM = OpenAI(model="gpt-3.5-turbo", temperature=0.0),
llm: LLM = OpenAI(model="gpt-4o", temperature=0.0),
embedding: BaseEmbedding = OpenAIEmbedding(model_name="text-embedding-3-small"),
entities: Optional[TypeAlias] = None,
relations: Optional[TypeAlias] = None,
Expand Down
16 changes: 13 additions & 3 deletions autogen/agentchat/contrib/llamaindex_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.chat_engine.types import AgentChatResponse
from pydantic import BaseModel
from pydantic import __version__ as pydantic_version

# let's Avoid: AttributeError: type object 'Config' has no attribute 'copy'
# check for v1 like in autogen/_pydantic.py
is_pydantic_v1 = pydantic_version.startswith("1.")
if not is_pydantic_v1:
from pydantic import ConfigDict

Config = ConfigDict(arbitrary_types_allowed=True)
else:

class Config:
arbitrary_types_allowed = True

# Add Pydantic configuration to allow arbitrary types
# Added to mitigate PydanticSchemaGenerationError
class Config:
arbitrary_types_allowed = True

BaseModel.model_config = Config

except ImportError as e:
Expand Down
107 changes: 99 additions & 8 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import inspect
import json
import re
import warnings
from dataclasses import dataclass
from enum import Enum
from inspect import signature
Expand Down Expand Up @@ -57,6 +59,29 @@ def __post_init__(self):
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"


@dataclass
class UPDATE_SYSTEM_MESSAGE:
update_function: Union[Callable, str]

def __post_init__(self):
if isinstance(self.update_function, str):
# find all {var} in the string
vars = re.findall(r"\{(\w+)\}", self.update_function)
if len(vars) == 0:
warnings.warn("Update function string contains no variables. This is probably unintended.")

elif isinstance(self.update_function, Callable):
sig = signature(self.update_function)
if len(sig.parameters) != 2:
raise ValueError(
"Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively"
)
if sig.return_annotation != str:
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
messages: Union[List[Dict[str, Any]], str],
Expand Down Expand Up @@ -107,12 +132,27 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
name="Tool_Execution",
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution(context_variables=context_variables)
tool_execution._set_to_tool_execution()

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

# Point all SwarmAgent's context variables to this function's context_variables
# providing a single (shared) context across all SwarmAgents in the swarm
for agent in agents + [tool_execution]:
agent._context_variables = context_variables

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Swarm transition function to determine the next agent in the conversation"""
"""Swarm transition function to determine and prepare the next agent in the conversation"""
next_agent = determine_next_agent(last_speaker, groupchat)

return next_agent

def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Determine the next agent in the conversation"""
nonlocal INIT_AGENT_USED
if not INIT_AGENT_USED:
INIT_AGENT_USED = True
Expand Down Expand Up @@ -310,6 +350,9 @@ def __init__(
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
code_execution_config=False,
update_agent_state_before_reply: Optional[
Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE]
] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -335,23 +378,70 @@ def __init__(

self.after_work = None

# Used only in the tool execution agent for context and transferring to the next agent
# Note: context variables are not stored for each agent
self._context_variables = {}
# Used in the tool execution agent to transfer to the next agent
self._next_agent = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
self.register_update_agent_state_before_reply(update_agent_state_before_reply)

def register_update_agent_state_before_reply(self, functions: Optional[Union[List[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
You can add your own validation or precondition functions here.
Args:
functions (List[Callable[[], None]]): A list of functions to be registered. Each function
is called when the agent is selected and before it speaks.
"""
if functions is None:
return
if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]:
raise ValueError("functions must be a list of callables")

if not isinstance(functions, list):
functions = [functions]

for func in functions:
if isinstance(func, UPDATE_SYSTEM_MESSAGE):

# Wrapper function that allows this to be used in the update_agent_state hook
# Its primary purpose, however, is just to update the agent's system message
# Outer function to create a closure with the update function
def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE):
def update_system_message_wrapper(
agent: ConversableAgent, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
if isinstance(update_func.update_function, str):
# Templates like "My context variable passport is {passport}" will
# use the context_variables for substitution
sys_message = OpenAIWrapper.instantiate(
template=update_func.update_function,
context=agent._context_variables,
allow_format_str_template=True,
)
else:
sys_message = update_func.update_function(agent, messages)

agent.update_system_message(sys_message)
return messages

return update_system_message_wrapper

self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func))

else:
self.register_hook(hookable_method="update_agent_state", hook=func)

def _set_to_tool_execution(self):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
It will execute the tool calls and update the context_variables and next_agent accordingly.
It will execute the tool calls and update the referenced context_variables and next_agent accordingly.
"""
self._next_agent = None
self._context_variables = context_variables or {}
self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply)

Expand Down Expand Up @@ -491,6 +581,7 @@ def generate_swarm_tool_reply(
return False, None

def add_single_function(self, func: Callable, name=None, description=""):
"""Add a single function to the agent, removing context variables for LLM use"""
if name:
func._name = name
else:
Expand Down
19 changes: 19 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
"update_agent_state": [],
}

def _validate_llm_config(self, llm_config):
Expand Down Expand Up @@ -2091,6 +2092,9 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
self.update_agent_state_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_received_message(messages)
Expand Down Expand Up @@ -2161,6 +2165,9 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
self.update_agent_state_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages_before_reply(messages)
Expand Down Expand Up @@ -2847,6 +2854,18 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def update_agent_state_before_reply(self, messages: List[Dict]) -> None:
"""
Calls any registered capability hooks to update the agent's state.
Primarily used to update context variables.
Will, potentially, modify the messages.
"""
hook_list = self.hook_lists["update_agent_state"]

# Call each hook (in order of registration) to process the messages.
for hook in hook_list:
hook(self, messages)

def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
Expand Down
19 changes: 18 additions & 1 deletion autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,24 @@ def __init__(self, **kwargs: Any):
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)

self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
# if haven't got any access_key or secret_key in environment variable or via arguments then
if (
self._aws_access_key is None
or self._aws_access_key == ""
or self._aws_secret_key is None
or self._aws_secret_key == ""
):

# attempts to get client from attached role of managed service (lambda, ec2, ecs, etc.)
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=bedrock_config)
else:
session = boto3.Session(
aws_access_key_id=self._aws_access_key,
aws_secret_access_key=self._aws_secret_key,
aws_session_token=self._aws_session_token,
profile_name=self._aws_profile_name,
)
self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)

def message_retrieval(self, response):
"""Retrieve the messages from the response."""
Expand Down
Loading

0 comments on commit bba250b

Please sign in to comment.