diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index d06999db34c..a92044f15b7 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -88,7 +88,9 @@ jobs:
fi
- name: Test with pytest skipping openai tests
if: matrix.python-version != '3.10' && matrix.os == 'ubuntu-latest'
+ # Remove the line below once https://github.com/docker/docker-py/issues/3256 is merged
run: |
+ pip install "requests<2.32.0"
pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
- name: Test with pytest skipping openai and docker tests
if: matrix.python-version != '3.10' && matrix.os != 'ubuntu-latest'
diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml
index 1bf71115d6b..b1b3e35e478 100644
--- a/.github/workflows/contrib-openai.yml
+++ b/.github/workflows/contrib-openai.yml
@@ -74,7 +74,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
-
+ AgentEvalTest:
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.10"]
+ runs-on: ${{ matrix.os }}
+ environment: openai1
+ steps:
+ # checkout to pr branch
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies
+ run: |
+ docker --version
+ python -m pip install --upgrade pip wheel
+ pip install -e .
+ python -c "import autogen"
+ pip install pytest-cov>=5 pytest-asyncio
+ - name: Coverage
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
+ AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
+ OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
+ run: |
+ pytest test/agentchat/contrib/agent_eval/test_agent_eval.py
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
CompressionTest:
strategy:
matrix:
diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index f8dd1d46186..38fab877402 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -107,7 +107,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr poppler-utils
- pip install unstructured[all-docs]==0.13.0
+ pip install --no-cache-dir unstructured[all-docs]==0.13.0
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
@@ -125,6 +125,35 @@ jobs:
file: ./coverage.xml
flags: unittests
+ AgentEvalTest:
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.10"]
+ runs-on: ${{ matrix.os }}
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for AgentEval
+ run: |
+ pip install -e .
+ - name: Coverage
+ run: |
+ pytest test/agentchat/contrib/agent_eval/ --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
CompressionTest:
runs-on: ${{ matrix.os }}
strategy:
diff --git a/README.md b/README.md
index fabbff99b63..e78d4b91aad 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,7 @@
[![Discord](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://aka.ms/autogen-dc)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40pyautogen)](https://twitter.com/pyautogen)
+[![NuGet version](https://badge.fury.io/nu/AutoGen.Core.svg)](https://badge.fury.io/nu/AutoGen.Core)
# AutoGen
[📚 Cite paper](#related-papers).
@@ -14,13 +15,19 @@
-->
+:fire: May 13, 2024: [The Economist](https://www.economist.com/science-and-technology/2024/05/13/todays-ai-models-are-impressive-teams-of-them-will-be-formidable) published an article about multi-agent systems (MAS) following a January 2024 interview with [Chi Wang](https://github.com/sonichi).
+
+:fire: May 11, 2024: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://openreview.net/pdf?id=uAjxFFing2) received the best paper award in [ICLR 2024 LLM Agents Workshop](https://llmagents.github.io/).
+
+:fire: Apr 26, 2024: [AutoGen.NET](https://microsoft.github.io/autogen-for-net/) is available for .NET developers!
+
:fire: Apr 17, 2024: Andrew Ng cited AutoGen in [The Batch newsletter](https://www.deeplearning.ai/the-batch/issue-245/) and [What's next for AI agentic workflows](https://youtu.be/sal78ACtGTc?si=JduUzN_1kDnMq0vF) at Sequoia Capital's AI Ascent (Mar 26).
:fire: Mar 3, 2024: What's new in AutoGen? 📰[Blog](https://microsoft.github.io/autogen/blog/2024/03/03/AutoGen-Update); 📺[Youtube](https://www.youtube.com/watch?v=j_mtwQiaLGU).
:fire: Mar 1, 2024: the first AutoGen multi-agent experiment on the challenging [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) benchmark achieved the No. 1 accuracy in all the three levels.
-:tada: Jan 30, 2024: AutoGen is highlighted by Peter Lee in Microsoft Research Forum [Keynote](https://t.co/nUBSjPDjqD).
+
:tada: Dec 31, 2023: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023).
@@ -28,13 +35,13 @@
-:tada: Nov 8, 2023: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff.
+:tada: Nov 8, 2023: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff from [FLAML](https://github.com/microsoft/FLAML).
-:tada: Nov 6, 2023: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U).
+
-:tada: Nov 1, 2023: AutoGen is the top trending repo on GitHub in October 2023.
+
-:tada: Oct 03, 2023: AutoGen spins off from FLAML on GitHub and has a major paper update (first version on Aug 16).
+
diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py
index b527f8e0bae..dd489c03625 100644
--- a/autogen/agentchat/chat.py
+++ b/autogen/agentchat/chat.py
@@ -195,7 +195,9 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
]
- __post_carryover_processing(chat_info)
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
@@ -236,7 +238,10 @@ async def _dependent_chat_future(
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
- __post_carryover_processing(chat_info)
+
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
new file mode 100644
index 00000000000..6588a1ec611
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -0,0 +1,7 @@
+Agents for running the AgentEval pipeline.
+
+AgentEval is a process for evaluating a LLM-based system's performance on a given task.
+
+When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
+
+For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
diff --git a/autogen/agentchat/contrib/agent_eval/agent_eval.py b/autogen/agentchat/contrib/agent_eval/agent_eval.py
new file mode 100644
index 00000000000..b48c65a66d2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/agent_eval.py
@@ -0,0 +1,101 @@
+from typing import Dict, List, Literal, Optional, Union
+
+import autogen
+from autogen.agentchat.contrib.agent_eval.criterion import Criterion
+from autogen.agentchat.contrib.agent_eval.critic_agent import CriticAgent
+from autogen.agentchat.contrib.agent_eval.quantifier_agent import QuantifierAgent
+from autogen.agentchat.contrib.agent_eval.subcritic_agent import SubCriticAgent
+from autogen.agentchat.contrib.agent_eval.task import Task
+
+
+def generate_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ task: Task = None,
+ additional_instructions: str = "",
+ max_round=2,
+ use_subcritic: bool = False,
+):
+ """
+ Creates a list of criteria for evaluating the utility of a given task.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ task (Task): The task to evaluate.
+ additional_instructions (str): Additional instructions for the criteria agent.
+ max_round (int): The maximum number of rounds to run the conversation.
+ use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
+ Returns:
+ list: A list of Criterion objects for evaluating the utility of the given task.
+ """
+ critic = CriticAgent(
+ system_message=CriticAgent.DEFAULT_SYSTEM_MESSAGE + "\n" + additional_instructions,
+ llm_config=llm_config,
+ )
+
+ critic_user = autogen.UserProxyAgent(
+ name="critic_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ agents = [critic_user, critic]
+
+ if use_subcritic:
+ subcritic = SubCriticAgent(
+ llm_config=llm_config,
+ )
+ agents.append(subcritic)
+
+ groupchat = autogen.GroupChat(
+ agents=agents, messages=[], max_round=max_round, speaker_selection_method="round_robin"
+ )
+ critic_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
+
+ critic_user.initiate_chat(critic_manager, message=task.get_sys_message())
+ criteria = critic_user.last_message()
+ content = criteria["content"]
+ # need to strip out any extra code around the returned json
+ content = content[content.find("[") : content.rfind("]") + 1]
+ criteria = Criterion.parse_json_str(content)
+ return criteria
+
+
+def quantify_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ criteria: List[Criterion] = None,
+ task: Task = None,
+ test_case: str = "",
+ ground_truth: str = "",
+):
+ """
+ Quantifies the performance of a system using the provided criteria.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
+ task (Task): The task to evaluate.
+ test_case (str): The test case to evaluate.
+ ground_truth (str): The ground truth for the test case.
+ Returns:
+ dict: A dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ """
+ quantifier = QuantifierAgent(
+ llm_config=llm_config,
+ )
+
+ quantifier_user = autogen.UserProxyAgent(
+ name="quantifier_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ quantifier_user.initiate_chat( # noqa: F841
+ quantifier,
+ message=task.get_sys_message()
+ + "Evaluation dictionary: "
+ + Criterion.write_json(criteria)
+ + "actual test case to evaluate: "
+ + test_case,
+ )
+ quantified_results = quantifier_user.last_message()
+ return {"actual_success": ground_truth, "estimated_performance": quantified_results["content"]}
diff --git a/autogen/agentchat/contrib/agent_eval/criterion.py b/autogen/agentchat/contrib/agent_eval/criterion.py
new file mode 100644
index 00000000000..5efd121ec07
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/criterion.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+import json
+from typing import List
+
+import pydantic_core
+from pydantic import BaseModel
+from pydantic.json import pydantic_encoder
+
+
+class Criterion(BaseModel):
+ """
+ A class that represents a criterion for agent evaluation.
+ """
+
+ name: str
+ description: str
+ accepted_values: List[str]
+ sub_criteria: List[Criterion] = list()
+
+ @staticmethod
+ def parse_json_str(criteria: str):
+ """
+ Create a list of Criterion objects from a json string.
+ Args:
+ criteria (str): Json string that represents the criteria
+ returns:
+ [Criterion]: A list of Criterion objects that represents the json criteria information.
+ """
+ return [Criterion(**crit) for crit in json.loads(criteria)]
+
+ @staticmethod
+ def write_json(criteria):
+ """
+ Create a json string from a list of Criterion objects.
+ Args:
+ criteria ([Criterion]): A list of Criterion objects.
+ Returns:
+ str: A json string that represents the list of Criterion objects.
+ """
+ return json.dumps([crit.model_dump() for crit in criteria], indent=2)
diff --git a/autogen/agentchat/contrib/agent_eval/critic_agent.py b/autogen/agentchat/contrib/agent_eval/critic_agent.py
new file mode 100644
index 00000000000..2f5e5598ba6
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/critic_agent.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class CriticAgent(ConversableAgent):
+ """
+ An agent for creating list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant. You suggest criteria for evaluating different tasks. They should be distinguishable, quantifiable and not redundant.
+ Convert the evaluation criteria into a list where each item is a criteria which consists of the following dictionary as follows
+ {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ Make sure "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels and "description" includes the criterion description.
+ Output just the criteria string you have created, no code.
+ """
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating list criteria for evaluating the utility of a given task."
+
+ def __init__(
+ self,
+ name="critic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/quantifier_agent.py b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
new file mode 100644
index 00000000000..02a8f650fab
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class QuantifierAgent(ConversableAgent):
+ """
+ An agent for quantifying the performance of a system using the provided criteria.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """"You are a helpful assistant. You quantify the output of different tasks based on the given criteria.
+ The criterion is given in a json list format where each element is a distinct criteria.
+ The each element is a dictionary as follows {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ You are going to quantify each of the crieria for a given task based on the task description.
+ Return a dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ Return only the dictionary, no code."""
+
+ DEFAULT_DESCRIPTION = "An AI agent for quantifing the performance of a system using the provided criteria."
+
+ def __init__(
+ self,
+ name="quantifier",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(name=name, system_message=system_message, description=description, **kwargs)
diff --git a/autogen/agentchat/contrib/agent_eval/subcritic_agent.py b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
new file mode 100755
index 00000000000..fa994ee7bda
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
@@ -0,0 +1,42 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class SubCriticAgent(ConversableAgent):
+ """
+ An agent for creating subcriteria from a given list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant to the critic agent. You suggest sub criteria for evaluating different tasks based on the criteria provided by the critic agent (if you feel it is needed).
+ They should be distinguishable, quantifiable, and related to the overall theme of the critic's provided criteria.
+ You operate by taking in the description of the criteria. You then create a new key called sub criteria where you provide the sub criteria for the given criteria.
+ The value of the sub_criteria is a dictionary where the keys are the subcriteria and each value is as follows {"description": sub criteria description , "accepted_values": possible accepted inputs for this key}
+ Do this for each criteria provided by the critic (removing the criteria's accepted values). "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels. "description" includes the criterion description.
+ Once you have created the sub criteria for the given criteria, you return the json (make sure to include the contents of the critic's dictionary in the final dictionary as well).
+ Make sure to return a valid json and no code"""
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating subcriteria from a given list of criteria."
+
+ def __init__(
+ self,
+ name="subcritic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/task.py b/autogen/agentchat/contrib/agent_eval/task.py
new file mode 100644
index 00000000000..9f96fbf79e2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/task.py
@@ -0,0 +1,37 @@
+import json
+
+from pydantic import BaseModel
+
+
+class Task(BaseModel):
+ """
+ Class representing a task for agent completion, includes example agent execution for criteria generation.
+ """
+
+ name: str
+ description: str
+ successful_response: str
+ failed_response: str
+
+ def get_sys_message(self):
+ return f"""Task: {self.name}.
+ Task description: {self.description}
+ Task successful example: {self.successful_response}
+ Task failed example: {self.failed_response}
+ """
+
+ @staticmethod
+ def parse_json_str(task: str):
+ """
+ Create a Task object from a json object.
+ Args:
+ json_data (dict): A dictionary that represents the task.
+ Returns:
+ Task: A Task object that represents the json task information.
+ """
+ json_data = json.loads(task)
+ name = json_data.get("name")
+ description = json_data.get("description")
+ successful_response = json_data.get("successful_response")
+ failed_response = json_data.get("failed_response")
+ return Task(name, description, successful_response, failed_response)
diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py
index 173811842eb..44b10259f1b 100644
--- a/autogen/agentchat/contrib/capabilities/context_handling.py
+++ b/autogen/agentchat/contrib/capabilities/context_handling.py
@@ -8,8 +8,8 @@
from autogen import ConversableAgent, token_count_utils
warn(
- "Context handling with TransformChatHistory is deprecated. "
- "Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
+ "Context handling with TransformChatHistory is deprecated and will be removed in `0.2.30`. "
+ "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
DeprecationWarning,
stacklevel=2,
)
diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py
index 8303843e881..bc56efd74d2 100644
--- a/autogen/agentchat/contrib/capabilities/transforms.py
+++ b/autogen/agentchat/contrib/capabilities/transforms.py
@@ -8,6 +8,7 @@
from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
+from autogen.oai.openai_utils import filter_config
from .text_compressors import LLMLingua, TextCompressor
@@ -130,6 +131,8 @@ def __init__(
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
):
"""
Args:
@@ -140,11 +143,17 @@ def __init__(
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
@@ -169,10 +178,15 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
for msg in reversed(temp_messages):
# Some messages may not have content.
- if not isinstance(msg.get("content"), (str, list)):
+ if not _is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue
+ if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
+ processed_messages.insert(0, msg)
+ processed_messages_tokens += _count_tokens(msg["content"])
+ continue
+
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
# If adding this message would exceed the token limit, truncate the last message to meet the total token
@@ -282,6 +296,8 @@ def __init__(
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
):
"""
Args:
@@ -293,6 +309,10 @@ def __init__(
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from compression. If False, messages that match the filter will be compressed.
"""
if text_compressor is None:
@@ -303,6 +323,8 @@ def __init__(
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
self._cache = cache
# Optimizing savings calculations to optimize log generation
@@ -334,7 +356,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
- if not isinstance(message.get("content"), (str, list)):
+ if not _is_content_right_type(message.get("content")):
+ continue
+
+ if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if _is_content_text_empty(message["content"]):
@@ -397,7 +422,7 @@ def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
- value = (tokens_saved, json.dumps(compressed_content))
+ value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
@@ -427,6 +452,10 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
return token_count
+def _is_content_right_type(content: Any) -> bool:
+ return isinstance(content, (str, list))
+
+
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
@@ -434,3 +463,10 @@ def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
+
+
+def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
+ if not filter_dict:
+ return True
+
+ return len(filter_config([message], filter_dict, exclude)) > 0
diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py
index 9c4e78af852..cbedb17ceed 100644
--- a/autogen/agentchat/contrib/compressible_agent.py
+++ b/autogen/agentchat/contrib/compressible_agent.py
@@ -13,8 +13,8 @@
logger = logging.getLogger(__name__)
warn(
- "Context handling with CompressibleAgent is deprecated. "
- "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/reference/agentchat/contrib/capabilities/transform_messages",
+ "Context handling with CompressibleAgent is deprecated and will be removed in `0.2.30`. "
+ "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
DeprecationWarning,
stacklevel=2,
)
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
index 0f5de8adcb5..40a28bfbcfa 100644
--- a/autogen/agentchat/contrib/gpt_assistant_agent.py
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -11,6 +11,7 @@
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
+from autogen.runtime_logging import log_new_agent, logging_enabled
logger = logging.getLogger(__name__)
@@ -65,6 +66,8 @@ def __init__(
super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)
+ if logging_enabled():
+ log_new_agent(self, locals())
# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
@@ -169,10 +172,11 @@ def __init__(
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
logger.warning("overwrite_tools is False. Using existing tools from assistant API.")
+ self.update_system_message(self._openai_assistant.instructions)
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
- self.register_reply(Agent, GPTAssistantAgent._invoke_assistant, position=2)
+ self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=2)
def _invoke_assistant(
self,
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index bfd38a54d60..c3394a96bb6 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -937,6 +937,7 @@ def my_summary_method(
One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
on the conversation and extract a summary when summary_method is "reflection_with_llm".
The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
+ Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
- If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
@@ -1168,8 +1169,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
raise ValueError("The summary_prompt must be a string.")
msg_list = recipient.chat_messages_for_summary(sender)
agent = sender if recipient is None else recipient
+ role = summary_args.get("summary_role", None)
+ if role and not isinstance(role, str):
+ raise ValueError("The summary_role in summary_arg must be a string.")
try:
- summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"))
+ summary = sender._reflection_with_llm(
+ prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
+ )
except BadRequestError as e:
warnings.warn(
f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
@@ -1178,7 +1184,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
return summary
def _reflection_with_llm(
- self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
+ self,
+ prompt,
+ messages,
+ llm_agent: Optional[Agent] = None,
+ cache: Optional[AbstractCache] = None,
+ role: Union[str, None] = None,
) -> str:
"""Get a chat summary using reflection with an llm client based on the conversation history.
@@ -1187,10 +1198,14 @@ def _reflection_with_llm(
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (AbstractCache or None): the cache client to be used for this conversation.
+ role (str): the role of the message, usually "system" or "user". Default is "system".
"""
+ if not role:
+ role = "system"
+
system_msg = [
{
- "role": "system",
+ "role": role,
"content": prompt,
}
]
@@ -2391,6 +2406,8 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]):
self._assert_valid_name(name)
if func is None and name not in self._function_map.keys():
warnings.warn(f"The function {name} to remove doesn't exist", name)
+ if name in self._function_map:
+ warnings.warn(f"Function '{name}' is being overridden.", UserWarning)
self._function_map.update(function_map)
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}
@@ -2427,6 +2444,9 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
+ if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]):
+ warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning)
+
self.llm_config["functions"] = [
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
] + [func_sig]
@@ -2466,7 +2486,9 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
)
self._assert_valid_name(tool_sig["function"]["name"])
- if "tools" in self.llm_config.keys():
+ if "tools" in self.llm_config:
+ if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]):
+ warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning)
self.llm_config["tools"] = [
tool
for tool in self.llm_config["tools"]
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index 86492455080..83c426272a2 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -1,3 +1,5 @@
+import copy
+import json
import logging
import random
import re
@@ -12,6 +14,7 @@
from ..io.base import IOStream
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
+from .chat import ChatResult
from .conversable_agent import ConversableAgent
logger = logging.getLogger(__name__)
@@ -36,6 +39,7 @@ class GroupChat:
Then select the next role from {agentlist} to play. Only return the role."
- select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ To ignore this prompt being used, set this to None. If set to None, ensure your instructions for selecting a speaker are in the select_speaker_message_template string.
- select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
@@ -98,15 +102,15 @@ def custom_speaker_selection_func(
agents: List[Agent]
messages: List[Dict]
- max_round: Optional[int] = 10
- admin_name: Optional[str] = "Admin"
- func_call_filter: Optional[bool] = True
+ max_round: int = 10
+ admin_name: str = "Admin"
+ func_call_filter: bool = True
speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
- max_retries_for_selecting_speaker: Optional[int] = 2
+ max_retries_for_selecting_speaker: int = 2
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
- enable_clear_history: Optional[bool] = False
+ enable_clear_history: bool = False
send_introductions: bool = False
select_speaker_message_template: str = """You are in a role play game. The following roles are available:
{roles}.
@@ -222,8 +226,8 @@ def __post_init__(self):
if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0:
raise ValueError("select_speaker_message_template cannot be empty or None.")
- if self.select_speaker_prompt_template is None or len(self.select_speaker_prompt_template) == 0:
- raise ValueError("select_speaker_prompt_template cannot be empty or None.")
+ if self.select_speaker_prompt_template is not None and len(self.select_speaker_prompt_template) == 0:
+ self.select_speaker_prompt_template = None
if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0:
raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
@@ -327,7 +331,13 @@ def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
return return_msg
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
- """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
+ """Return the floating system prompt selecting the next speaker.
+ This is always the *last* message in the context.
+ Will return None if the select_speaker_prompt_template is None."""
+
+ if self.select_speaker_prompt_template is None:
+ return None
+
if agents is None:
agents = self.agents
@@ -621,23 +631,35 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
remove_other_reply_funcs=True,
)
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
- chat_messages={checking_agent: messages},
+ chat_messages=(
+ {checking_agent: messages}
+ if self.select_speaker_prompt_template is not None
+ else {checking_agent: messages[:-1]}
+ ),
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
)
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
# Run the speaker selection chat
result = checking_agent.initiate_chat(
speaker_selection_agent,
cache=None, # don't use caching for the speaker selection chat
- message={
- "content": self.select_speaker_prompt(agents),
- "override_role": self.role_for_select_speaker_messages,
- },
+ message=start_message,
max_turns=2
* max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
clear_history=False,
@@ -708,6 +730,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
remove_other_reply_funcs=True,
)
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
@@ -717,11 +741,20 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
)
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
# Run the speaker selection chat
result = await checking_agent.a_initiate_chat(
speaker_selection_agent,
cache=None, # don't use caching for the speaker selection chat
- message=self.select_speaker_prompt(agents),
+ message=start_message,
max_turns=2
* max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
clear_history=False,
@@ -914,6 +947,7 @@ def __init__(
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
human_input_mode: Optional[str] = "NEVER",
system_message: Optional[Union[str, List]] = "Group chat manager.",
+ silent: bool = False,
**kwargs,
):
if (
@@ -937,6 +971,8 @@ def __init__(
# Store groupchat
self._groupchat = groupchat
+ self._silent = silent
+
# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
@@ -989,6 +1025,7 @@ def run_chat(
speaker = sender
groupchat = config
send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
if send_introductions:
# Broadcast the intro
@@ -1043,7 +1080,7 @@ def run_chat(
reply["content"] = self.clear_agents_history(reply, groupchat)
# The speaker sends the message without requesting a reply
- speaker.send(reply, self, request_reply=False)
+ speaker.send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
@@ -1064,6 +1101,7 @@ async def a_run_chat(
speaker = sender
groupchat = config
send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
if send_introductions:
# Broadcast the intro
@@ -1108,7 +1146,7 @@ async def a_run_chat(
if reply is None:
break
# The speaker sends the message without requesting a reply
- await speaker.a_send(reply, self, request_reply=False)
+ await speaker.a_send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
@@ -1116,6 +1154,290 @@ async def a_run_chat(
a.previous_cache = None
return True, None
+ def resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: str = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ async def a_resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: str = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ await self.a_send(
+ message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
+ )
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ def _valid_resume_messages(self, messages: List[Dict]):
+ """Validates the messages used for resuming
+
+ args:
+ messages (List[Dict]): list of messages to resume with
+
+ returns:
+ - bool: Whether they are valid for resuming
+ """
+ # Must have messages to start with, otherwise they should run run_chat
+ if not messages:
+ raise Exception(
+ "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat."
+ )
+
+ # Check that all agents in the chat messages exist in the group chat
+ for message in messages:
+ if message.get("name"):
+ if (
+ not self._groupchat.agent_by_name(message["name"])
+ and not message["name"] == self._groupchat.admin_name # ignore group chat's name
+ and not message["name"] == self.name # ignore group chat manager's name
+ ):
+ raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")
+
+ def _process_resume_termination(self, remove_termination_string: str, messages: List[Dict]):
+ """Removes termination string, if required, and checks if termination may occur.
+
+ args:
+ remove_termination_string (str): termination string to remove from the last message
+
+ returns:
+ None
+ """
+
+ last_message = messages[-1]
+
+ # Replace any given termination string in the last message
+ if remove_termination_string:
+ if messages[-1].get("content") and remove_termination_string in messages[-1]["content"]:
+ messages[-1]["content"] = messages[-1]["content"].replace(remove_termination_string, "")
+
+ # Check if the last message meets termination (if it has one)
+ if self._is_termination_msg:
+ if self._is_termination_msg(last_message):
+ logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.")
+
+ def messages_from_string(self, message_string: str) -> List[Dict]:
+ """Reads the saved state of messages in Json format for resume and returns as a messages list
+
+ args:
+ - message_string: Json string, the saved state
+
+ returns:
+ - List[Dict]: List of messages
+ """
+ try:
+ state = json.loads(message_string)
+ except json.JSONDecodeError:
+ raise Exception("Messages string is not a valid JSON string")
+
+ return state
+
+ def messages_to_string(self, messages: List[Dict]) -> str:
+ """Converts the provided messages into a Json string that can be used for resuming the chat.
+ The state is made up of a list of messages
+
+ args:
+ - messages (List[Dict]): set of messages to convert to a string
+
+ returns:
+ - str: Json representation of the messages which can be persisted for resuming later
+ """
+
+ return json.dumps(messages)
+
def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
diff --git a/autogen/code_utils.py b/autogen/code_utils.py
index e1bc951f099..98ed6067066 100644
--- a/autogen/code_utils.py
+++ b/autogen/code_utils.py
@@ -6,8 +6,10 @@
import subprocess
import sys
import time
+import venv
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
+from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import docker
@@ -719,3 +721,19 @@ def implement(
# cost += metrics["gen_cost"]
# if metrics["succeed_assertions"] or i == len(configs) - 1:
# return responses[metrics["index_selected"]], cost, i
+
+
+def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace:
+ """Creates a python virtual environment and returns the context.
+
+ Args:
+ dir_path (str): Directory path where the env will be created.
+ **env_args: Any extra args to pass to the `EnvBuilder`
+
+ Returns:
+ SimpleNamespace: the virtual env context object."""
+ if not env_args:
+ env_args = {"with_pip": True}
+ env_builder = venv.EnvBuilder(**env_args)
+ env_builder.create(dir_path)
+ return env_builder.ensure_directories(dir_path)
diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py
index ed92cd527be..29172bbe922 100644
--- a/autogen/coding/local_commandline_code_executor.py
+++ b/autogen/coding/local_commandline_code_executor.py
@@ -1,4 +1,5 @@
import logging
+import os
import re
import subprocess
import sys
@@ -6,6 +7,7 @@
from hashlib import md5
from pathlib import Path
from string import Template
+from types import SimpleNamespace
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
from typing_extensions import ParamSpec
@@ -64,6 +66,7 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
def __init__(
self,
timeout: int = 60,
+ virtual_env_context: Optional[SimpleNamespace] = None,
work_dir: Union[Path, str] = Path("."),
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
functions_module: str = "functions",
@@ -82,8 +85,22 @@ def __init__(
PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript.
Execution policies determine whether each language's code blocks are executed or saved only.
+ ## Execution with a Python virtual environment
+ A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the
+ base environment with unwanted modules.
+ ```python
+ from autogen.code_utils import create_virtual_env
+ from autogen.coding import LocalCommandLineCodeExecutor
+
+ venv_dir = ".venv"
+ venv_context = create_virtual_env(venv_dir)
+
+ executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context)
+ ```
+
Args:
timeout (int): The timeout for code execution, default is 60 seconds.
+ virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use.
work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor.
functions_module (str): The module name under which functions are accessible.
@@ -105,6 +122,7 @@ def __init__(
self._timeout = timeout
self._work_dir: Path = work_dir
+ self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
@@ -196,7 +214,11 @@ def _setup_functions(self) -> None:
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
- cmd = [sys.executable, "-m", "pip", "install"] + required_packages
+ if self._virtual_env_context:
+ py_executable = self._virtual_env_context.env_exe
+ else:
+ py_executable = sys.executable
+ cmd = [py_executable, "-m", "pip", "install"] + required_packages
try:
result = subprocess.run(
cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
@@ -269,9 +291,18 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
program = _cmd(lang)
cmd = [program, str(written_file.absolute())]
+ env = os.environ.copy()
+
+ if self._virtual_env_context:
+ path_with_virtualenv = rf"{self._virtual_env_context.bin_path}{os.pathsep}{env['PATH']}"
+ env["PATH"] = path_with_virtualenv
+ if WIN32:
+ activation_script = os.path.join(self._virtual_env_context.bin_path, "activate.bat")
+ cmd = [activation_script, "&&", *cmd]
+
try:
result = subprocess.run(
- cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
+ cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout), env=env
)
except subprocess.TimeoutExpired:
logs_all += "\n" + TIMEOUT_MSG
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
new file mode 100644
index 00000000000..466ed62c849
--- /dev/null
+++ b/autogen/logger/file_logger.py
@@ -0,0 +1,214 @@
+from __future__ import annotations
+
+import json
+import logging
+import os
+import threading
+import uuid
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+from openai import AzureOpenAI, OpenAI
+from openai.types.chat import ChatCompletion
+
+from autogen.logger.base_logger import BaseLogger
+from autogen.logger.logger_utils import get_current_ts, to_dict
+
+from .base_logger import LLMConfig
+
+if TYPE_CHECKING:
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.gemini import GeminiClient
+
+logger = logging.getLogger(__name__)
+
+
+class FileLogger(BaseLogger):
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.session_id = str(uuid.uuid4())
+
+ curr_dir = os.getcwd()
+ self.log_dir = os.path.join(curr_dir, "autogen_logs")
+ os.makedirs(self.log_dir, exist_ok=True)
+
+ self.log_file = os.path.join(self.log_dir, self.config.get("filename", "runtime.log"))
+ try:
+ with open(self.log_file, "a"):
+ pass
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+
+ self.logger = logging.getLogger(__name__)
+ self.logger.setLevel(logging.INFO)
+ file_handler = logging.FileHandler(self.log_file)
+ self.logger.addHandler(file_handler)
+
+ def start(self) -> str:
+ """Start the logger and return the session_id."""
+ try:
+ self.logger.info(f"Started new session with Session ID: {self.session_id}")
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+ finally:
+ return self.session_id
+
+ def log_chat_completion(
+ self,
+ invocation_id: uuid.UUID,
+ client_id: int,
+ wrapper_id: int,
+ request: Dict[str, Union[float, str, List[Dict[str, str]]]],
+ response: Union[str, ChatCompletion],
+ is_cached: int,
+ cost: float,
+ start_time: str,
+ ) -> None:
+ """
+ Log a chat completion.
+ """
+ thread_id = threading.get_ident()
+ try:
+ log_data = json.dumps(
+ {
+ "invocation_id": str(invocation_id),
+ "client_id": client_id,
+ "wrapper_id": wrapper_id,
+ "request": to_dict(request),
+ "response": str(response),
+ "is_cached": is_cached,
+ "cost": cost,
+ "start_time": start_time,
+ "end_time": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log chat completion: {e}")
+
+ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) -> None:
+ """
+ Log a new agent instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "id": id(agent),
+ "agent_name": agent.name if hasattr(agent, "name") and agent.name is not None else "",
+ "wrapper_id": to_dict(
+ agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else ""
+ ),
+ "session_id": self.session_id,
+ "current_time": get_current_ts(),
+ "agent_type": type(agent).__name__,
+ "args": to_dict(init_args),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log new agent: {e}")
+
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ """
+ Log an event from an agent or a string source.
+ """
+ from autogen import Agent
+
+ # This takes an object o as input and returns a string. If the object o cannot be serialized, instead of raising an error,
+ # it returns a string indicating that the object is non-serializable, along with its type's qualified name obtained using __qualname__.
+ json_args = json.dumps(kwargs, default=lambda o: f"<>")
+ thread_id = threading.get_ident()
+
+ if isinstance(source, Agent):
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "agent_module": source.__module__,
+ "agent_class": source.__class__.__name__,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+ else:
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_wrapper(
+ self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {}
+ ) -> None:
+ """
+ Log a new wrapper instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_client(
+ self, client: AzureOpenAI | OpenAI | GeminiClient, wrapper: OpenAIWrapper, init_args: Dict[str, Any]
+ ) -> None:
+ """
+ Log a new client instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "client_id": id(client),
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "class": type(client).__name__,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def get_connection(self) -> None:
+ """Method is intentionally left blank because there is no specific connection needed for the FileLogger."""
+ pass
+
+ def stop(self) -> None:
+ """Close the file handler and remove it from the logger."""
+ for handler in self.logger.handlers:
+ if isinstance(handler, logging.FileHandler):
+ handler.close()
+ self.logger.removeHandler(handler)
diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py
index 8073c0c07d3..ed9567977bb 100644
--- a/autogen/logger/logger_factory.py
+++ b/autogen/logger/logger_factory.py
@@ -1,6 +1,7 @@
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Literal, Optional
from autogen.logger.base_logger import BaseLogger
+from autogen.logger.file_logger import FileLogger
from autogen.logger.sqlite_logger import SqliteLogger
__all__ = ("LoggerFactory",)
@@ -8,11 +9,15 @@
class LoggerFactory:
@staticmethod
- def get_logger(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> BaseLogger:
+ def get_logger(
+ logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[Dict[str, Any]] = None
+ ) -> BaseLogger:
if config is None:
config = {}
if logger_type == "sqlite":
return SqliteLogger(config)
+ elif logger_type == "file":
+ return FileLogger(config)
else:
raise ValueError(f"[logger_factory] Unknown logger type: {logger_type}")
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 6e95a571cd0..42db83d849d 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -18,6 +18,7 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.gemini import GeminiClient
logger = logging.getLogger(__name__)
lock = threading.Lock()
@@ -316,7 +317,7 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM
self._run_query(query=query, args=args)
def log_new_client(
- self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
+ self, client: Union[AzureOpenAI, OpenAI, GeminiClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
) -> None:
if self.con is None:
return
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index 3edfa40d4ec..982d1c0d57f 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -435,7 +435,8 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` to use Google OpenAI API.")
- self._clients.append(GeminiClient(**openai_config))
+ client = GeminiClient(**openai_config)
+ self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py
index fcf7e09c025..5c06a4def0c 100644
--- a/autogen/oai/gemini.py
+++ b/autogen/oai/gemini.py
@@ -5,8 +5,18 @@
llm_config={
"config_list": [{
"api_type": "google",
- "model": "models/gemini-pro",
- "api_key": os.environ.get("GOOGLE_API_KEY")
+ "model": "gemini-pro",
+ "api_key": os.environ.get("GOOGLE_API_KEY"),
+ "safety_settings": [
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
+ ],
+ "top_p":0.5,
+ "max_tokens": 2048,
+ "temperature": 1.0,
+ "top_k": 5
}
]}
@@ -47,6 +57,17 @@ class GeminiClient:
of AutoGen.
"""
+ # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
+ PARAMS_MAPPING = {
+ "max_tokens": "max_output_tokens",
+ # "n": "candidate_count", # Gemini supports only `n=1`
+ "stop_sequences": "stop_sequences",
+ "temperature": "temperature",
+ "top_p": "top_p",
+ "top_k": "top_k",
+ "max_output_tokens": "max_output_tokens",
+ }
+
def __init__(self, **kwargs):
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
@@ -93,12 +114,15 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
- params.get("temperature", 0.5)
- params.get("top_p", 1.0)
- params.get("max_tokens", 4096)
+
+ generation_config = {
+ gemini_term: params[autogen_term]
+ for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
+ if autogen_term in params
+ }
+ safety_settings = params.get("safety_settings", {})
if stream:
- # warn user that streaming is not supported
warnings.warn(
"Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
UserWarning,
@@ -112,7 +136,9 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = oai_messages_to_gemini_messages(messages)
# we use chat model by default
- model = genai.GenerativeModel(model_name)
+ model = genai.GenerativeModel(
+ model_name, generation_config=generation_config, safety_settings=safety_settings
+ )
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
@@ -142,7 +168,9 @@ def create(self, params: Dict) -> ChatCompletion:
elif model_name == "gemini-pro-vision":
# B. handle the vision model
# Gemini's vision model does not support chat history yet
- model = genai.GenerativeModel(model_name)
+ model = genai.GenerativeModel(
+ model_name, generation_config=generation_config, safety_settings=safety_settings
+ )
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index 7e738b7bd61..1ed347f6271 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -16,7 +16,10 @@
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
OAI_PRICE1K = {
- # https://openai.com/pricing
+ # https://openai.com/api/pricing/
+ # gpt-4o
+ "gpt-4o": (0.005, 0.015),
+ "gpt-4o-2024-05-13": (0.005, 0.015),
# gpt-4-turbo
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
# gpt-4
@@ -376,11 +379,10 @@ def config_list_gpt4_gpt35(
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
+ exclude: bool = False,
) -> List[Dict[str, Any]]:
- """
- This function filters `config_list` by checking each configuration dictionary against the
- criteria specified in `filter_dict`. A configuration dictionary is retained if for every
- key in `filter_dict`, see example below.
+ """This function filters `config_list` by checking each configuration dictionary against the criteria specified in
+ `filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
Args:
config_list (list of dict): A list of configuration dictionaries to be filtered.
@@ -391,71 +393,68 @@ def filter_config(
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
-
-
+ exclude (bool): If False (the default value), configs that match the filter will be included in the returned
+ list. If True, configs that match the filter will be excluded in the returned list.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
in `filter_dict`.
Example:
- ```python
- # Example configuration list with various models and API types
- configs = [
- {'model': 'gpt-3.5-turbo'},
- {'model': 'gpt-4'},
- {'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
- {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
- ]
-
- # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
- # that are also using the 'azure' API type
- filter_criteria = {
- 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
- 'api_type': ['azure'] # Only accept configurations for 'azure' API type
- }
-
- # Apply the filter to the configuration list
- filtered_configs = filter_config(configs, filter_criteria)
-
- # The resulting `filtered_configs` will be:
- # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
-
-
- # Define a filter to select a given tag
- filter_criteria = {
- 'tags': ['gpt35_turbo'],
- }
-
- # Apply the filter to the configuration list
- filtered_configs = filter_config(configs, filter_criteria)
-
- # The resulting `filtered_configs` will be:
- # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
- ```
-
+ ```python
+ # Example configuration list with various models and API types
+ configs = [
+ {'model': 'gpt-3.5-turbo'},
+ {'model': 'gpt-4'},
+ {'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
+ {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
+ ]
+ # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
+ # that are also using the 'azure' API type
+ filter_criteria = {
+ 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
+ 'api_type': ['azure'] # Only accept configurations for 'azure' API type
+ }
+ # Apply the filter to the configuration list
+ filtered_configs = filter_config(configs, filter_criteria)
+ # The resulting `filtered_configs` will be:
+ # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
+ # Define a filter to select a given tag
+ filter_criteria = {
+ 'tags': ['gpt35_turbo'],
+ }
+ # Apply the filter to the configuration list
+ filtered_configs = filter_config(configs, filter_criteria)
+ # The resulting `filtered_configs` will be:
+ # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
+ ```
Note:
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
it is considered a non-match and is excluded from the result.
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
- """
- def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
- if isinstance(config_value, list):
- return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
- else:
- return config_value in acceptable_values
+ """
if filter_dict:
- config_list = [
- config
- for config in config_list
- if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
+ return [
+ item
+ for item in config_list
+ if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
]
return config_list
+def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
+ if value is None:
+ return False
+
+ if isinstance(value, list):
+ return bool(set(value) & set(criteria_values)) # Non-empty intersection
+ else:
+ return value in criteria_values
+
+
def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
@@ -782,3 +781,10 @@ def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Di
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
+
+
+def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
+ if isinstance(config_value, list):
+ return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
+ else:
+ return config_value in acceptable_values
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index 1b9835eaa4b..ffc741482e6 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -3,16 +3,17 @@
import logging
import sqlite3
import uuid
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
-from autogen.logger.base_logger import LLMConfig
+from autogen.logger.base_logger import BaseLogger, LLMConfig
from autogen.logger.logger_factory import LoggerFactory
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.gemini import GeminiClient
logger = logging.getLogger(__name__)
@@ -20,11 +21,27 @@
is_logging = False
-def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> str:
+def start(
+ logger: Optional[BaseLogger] = None,
+ logger_type: Literal["sqlite", "file"] = "sqlite",
+ config: Optional[Dict[str, Any]] = None,
+) -> str:
+ """
+ Start logging for the runtime.
+ Args:
+ logger (BaseLogger): A logger instance
+ logger_type (str): The type of logger to use (default: sqlite)
+ config (dict): Configuration for the logger
+ Returns:
+ session_id (str(uuid.uuid4)): a unique id for the logging session
+ """
global autogen_logger
global is_logging
- autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
+ if logger:
+ autogen_logger = logger
+ else:
+ autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
try:
session_id = autogen_logger.start()
@@ -78,7 +95,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
autogen_logger.log_new_wrapper(wrapper, init_args)
-def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
+def log_new_client(
+ client: Union[AzureOpenAI, OpenAI, GeminiClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
+) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
return
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 589d7b404a7..b71dbc428a1 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -34,6 +34,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-vision-preview": 128000,
+ "gpt-4o": 128000,
+ "gpt-4o-2024-05-13": 128000,
}
return max_token_limit[model]
diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig
index 4da1adc5de6..5a604ce0096 100644
--- a/dotnet/.editorconfig
+++ b/dotnet/.editorconfig
@@ -141,7 +141,7 @@ csharp_preserve_single_line_statements = true
csharp_preserve_single_line_blocks = true
# Code block
-csharp_prefer_braces = false:none
+csharp_prefer_braces = true:warning
# Using statements
csharp_using_directive_placement = outside_namespace:error
@@ -173,6 +173,11 @@ dotnet_diagnostic.CS1573.severity = none
# disable CS1570: XML comment has badly formed XML
dotnet_diagnostic.CS1570.severity = none
+dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
+dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
+
+csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
+
# disable check for generated code
[*.generated.cs]
generated_code = true
\ No newline at end of file
diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln
index b46b8091cf5..be40e7b61b6 100644
--- a/dotnet/AutoGen.sln
+++ b/dotnet/AutoGen.sln
@@ -33,7 +33,18 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral", "src\Auto
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "test\AutoGen.Mistral.Tests\AutoGen.Mistral.Tests.csproj", "{15441693-3659-4868-B6C1-B106F52FF3BA}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama", "src\AutoGen.Ollama\AutoGen.Ollama.csproj", "{9F9E6DED-3D92-4970-909A-70FC11F1A665}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama.Tests", "test\AutoGen.Ollama.Tests\AutoGen.Ollama.Tests.csproj", "{03E31CAA-3728-48D3-B936-9F11CF6C18FE}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Ollama.Sample", "sample\AutoGen.Ollama.Sample\AutoGen.Ollama.Sample.csproj", "{93AA4D0D-6EE4-44D5-AD77-7F73A3934544}"
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.SemanticKernel.Sample", "sample\AutoGen.SemanticKernel.Sample\AutoGen.SemanticKernel.Sample.csproj", "{52958A60-3FF7-4243-9058-34A6E4F55C31}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -93,6 +104,30 @@ Global
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.Build.0 = Release|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Release|Any CPU.Build.0 = Release|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Release|Any CPU.Build.0 = Release|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Release|Any CPU.Build.0 = Release|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -111,6 +146,12 @@ Global
{6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {52958A60-3FF7-4243-9058-34A6E4F55C31} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
diff --git a/dotnet/NuGet.config b/dotnet/NuGet.config
index 2eb25136c6a..1d0cf4c2bc7 100644
--- a/dotnet/NuGet.config
+++ b/dotnet/NuGet.config
@@ -2,8 +2,6 @@
-
-
diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props
index 8aff3c60226..0444dadfd5e 100644
--- a/dotnet/eng/MetaInfo.props
+++ b/dotnet/eng/MetaInfo.props
@@ -1,7 +1,7 @@
- 0.0.13
+ 0.0.14
AutoGen
https://microsoft.github.io/autogen-for-net/
https://github.com/microsoft/autogen
diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props
index b9fc4367194..ae213015471 100644
--- a/dotnet/eng/Version.props
+++ b/dotnet/eng/Version.props
@@ -10,7 +10,7 @@
6.8.0
2.4.2
17.7.0
- 1.0.0-beta.23523.2
+ 1.0.0-beta.24229.4
8.0.0
4.0.0
diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
index 0cafff3c0d0..afc76164906 100644
--- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
+++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
@@ -4,7 +4,6 @@
Exe
$(TestTargetFramework)
enable
- enable
True
$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
index 022f7e9f984..cf045221223 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
@@ -84,7 +84,7 @@ public async Task CreateOpenAIChatAgentAsync()
new TextMessage(Role.Assistant, "Hello", from: "user"),
],
from: "user"),
- new Message(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
+ new TextMessage(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
};
foreach (var message in messages)
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
index 57b9ea76dcb..0ef8eaa48ae 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
@@ -77,20 +77,30 @@ public static async Task RunAsync()
// talk to the assistant agent
var upperCase = await agent.SendAsync("convert to upper case: hello world");
upperCase.GetContent()?.Should().Be("HELLO WORLD");
- upperCase.Should().BeOfType>();
+ upperCase.Should().BeOfType();
upperCase.GetToolCalls().Should().HaveCount(1);
upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase));
var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e");
concatString.GetContent()?.Should().Be("a b c d e");
- concatString.Should().BeOfType>();
+ concatString.Should().BeOfType();
concatString.GetToolCalls().Should().HaveCount(1);
concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString));
var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1");
calculateTax.GetContent().Should().Be("tax is 10");
- calculateTax.Should().BeOfType>();
+ calculateTax.Should().BeOfType();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
+
+ // parallel function calls
+ var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
+ calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
+ calculateTaxes.Should().BeOfType();
+ calculateTaxes.GetToolCalls().Should().HaveCount(2);
+ calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
+
+ // send aggregate message back to llm to get the final result
+ var finalResult = await agent.SendAsync(calculateTaxes);
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
index c5d9a01f971..47dd8ce66c9 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
@@ -40,7 +40,8 @@ public static async Task RunAsync()
name: "groupAdmin",
systemMessage: "You are the admin of the group chat",
temperature: 0f,
- config: gptConfig);
+ config: gptConfig)
+ .RegisterPrintMessage();
var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE, humanInputMode: HumanInputMode.NEVER)
.RegisterPrintMessage();
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
index 9fccd7ab385..2d21615ef71 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example05_Dalle_And_GPT4V.cs
-using AutoGen;
using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
using FluentAssertions;
using autogen = AutoGen.LLMConfigAPI;
@@ -66,50 +67,39 @@ public static async Task RunAsync()
File.Delete(imagePath);
}
- var dalleAgent = new AssistantAgent(
- name: "dalle",
- systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = gpt35Config,
- FunctionContracts = new[]
- {
- instance.GenerateImageFunctionContract,
- },
- },
+ var generateImageFunctionMiddleware = new FunctionCallMiddleware(
+ functions: [instance.GenerateImageFunctionContract],
functionMap: new Dictionary>>
{
{ nameof(GenerateImage), instance.GenerateImageWrapper },
- })
+ });
+ var dalleAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ modelName: "gpt-3.5-turbo",
+ name: "dalle",
+ systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(generateImageFunctionMiddleware)
.RegisterMiddleware(async (msgs, option, agent, ct) =>
{
- // if last message contains [TERMINATE], then find the last image url and terminate the conversation
- if (msgs.Last().GetContent()?.Contains("TERMINATE") is true)
+ if (msgs.Any(msg => msg.GetContent()?.ToLower().Contains("approve") is true))
{
- var lastMessageWithImage = msgs.Last(msg => msg is ImageMessage) as ImageMessage;
- var lastImageUrl = lastMessageWithImage.Url;
- Console.WriteLine($"download image from {lastImageUrl} to {imagePath}");
- var httpClient = new HttpClient();
- var imageBytes = await httpClient.GetByteArrayAsync(lastImageUrl);
- File.WriteAllBytes(imagePath, imageBytes);
-
- var messageContent = $@"{GroupChatExtension.TERMINATE}
-
-{lastImageUrl}";
- return new TextMessage(Role.Assistant, messageContent)
- {
- From = "dalle",
- };
+ return new TextMessage(Role.Assistant, $"The image satisfies the condition, conversation is terminated. {GroupChatExtension.TERMINATE}");
}
- var reply = await agent.GenerateReplyAsync(msgs, option, ct);
+ var msgsWithoutImage = msgs.Where(msg => msg is not ImageMessage).ToList();
+ var reply = await agent.GenerateReplyAsync(msgsWithoutImage, option, ct);
if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
{
var imageUrl = content.Split("\n").Last();
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From);
+ Console.WriteLine($"download image from {imageUrl} to {imagePath}");
+ var httpClient = new HttpClient();
+ var imageBytes = await httpClient.GetByteArrayAsync(imageUrl, ct);
+ File.WriteAllBytes(imagePath, imageBytes);
+
return imageMessage;
}
else
@@ -119,33 +109,25 @@ public static async Task RunAsync()
})
.RegisterPrintMessage();
- var gpt4VAgent = new AssistantAgent(
+ var gpt4VAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
name: "gpt4v",
+ modelName: "gpt-4-vision-preview",
systemMessage: @"You are a critism that provide feedback to DALL-E agent.
Carefully check the image generated by DALL-E agent and provide feedback.
-If the image satisfies the condition, then terminate the conversation by saying [TERMINATE].
+If the image satisfies the condition, then say [APPROVE].
Otherwise, provide detailed feedback to DALL-E agent so it can generate better image.
The image should satisfy the following conditions:
- There should be a cat and a mouse in the image
-- The cat should be chasing after the mouse
-",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = gpt4vConfig,
- })
+- The cat should be chasing after the mouse")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
- IEnumerable conversation = new List()
- {
- new TextMessage(Role.User, "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse")
- };
- var maxRound = 20;
await gpt4VAgent.InitiateChatAsync(
receiver: dalleAgent,
message: "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse",
- maxRound: maxRound);
+ maxRound: 10);
File.Exists(imagePath).Should().BeTrue();
}
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
new file mode 100644
index 00000000000..1dc94400869
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
@@ -0,0 +1,24 @@
+
+
+ Exe
+ $(TestTargetFramework)
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+
+
+
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+
+
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs
new file mode 100644
index 00000000000..e1af08c574c
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_LLaMA.cs
+
+using AutoGen.Core;
+using AutoGen.Ollama.Extension;
+
+namespace AutoGen.Ollama.Sample;
+
+public class Chat_With_LLaMA
+{
+ public static async Task RunAsync()
+ {
+ using var httpClient = new HttpClient()
+ {
+ BaseAddress = new Uri("https://2xbvtxd1-11434.usw2.devtunnels.ms")
+ };
+
+ var ollamaAgent = new OllamaAgent(
+ httpClient: httpClient,
+ name: "ollama",
+ modelName: "llama3:latest",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var reply = await ollamaAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ }
+}
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs
new file mode 100644
index 00000000000..b1b310e3956
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_LLaVA.cs
+
+using AutoGen.Core;
+using AutoGen.Ollama.Extension;
+
+namespace AutoGen.Ollama.Sample;
+
+public class Chat_With_LLaVA
+{
+ public static async Task RunAsync()
+ {
+ using var httpClient = new HttpClient()
+ {
+ BaseAddress = new Uri("https://2xbvtxd1-11434.usw2.devtunnels.ms")
+ };
+
+ var ollamaAgent = new OllamaAgent(
+ httpClient: httpClient,
+ name: "ollama",
+ modelName: "llava:latest",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var image = Path.Combine("images", "background.png");
+ var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
+ var imageMessage = new ImageMessage(Role.User, binaryData);
+ var textMessage = new TextMessage(Role.User, "what's in this image?");
+ var reply = await ollamaAgent.SendAsync(chatHistory: [textMessage, imageMessage]);
+
+ // You can also use MultiModalMessage to put text and image together in one message
+ // In this case, all the messages in the multi-modal message will be put into single piece of message
+ // where the text is the concatenation of all the text messages seperated by \n
+ // and the images are all the images in the multi-modal message
+ var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
+
+ reply = await ollamaAgent.SendAsync(chatHistory: [multiModalMessage]);
+ }
+}
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Program.cs b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs
new file mode 100644
index 00000000000..62c92eebe7e
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.Ollama.Sample;
+
+await Chat_With_LLaVA.RunAsync();
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/images/background.png b/dotnet/sample/AutoGen.Ollama.Sample/images/background.png
new file mode 100644
index 00000000000..ca276f81f5b
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/images/background.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:300b7c9d6ba0c23a3e52fbd2e268141ddcca0434a9fb9dcf7e58e7e903d36dcf
+size 2126185
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
new file mode 100644
index 00000000000..6c226651292
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
@@ -0,0 +1,17 @@
+
+
+
+ Exe
+ $(TestTargetFramework)
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ enable
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs
new file mode 100644
index 00000000000..3333cdd9ad9
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Semantic_Kernel_Agent.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Create_Semantic_Kernel_Agent
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernel = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey)
+ .Build();
+
+ var skAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ await skAgent.SendAsync("Hey tell me a long tedious joke");
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs
new file mode 100644
index 00000000000..0caea6f031f
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Semantic_Kernel_Chat_Agent.cs
+
+using AutoGen.Core;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Agents;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Create_Semantic_Kernel_Chat_Agent
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernel = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey)
+ .Build();
+
+ // The built-in ChatCompletionAgent from semantic kernel.
+ var chatAgent = new ChatCompletionAgent()
+ {
+ Kernel = kernel,
+ Name = "assistant",
+ Description = "You are a helpful AI assistant",
+ };
+
+ var messageConnector = new SemanticKernelChatMessageContentConnector();
+ var skAgent = new SemanticKernelChatCompletionAgent(chatAgent)
+ .RegisterMiddleware(messageConnector) // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ await skAgent.SendAsync("Hey tell me a long tedious joke");
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs
new file mode 100644
index 00000000000..5032f2d4330
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.SemanticKernel.Sample;
+
+await Use_Kernel_Functions_With_Other_Agent.RunAsync();
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs
new file mode 100644
index 00000000000..4cebc88291f
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Bing_Search_With_Semantic_Kernel_Agent.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Plugins.Web;
+using Microsoft.SemanticKernel.Plugins.Web.Bing;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Use_Bing_Search_With_Semantic_Kernel_Agent
+{
+ public static async Task RunAsync()
+ {
+ var bingApiKey = Environment.GetEnvironmentVariable("BING_API_KEY") ?? throw new Exception("BING_API_KEY environment variable is not set");
+ var bingSearch = new BingConnector(bingApiKey);
+ var webSearchPlugin = new WebSearchEnginePlugin(bingSearch);
+
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernelBuilder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ kernelBuilder.Plugins.AddFromObject(webSearchPlugin);
+
+ var kernel = kernelBuilder.Build();
+
+ var skAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ await skAgent.SendAsync("Tell me more about gpt-4-o");
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
new file mode 100644
index 00000000000..d91d727668a
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Kernel_Functions_With_Other_Agent.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using Microsoft.SemanticKernel;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Use_Kernel_Functions_With_Other_Agent
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernelBuilder = Kernel.CreateBuilder();
+ var kernel = kernelBuilder.Build();
+ var getWeatherFunction = KernelFunctionFactory.CreateFromMethod(
+ method: (string location) => $"The weather in {location} is 75 degrees Fahrenheit.",
+ functionName: "GetWeather",
+ description: "Get the weather for a location.");
+ var plugin = kernel.CreatePluginFromFunctions("my_plugin", [getWeatherFunction]);
+
+ // Create a middleware to handle the plugin functions
+ var kernelPluginMiddleware = new KernelPluginMiddleware(kernel, plugin);
+
+ var openAIClient = new OpenAIClient(openAIKey);
+ var openAIAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: modelId)
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterMiddleware(kernelPluginMiddleware) // register the middleware to handle the plugin functions
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ var toolAggregateMessage = await openAIAgent.SendAsync("Tell me the weather in Seattle");
+
+ // The aggregate message will be converted to [ToolCallMessage, ToolCallResultMessage] when flowing into the agent
+ // send the aggregated message to llm to generate the final response
+ var finalReply = await openAIAgent.SendAsync(toolAggregateMessage);
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
index 47dbad55e30..3dbba9668f9 100644
--- a/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
+++ b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageExtension.cs
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -15,7 +16,9 @@ public static string FormatMessage(this IMessage message)
{
return message switch
{
+#pragma warning disable CS0618 // deprecated
Message msg => msg.FormatMessage(),
+#pragma warning restore CS0618 // deprecated
TextMessage textMessage => textMessage.FormatMessage(),
ImageMessage imageMessage => imageMessage.FormatMessage(),
ToolCallMessage toolCallMessage => toolCallMessage.FormatMessage(),
@@ -110,6 +113,8 @@ public static string FormatMessage(this AggregateMessage textMessage.Role == Role.System,
+#pragma warning disable CS0618 // deprecated
Message msg => msg.Role == Role.System,
+#pragma warning restore CS0618 // deprecated
_ => false,
};
}
///
/// Get the content from the message
- /// if the message is a or , return the content
- /// if the message is a and only contains one function call, return the result of that function call
+ /// if the message implements , return the content from the message by calling
/// if the message is a where TMessage1 is and TMessage2 is and the second message only contains one function call, return the result of that function call
/// for all other situation, return null.
///
@@ -166,10 +172,11 @@ public static bool IsSystemMessage(this IMessage message)
{
return message switch
{
- TextMessage textMessage => textMessage.Content,
+ ICanGetTextContent canGetTextContent => canGetTextContent.GetContent(),
+ AggregateMessage aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)),
+#pragma warning disable CS0618 // deprecated
Message msg => msg.Content,
- ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
- AggregateMessage aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
+#pragma warning restore CS0618 // deprecated
_ => null,
};
}
@@ -182,7 +189,9 @@ public static bool IsSystemMessage(this IMessage message)
return message switch
{
TextMessage textMessage => textMessage.Role,
+#pragma warning disable CS0618 // deprecated
Message msg => msg.Role,
+#pragma warning restore CS0618 // deprecated
ImageMessage img => img.Role,
MultiModalMessage multiModal => multiModal.Role,
_ => null,
@@ -191,8 +200,7 @@ public static bool IsSystemMessage(this IMessage message)
///
/// Return the tool calls from the message if it's available.
- /// if the message is a , return its tool calls
- /// if the message is a and the function name and function arguments are available, return a list of tool call with one item
+ /// if the message implements , return the tool calls from the message by calling
/// if the message is a where TMessage1 is and TMessage2 is , return the tool calls from the first message
///
///
@@ -201,11 +209,13 @@ public static bool IsSystemMessage(this IMessage message)
{
return message switch
{
- ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls,
+ ICanGetToolCalls canGetToolCalls => canGetToolCalls.GetToolCalls().ToList(),
+#pragma warning disable CS0618 // deprecated
Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null
- ? msg.Content is not null ? new List { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) }
- : new List { new ToolCall(msg.FunctionName, msg.FunctionArguments) }
+ ? msg.Content is not null ? [new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content)]
+ : new List { new(msg.FunctionName, msg.FunctionArguments) }
: null,
+#pragma warning restore CS0618 // deprecated
AggregateMessage aggregateMessage => aggregateMessage.Message1.ToolCalls,
_ => null,
};
diff --git a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
index 78d92508611..02f4da50bae 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
@@ -8,19 +8,6 @@
namespace AutoGen.Core;
-///
-/// Obsolete: please use
-///
-[Obsolete("please use Graph")]
-public class Workflow : Graph
-{
- [Obsolete("please use Graph")]
- public Workflow(IEnumerable transitions)
- : base(transitions)
- {
- }
-}
-
public class Graph
{
private readonly List transitions = new List();
diff --git a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
index 3b6288ca0a7..cd17a21f8b9 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
@@ -110,7 +110,7 @@ public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumera
{string.Join(",", agentNames)}
Each message will start with 'From name:', e.g:
-From admin:
+From {agentNames.First()}:
//your message//.");
var conv = this.ProcessConversationsForRolePlay(this.initializeMessages, conversationHistory);
diff --git a/dotnet/src/AutoGen.Core/Message/IMessage.cs b/dotnet/src/AutoGen.Core/Message/IMessage.cs
index 7b48f4f0d63..ad215d510e3 100644
--- a/dotnet/src/AutoGen.Core/Message/IMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/IMessage.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs
+using System.Collections.Generic;
+
namespace AutoGen.Core;
///
@@ -29,7 +31,7 @@ namespace AutoGen.Core;
/// -
/// : an aggregate message type that contains two message types.
/// This type is useful when you want to combine two message types into one unique message type. One example is when invoking a tool call and you want to return both and .
-/// One example of how this type is used in AutoGen is
+/// One example of how this type is used in AutoGen is and its return message
///
///
///
@@ -41,6 +43,24 @@ public interface IMessage : IMessage, IStreamingMessage
{
}
+///
+/// The interface for messages that can get text content.
+/// This interface will be used by to get the content from the message.
+///
+public interface ICanGetTextContent : IMessage, IStreamingMessage
+{
+ public string? GetContent();
+}
+
+///
+/// The interface for messages that can get a list of
+///
+public interface ICanGetToolCalls : IMessage, IStreamingMessage
+{
+ public IEnumerable GetToolCalls();
+}
+
+
public interface IStreamingMessage
{
string? From { get; set; }
diff --git a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs
index 1239785c411..d2e2d080300 100644
--- a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs
@@ -49,7 +49,9 @@ public ImageMessage(Role role, BinaryData data, string? from = null)
public string BuildDataUri()
{
if (this.Data is null)
+ {
throw new NullReferenceException($"{nameof(Data)}");
+ }
return $"data:{this.Data.MediaType};base64,{Convert.ToBase64String(this.Data.ToArray())}";
}
diff --git a/dotnet/src/AutoGen.Core/Message/Message.cs b/dotnet/src/AutoGen.Core/Message/Message.cs
index ec4751b9344..b31b413eca7 100644
--- a/dotnet/src/AutoGen.Core/Message/Message.cs
+++ b/dotnet/src/AutoGen.Core/Message/Message.cs
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Message.cs
+using System;
using System.Collections.Generic;
namespace AutoGen.Core;
+[Obsolete("This message class is deprecated, please use a specific AutoGen built-in message type instead. For more information, please visit https://microsoft.github.io/autogen-for-net/articles/Built-in-messages.html")]
public class Message : IMessage
{
public Message(
diff --git a/dotnet/src/AutoGen.Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
index ed4d7436dde..addd8728a92 100644
--- a/dotnet/src/AutoGen.Core/Message/TextMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
@@ -3,7 +3,7 @@
namespace AutoGen.Core;
-public class TextMessage : IMessage, IStreamingMessage
+public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
{
public TextMessage(Role role, string content, string? from = null)
{
@@ -44,9 +44,14 @@ public override string ToString()
{
return $"TextMessage({this.Role}, {this.Content}, {this.From})";
}
+
+ public string? GetContent()
+ {
+ return this.Content;
+ }
}
-public class TextMessageUpdate : IStreamingMessage
+public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
@@ -60,4 +65,9 @@ public TextMessageUpdate(Role role, string? content, string? from = null)
public string? From { get; set; }
public Role Role { get; set; }
+
+ public string? GetContent()
+ {
+ return this.Content;
+ }
}
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
new file mode 100644
index 00000000000..7781b785ef8
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionCallAggregateMessage.cs
+
+using System.Collections.Generic;
+
+namespace AutoGen.Core;
+
+///
+/// An aggregate message that contains a tool call message and a tool call result message.
+/// This message type is used by to return both and .
+///
+public class ToolCallAggregateMessage : AggregateMessage, ICanGetTextContent, ICanGetToolCalls
+{
+ public ToolCallAggregateMessage(ToolCallMessage message1, ToolCallResultMessage message2, string? from = null)
+ : base(message1, message2, from)
+ {
+ }
+
+ public string? GetContent()
+ {
+ return this.Message2.GetContent();
+ }
+
+ public IEnumerable GetToolCalls()
+ {
+ return this.Message1.GetToolCalls();
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
index 8dcd98ea0ec..396dba3d3a1 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
@@ -26,6 +26,8 @@ public ToolCall(string functionName, string functionArgs, string result)
public string FunctionArguments { get; set; }
+ public string? ToolCallId { get; set; }
+
public string? Result { get; set; }
public override string ToString()
@@ -34,7 +36,7 @@ public override string ToString()
}
}
-public class ToolCallMessage : IMessage
+public class ToolCallMessage : IMessage, ICanGetToolCalls
{
public ToolCallMessage(IEnumerable toolCalls, string? from = null)
{
@@ -45,7 +47,7 @@ public ToolCallMessage(IEnumerable toolCalls, string? from = null)
public ToolCallMessage(string functionName, string functionArgs, string? from = null)
{
this.From = from;
- this.ToolCalls = new List { new ToolCall(functionName, functionArgs) };
+ this.ToolCalls = new List { new ToolCall(functionName, functionArgs) { ToolCallId = functionName } };
}
public ToolCallMessage(ToolCallMessageUpdate update)
@@ -89,6 +91,11 @@ public override string ToString()
return sb.ToString();
}
+
+ public IEnumerable GetToolCalls()
+ {
+ return this.ToolCalls;
+ }
}
public class ToolCallMessageUpdate : IStreamingMessage
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs
index 99c7740849a..fa7357c941c 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs
@@ -7,7 +7,7 @@
namespace AutoGen.Core;
-public class ToolCallResultMessage : IMessage
+public class ToolCallResultMessage : IMessage, ICanGetTextContent
{
public ToolCallResultMessage(IEnumerable toolCalls, string? from = null)
{
@@ -18,7 +18,7 @@ public ToolCallResultMessage(IEnumerable toolCalls, string? from = nul
public ToolCallResultMessage(string result, string functionName, string functionArgs, string? from = null)
{
this.From = from;
- var toolCall = new ToolCall(functionName, functionArgs);
+ var toolCall = new ToolCall(functionName, functionArgs) { ToolCallId = functionName };
toolCall.Result = result;
this.ToolCalls = [toolCall];
}
@@ -30,6 +30,15 @@ public ToolCallResultMessage(string result, string functionName, string function
public string? From { get; set; }
+ public string? GetContent()
+ {
+ var results = this.ToolCalls
+ .Where(x => x.Result != null)
+ .Select(x => x.Result);
+
+ return string.Join("\n", results);
+ }
+
public override string ToString()
{
var sb = new StringBuilder();
@@ -41,16 +50,4 @@ public override string ToString()
return sb.ToString();
}
-
- private void Validate()
- {
- // each tool call must have a result
- foreach (var toolCall in this.ToolCalls)
- {
- if (string.IsNullOrEmpty(toolCall.Result))
- {
- throw new System.ArgumentException($"The tool call {toolCall} does not have a result");
- }
- }
- }
}
diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
index 2bc02805538..d0788077b59 100644
--- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
@@ -18,8 +18,7 @@ namespace AutoGen.Core;
/// Otherwise, the message will be sent to the inner agent. In this situation
/// if the reply from the inner agent is ,
/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked,
-/// and a where TMessage1 is and TMessage2 is "/>
-/// will be returned.
+/// and a will be returned.
///
/// If the reply from the inner agent is but the tool calls is not available in this middleware's function map,
/// or the reply from the inner agent is not , the original reply from the inner agent will be returned.
@@ -128,13 +127,13 @@ private async Task InvokeToolCallMessagesBeforeInvokingAg
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
- toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
+ toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
- toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
+ toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
}
else
{
@@ -156,14 +155,14 @@ private async Task InvokeToolCallMessagesAfterInvokingAgentAsync(ToolC
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
- toolCallResult.Add(new ToolCall(fName, fArgs, result));
+ toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
}
}
if (toolCallResult.Count() > 0)
{
var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name);
- return new AggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name);
+ return new ToolCallAggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name);
}
else
{
diff --git a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
index 57fcb1fce16..72c67fe7801 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
+++ b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
@@ -19,7 +19,7 @@
-
+
@@ -27,14 +27,12 @@
-
-
+
-
-
+
diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
index 5587694882c..bb5504cd548 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
@@ -12,57 +12,58 @@ namespace AutoGen.DotnetInteractive;
public class DotnetInteractiveFunction : IDisposable
{
private readonly InteractiveService? _interactiveService = null;
- private string? _notebookPath;
+ private string _notebookPath;
private readonly KernelInfoCollection _kernelInfoCollection = new KernelInfoCollection();
+ ///
+ /// Create an instance of "
+ ///
+ /// interactive service to use.
+ /// notebook path if provided.
public DotnetInteractiveFunction(InteractiveService interactiveService, string? notebookPath = null, bool continueFromExistingNotebook = false)
{
this._interactiveService = interactiveService;
- this._notebookPath = notebookPath;
+ this._notebookPath = notebookPath ?? Path.GetTempPath() + "notebook.ipynb";
this._kernelInfoCollection.Add(new KernelInfo("csharp"));
this._kernelInfoCollection.Add(new KernelInfo("markdown"));
-
- if (this._notebookPath != null)
+ if (continueFromExistingNotebook == false)
{
- if (continueFromExistingNotebook == false)
+ // remove existing notebook
+ if (File.Exists(this._notebookPath))
{
- // remove existing notebook
- if (File.Exists(this._notebookPath))
- {
- File.Delete(this._notebookPath);
- }
+ File.Delete(this._notebookPath);
+ }
- var document = new InteractiveDocument();
+ var document = new InteractiveDocument();
- using var stream = File.OpenWrite(_notebookPath);
- Notebook.Write(document, stream, this._kernelInfoCollection);
- stream.Flush();
- stream.Dispose();
- }
- else if (continueFromExistingNotebook == true && File.Exists(this._notebookPath))
+ using var stream = File.OpenWrite(_notebookPath);
+ Notebook.Write(document, stream, this._kernelInfoCollection);
+ stream.Flush();
+ stream.Dispose();
+ }
+ else if (continueFromExistingNotebook == true && File.Exists(this._notebookPath))
+ {
+ // load existing notebook
+ using var readStream = File.OpenRead(this._notebookPath);
+ var document = Notebook.Read(readStream, this._kernelInfoCollection);
+ foreach (var cell in document.Elements)
{
- // load existing notebook
- using var readStream = File.OpenRead(this._notebookPath);
- var document = Notebook.Read(readStream, this._kernelInfoCollection);
- foreach (var cell in document.Elements)
+ if (cell.KernelName == "csharp")
{
- if (cell.KernelName == "csharp")
- {
- var code = cell.Contents;
- this._interactiveService.SubmitCSharpCodeAsync(code, default).Wait();
- }
+ var code = cell.Contents;
+ this._interactiveService.SubmitCSharpCodeAsync(code, default).Wait();
}
}
- else
- {
- // create an empty notebook
- var document = new InteractiveDocument();
+ }
+ else
+ {
+ // create an empty notebook
+ var document = new InteractiveDocument();
- using var stream = File.OpenWrite(_notebookPath);
- Notebook.Write(document, stream, this._kernelInfoCollection);
- stream.Flush();
- stream.Dispose();
- }
+ using var stream = File.OpenWrite(_notebookPath);
+ Notebook.Write(document, stream, this._kernelInfoCollection);
+ stream.Flush();
+ stream.Dispose();
}
}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
index 0dc34f24e44..7490b64e126 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
@@ -5,7 +5,6 @@
using System.Reactive.Linq;
using System.Reflection;
using Microsoft.DotNet.Interactive;
-using Microsoft.DotNet.Interactive.App.Connection;
using Microsoft.DotNet.Interactive.Commands;
using Microsoft.DotNet.Interactive.Connection;
using Microsoft.DotNet.Interactive.Events;
@@ -41,7 +40,7 @@ public InteractiveService(string installingDirectory)
public async Task StartAsync(string workingDirectory, CancellationToken ct = default)
{
- this.kernel = await this.CreateKernelAsync(workingDirectory, ct);
+ this.kernel = await this.CreateKernelAsync(workingDirectory, true, ct);
return true;
}
@@ -84,7 +83,51 @@ public async Task StartAsync(string workingDirectory, CancellationToken ct
return await this.SubmitCommandAsync(command, ct);
}
- private async Task CreateKernelAsync(string workingDirectory, CancellationToken ct = default)
+ public bool RestoreDotnetInteractive()
+ {
+ this.WriteLine("Restore dotnet interactive tool");
+ // write RestoreInteractive.config from embedded resource to this.workingDirectory
+ var assembly = Assembly.GetAssembly(typeof(InteractiveService))!;
+ var resourceName = "AutoGen.DotnetInteractive.RestoreInteractive.config";
+ using (var stream = assembly.GetManifestResourceStream(resourceName)!)
+ using (var fileStream = File.Create(Path.Combine(this.installingDirectory, "RestoreInteractive.config")))
+ {
+ stream.CopyTo(fileStream);
+ }
+
+ // write dotnet-tool.json from embedded resource to this.workingDirectory
+
+ resourceName = "AutoGen.DotnetInteractive.dotnet-tools.json";
+ using (var stream2 = assembly.GetManifestResourceStream(resourceName)!)
+ using (var fileStream2 = File.Create(Path.Combine(this.installingDirectory, "dotnet-tools.json")))
+ {
+ stream2.CopyTo(fileStream2);
+ }
+
+ var psi = new ProcessStartInfo
+ {
+ FileName = "dotnet",
+ Arguments = $"tool restore --configfile RestoreInteractive.config",
+ WorkingDirectory = this.installingDirectory,
+ RedirectStandardInput = true,
+ RedirectStandardOutput = true,
+ RedirectStandardError = true,
+ UseShellExecute = false,
+ CreateNoWindow = true,
+ };
+
+ using var process = new Process { StartInfo = psi };
+ process.OutputDataReceived += this.PrintProcessOutput;
+ process.ErrorDataReceived += this.PrintProcessOutput;
+ process.Start();
+ process.BeginErrorReadLine();
+ process.BeginOutputReadLine();
+ process.WaitForExit();
+
+ return process.ExitCode == 0;
+ }
+
+ private async Task CreateKernelAsync(string workingDirectory, bool restoreWhenFail = true, CancellationToken ct = default)
{
try
{
@@ -139,13 +182,13 @@ await rootProxyKernel.SendAsync(
return compositeKernel;
}
- catch (CommandLineInvocationException ex) when (ex.Message.Contains("Cannot find a tool in the manifest file that has a command named 'dotnet-interactive'"))
+ catch (CommandLineInvocationException) when (restoreWhenFail)
{
var success = this.RestoreDotnetInteractive();
if (success)
{
- return await this.CreateKernelAsync(workingDirectory, ct);
+ return await this.CreateKernelAsync(workingDirectory, false, ct);
}
throw;
@@ -176,50 +219,6 @@ private void WriteLine(string data)
this.Output?.Invoke(this, data);
}
- private bool RestoreDotnetInteractive()
- {
- this.WriteLine("Restore dotnet interactive tool");
- // write RestoreInteractive.config from embedded resource to this.workingDirectory
- var assembly = Assembly.GetAssembly(typeof(InteractiveService))!;
- var resourceName = "AutoGen.DotnetInteractive.RestoreInteractive.config";
- using (var stream = assembly.GetManifestResourceStream(resourceName)!)
- using (var fileStream = File.Create(Path.Combine(this.installingDirectory, "RestoreInteractive.config")))
- {
- stream.CopyTo(fileStream);
- }
-
- // write dotnet-tool.json from embedded resource to this.workingDirectory
-
- resourceName = "AutoGen.DotnetInteractive.dotnet-tools.json";
- using (var stream2 = assembly.GetManifestResourceStream(resourceName)!)
- using (var fileStream2 = File.Create(Path.Combine(this.installingDirectory, "dotnet-tools.json")))
- {
- stream2.CopyTo(fileStream2);
- }
-
- var psi = new ProcessStartInfo
- {
- FileName = "dotnet",
- Arguments = $"tool restore --configfile RestoreInteractive.config",
- WorkingDirectory = this.installingDirectory,
- RedirectStandardInput = true,
- RedirectStandardOutput = true,
- RedirectStandardError = true,
- UseShellExecute = false,
- CreateNoWindow = true,
- };
-
- using var process = new Process { StartInfo = psi };
- process.OutputDataReceived += this.PrintProcessOutput;
- process.ErrorDataReceived += this.PrintProcessOutput;
- process.Start();
- process.BeginErrorReadLine();
- process.BeginOutputReadLine();
- process.WaitForExit();
-
- return process.ExitCode == 0;
- }
-
private void PrintProcessOutput(object sender, DataReceivedEventArgs e)
{
if (!string.IsNullOrEmpty(e.Data))
diff --git a/dotnet/src/AutoGen.DotnetInteractive/dotnet-tools.json b/dotnet/src/AutoGen.DotnetInteractive/dotnet-tools.json
index b2677b61678..12b09e61cae 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/dotnet-tools.json
+++ b/dotnet/src/AutoGen.DotnetInteractive/dotnet-tools.json
@@ -3,7 +3,7 @@
"isRoot": true,
"tools": {
"Microsoft.dotnet-interactive": {
- "version": "1.0.431302",
+ "version": "1.0.522904",
"commands": [
"dotnet-interactive"
]
diff --git a/dotnet/src/AutoGen.Mistral/DTOs/ChatMessage.cs b/dotnet/src/AutoGen.Mistral/DTOs/ChatMessage.cs
index c5dae2aa34d..b0fa1757c12 100644
--- a/dotnet/src/AutoGen.Mistral/DTOs/ChatMessage.cs
+++ b/dotnet/src/AutoGen.Mistral/DTOs/ChatMessage.cs
@@ -13,7 +13,7 @@ public class ChatMessage
///
/// role.
/// content.
- public ChatMessage(RoleEnum? role = default(RoleEnum?), string? content = null)
+ public ChatMessage(RoleEnum? role = default, string? content = null)
{
this.Role = role;
this.Content = content;
@@ -67,18 +67,25 @@ public enum RoleEnum
[JsonPropertyName("tool_calls")]
public List? ToolCalls { get; set; }
+
+ [JsonPropertyName("tool_call_id")]
+ public string? ToolCallId { get; set; }
}
public class FunctionContent
{
- public FunctionContent(FunctionCall function)
+ public FunctionContent(string id, FunctionCall function)
{
this.Function = function;
+ this.Id = id;
}
[JsonPropertyName("function")]
public FunctionCall Function { get; set; }
+ [JsonPropertyName("id")]
+ public string Id { get; set; }
+
public class FunctionCall
{
public FunctionCall(string name, string arguments)
diff --git a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
index 3ba910aa700..95592e97fcc 100644
--- a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
+++ b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
@@ -158,7 +158,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
else if (finishReason == Choice.FinishReasonEnum.ToolCalls)
{
var functionContents = choice.Message?.ToolCalls ?? throw new ArgumentNullException("choice.Message.ToolCalls");
- var toolCalls = functionContents.Select(f => new ToolCall(f.Function.Name, f.Function.Arguments)).ToList();
+ var toolCalls = functionContents.Select(f => new ToolCall(f.Function.Name, f.Function.Arguments) { ToolCallId = f.Id }).ToList();
return new ToolCallMessage(toolCalls, from: from.Name);
}
else
@@ -257,6 +257,7 @@ private IEnumerable> ProcessToolCallResultMessage(ToolCall
var message = new ChatMessage(ChatMessage.RoleEnum.Tool, content: toolCall.Result)
{
Name = toolCall.FunctionName,
+ ToolCallId = toolCall.ToolCallId,
};
messages.Add(message);
@@ -305,10 +306,12 @@ private IEnumerable> ProcessToolCallMessage(ToolCallMessag
// convert tool call message to chat message
var chatMessage = new ChatMessage(ChatMessage.RoleEnum.Assistant);
chatMessage.ToolCalls = new List();
- foreach (var toolCall in toolCallMessage.ToolCalls)
+ for (var i = 0; i < toolCallMessage.ToolCalls.Count; i++)
{
+ var toolCall = toolCallMessage.ToolCalls[i];
+ var toolCallId = toolCall.ToolCallId ?? $"{toolCall.FunctionName}_{i}";
var functionCall = new FunctionContent.FunctionCall(toolCall.FunctionName, toolCall.FunctionArguments);
- var functionContent = new FunctionContent(functionCall);
+ var functionContent = new FunctionContent(toolCallId, functionCall);
chatMessage.ToolCalls.Add(functionContent);
}
diff --git a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
new file mode 100644
index 00000000000..9ef68388d60
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
@@ -0,0 +1,185 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaAgent.cs
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Net.Http;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Core;
+
+namespace AutoGen.Ollama;
+
+///
+/// An agent that can interact with ollama models.
+///
+public class OllamaAgent : IStreamingAgent
+{
+ private readonly HttpClient _httpClient;
+ private readonly string _modelName;
+ private readonly string _systemMessage;
+ private readonly OllamaReplyOptions? _replyOptions;
+
+ public OllamaAgent(HttpClient httpClient, string name, string modelName,
+ string systemMessage = "You are a helpful AI assistant",
+ OllamaReplyOptions? replyOptions = null)
+ {
+ Name = name;
+ _httpClient = httpClient;
+ _modelName = modelName;
+ _systemMessage = systemMessage;
+ _replyOptions = replyOptions;
+ }
+
+ public async Task GenerateReplyAsync(
+ IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default)
+ {
+ ChatRequest request = await BuildChatRequest(messages, options);
+ request.Stream = false;
+ var httpRequest = BuildRequest(request);
+ using (HttpResponseMessage? response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseContentRead, cancellation))
+ {
+ response.EnsureSuccessStatusCode();
+ Stream? streamResponse = await response.Content.ReadAsStreamAsync();
+ ChatResponse chatResponse = await JsonSerializer.DeserializeAsync(streamResponse, cancellationToken: cancellation)
+ ?? throw new Exception("Failed to deserialize response");
+ var output = new MessageEnvelope(chatResponse, from: Name);
+ return output;
+ }
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ ChatRequest request = await BuildChatRequest(messages, options);
+ request.Stream = true;
+ HttpRequestMessage message = BuildRequest(request);
+ using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
+ {
+ response.EnsureSuccessStatusCode();
+ using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
+ using var reader = new StreamReader(stream);
+
+ while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
+ {
+ string? line = await reader.ReadLineAsync();
+ if (string.IsNullOrWhiteSpace(line))
+ {
+ continue;
+ }
+
+ ChatResponseUpdate? update = JsonSerializer.Deserialize(line);
+ if (update is { Done: false })
+ {
+ yield return new MessageEnvelope(update, from: Name);
+ }
+ else
+ {
+ var finalUpdate = JsonSerializer.Deserialize(line) ?? throw new Exception("Failed to deserialize response");
+
+ yield return new MessageEnvelope(finalUpdate, from: Name);
+ }
+ }
+ }
+ }
+
+ public string Name { get; }
+
+ private async Task BuildChatRequest(IEnumerable messages, GenerateReplyOptions? options)
+ {
+ var request = new ChatRequest
+ {
+ Model = _modelName,
+ Messages = await BuildChatHistory(messages)
+ };
+
+ if (options is OllamaReplyOptions replyOptions)
+ {
+ BuildChatRequestOptions(replyOptions, request);
+ return request;
+ }
+
+ if (_replyOptions != null)
+ {
+ BuildChatRequestOptions(_replyOptions, request);
+ return request;
+ }
+ return request;
+ }
+ private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request)
+ {
+ request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null;
+ request.Template = replyOptions.Template;
+ request.KeepAlive = replyOptions.KeepAlive;
+
+ if (replyOptions.Temperature != null
+ || replyOptions.MaxToken != null
+ || replyOptions.StopSequence != null
+ || replyOptions.Seed != null
+ || replyOptions.MiroStat != null
+ || replyOptions.MiroStatEta != null
+ || replyOptions.MiroStatTau != null
+ || replyOptions.NumCtx != null
+ || replyOptions.NumGqa != null
+ || replyOptions.NumGpu != null
+ || replyOptions.NumThread != null
+ || replyOptions.RepeatLastN != null
+ || replyOptions.RepeatPenalty != null
+ || replyOptions.TopK != null
+ || replyOptions.TopP != null
+ || replyOptions.TfsZ != null)
+ {
+ request.Options = new ModelReplyOptions
+ {
+ Temperature = replyOptions.Temperature,
+ NumPredict = replyOptions.MaxToken,
+ Stop = replyOptions.StopSequence?[0],
+ Seed = replyOptions.Seed,
+ MiroStat = replyOptions.MiroStat,
+ MiroStatEta = replyOptions.MiroStatEta,
+ MiroStatTau = replyOptions.MiroStatTau,
+ NumCtx = replyOptions.NumCtx,
+ NumGqa = replyOptions.NumGqa,
+ NumGpu = replyOptions.NumGpu,
+ NumThread = replyOptions.NumThread,
+ RepeatLastN = replyOptions.RepeatLastN,
+ RepeatPenalty = replyOptions.RepeatPenalty,
+ TopK = replyOptions.TopK,
+ TopP = replyOptions.TopP,
+ TfsZ = replyOptions.TfsZ
+ };
+ }
+ }
+ private async Task> BuildChatHistory(IEnumerable messages)
+ {
+ var history = messages.Select(m => m switch
+ {
+ IMessage chatMessage => chatMessage.Content,
+ _ => throw new ArgumentException("Invalid message type")
+ });
+
+ // if there's no system message in the history, add one to the beginning
+ if (!history.Any(m => m.Role == "system"))
+ {
+ history = new[] { new Message() { Role = "system", Value = _systemMessage } }.Concat(history);
+ }
+
+ return history.ToList();
+ }
+
+ private static HttpRequestMessage BuildRequest(ChatRequest request)
+ {
+ string serialized = JsonSerializer.Serialize(request);
+ return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint)
+ {
+ Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType)
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj
new file mode 100644
index 00000000000..20924a476b7
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ AutoGen.Ollama
+ True
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs b/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs
new file mode 100644
index 00000000000..3b0cf04a1a0
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatRequest.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class ChatRequest
+{
+ ///
+ /// (required) the model name
+ ///
+ [JsonPropertyName("model")]
+ public string Model { get; set; } = string.Empty;
+
+ ///
+ /// the messages of the chat, this can be used to keep a chat memory
+ ///
+ [JsonPropertyName("messages")]
+ public IList Messages { get; set; } = [];
+
+ ///
+ /// the format to return a response in. Currently, the only accepted value is json
+ ///
+ [JsonPropertyName("format")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public string? Format { get; set; }
+
+ ///
+ /// additional model parameters listed in the documentation for the Modelfile such as temperature
+ ///
+ [JsonPropertyName("options")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public ModelReplyOptions? Options { get; set; }
+ ///
+ /// the prompt template to use (overrides what is defined in the Modelfile)
+ ///
+ [JsonPropertyName("template")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public string? Template { get; set; }
+ ///
+ /// if false the response will be returned as a single response object, rather than a stream of objects
+ ///
+ [JsonPropertyName("stream")]
+ public bool Stream { get; set; }
+ ///
+ /// controls how long the model will stay loaded into memory following the request (default: 5m)
+ ///
+ [JsonPropertyName("keep_alive")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public string? KeepAlive { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/ChatResponse.cs b/dotnet/src/AutoGen.Ollama/DTOs/ChatResponse.cs
new file mode 100644
index 00000000000..7d8142de785
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/ChatResponse.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatResponse.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class ChatResponse : ChatResponseUpdate
+{
+ ///
+ /// time spent generating the response
+ ///
+ [JsonPropertyName("total_duration")]
+ public long TotalDuration { get; set; }
+
+ ///
+ /// time spent in nanoseconds loading the model
+ ///
+ [JsonPropertyName("load_duration")]
+ public long LoadDuration { get; set; }
+
+ ///
+ /// number of tokens in the prompt
+ ///
+ [JsonPropertyName("prompt_eval_count")]
+ public int PromptEvalCount { get; set; }
+
+ ///
+ /// time spent in nanoseconds evaluating the prompt
+ ///
+ [JsonPropertyName("prompt_eval_duration")]
+ public long PromptEvalDuration { get; set; }
+
+ ///
+ /// number of tokens the response
+ ///
+ [JsonPropertyName("eval_count")]
+ public int EvalCount { get; set; }
+
+ ///
+ /// time in nanoseconds spent generating the response
+ ///
+ [JsonPropertyName("eval_duration")]
+ public long EvalDuration { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/ChatResponseUpdate.cs b/dotnet/src/AutoGen.Ollama/DTOs/ChatResponseUpdate.cs
new file mode 100644
index 00000000000..8b4dac194f4
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/ChatResponseUpdate.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatResponseUpdate.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class ChatResponseUpdate
+{
+ [JsonPropertyName("model")]
+ public string Model { get; set; } = string.Empty;
+
+ [JsonPropertyName("created_at")]
+ public string CreatedAt { get; set; } = string.Empty;
+
+ [JsonPropertyName("message")]
+ public Message? Message { get; set; }
+
+ [JsonPropertyName("done")]
+ public bool Done { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
new file mode 100644
index 00000000000..2e0d891cc61
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatResponseUpdate.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class Message
+{
+ public Message()
+ {
+ }
+
+ public Message(string role, string value)
+ {
+ Role = role;
+ Value = value;
+ }
+
+ ///
+ /// the role of the message, either system, user or assistant
+ ///
+ [JsonPropertyName("role")]
+ public string Role { get; set; } = string.Empty;
+ ///
+ /// the content of the message
+ ///
+ [JsonPropertyName("content")]
+ public string Value { get; set; } = string.Empty;
+
+ ///
+ /// (optional): a list of images to include in the message (for multimodal models such as llava)
+ ///
+ [JsonPropertyName("images")]
+ public IList? Images { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/ModelReplyOptions.cs b/dotnet/src/AutoGen.Ollama/DTOs/ModelReplyOptions.cs
new file mode 100644
index 00000000000..9d54a1bb83b
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/ModelReplyOptions.cs
@@ -0,0 +1,129 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ModelReplyOptions.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+//https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
+public class ModelReplyOptions
+{
+ ///
+ /// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
+ ///
+ [JsonPropertyName("mirostat")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? MiroStat { get; set; }
+
+ ///
+ /// Influences how quickly the algorithm responds to feedback from the generated text.
+ /// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)
+ ///
+ [JsonPropertyName("mirostat_eta")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public float? MiroStatEta { get; set; }
+
+ ///
+ /// Controls the balance between coherence and diversity of the output.
+ /// A lower value will result in more focused and coherent text. (Default: 5.0)
+ ///
+ [JsonPropertyName("mirostat_tau")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public float? MiroStatTau { get; set; }
+
+ ///
+ /// Sets the size of the context window used to generate the next token. (Default: 2048)
+ ///
+ [JsonPropertyName("num_ctx")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? NumCtx { get; set; }
+
+ ///
+ /// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b
+ ///
+ [JsonPropertyName("num_gqa")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? NumGqa { get; set; }
+
+ ///
+ /// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.
+ ///
+ [JsonPropertyName("num_gpu")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? NumGpu { get; set; }
+
+ ///
+ /// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance.
+ /// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).
+ ///
+ [JsonPropertyName("num_thread")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? NumThread { get; set; }
+
+ ///
+ /// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
+ ///
+ [JsonPropertyName("repeat_last_n")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? RepeatLastN { get; set; }
+
+ ///
+ /// Sets how strongly to penalize repetitions.
+ /// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
+ ///
+ [JsonPropertyName("repeat_penalty")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public float? RepeatPenalty { get; set; }
+
+ ///
+ /// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)
+ ///
+ [JsonPropertyName("temperature")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public float? Temperature { get; set; }
+
+ ///
+ /// Sets the random number seed to use for generation.
+ /// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)
+ ///
+ [JsonPropertyName("seed")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? Seed { get; set; }
+
+ ///
+ /// Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return.
+ /// Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile.
+ ///
+ [JsonPropertyName("stop")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public string? Stop { get; set; }
+
+ ///
+ /// Tail free sampling is used to reduce the impact of less probable tokens from the output.
+ /// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)
+ ///
+ [JsonPropertyName("tfs_z")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public float? TfsZ { get; set; }
+
+ ///
+ /// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
+ ///
+ [JsonPropertyName("num_predict")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? NumPredict { get; set; }
+
+ ///
+ /// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
+ ///
+ [JsonPropertyName("top_k")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? TopK { get; set; }
+
+ ///
+ /// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
+ ///
+ [JsonPropertyName("top_p")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public int? TopP { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/OllamaReplyOptions.cs b/dotnet/src/AutoGen.Ollama/DTOs/OllamaReplyOptions.cs
new file mode 100644
index 00000000000..c7c77d1db25
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/DTOs/OllamaReplyOptions.cs
@@ -0,0 +1,111 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaReplyOptions.cs
+
+using AutoGen.Core;
+
+namespace AutoGen.Ollama;
+
+public enum FormatType
+{
+ None,
+ Json,
+}
+
+public class OllamaReplyOptions : GenerateReplyOptions
+{
+ ///
+ /// the format to return a response in. Currently, the only accepted value is json
+ ///
+ public FormatType Format { get; set; } = FormatType.None;
+
+ ///
+ /// the prompt template to use (overrides what is defined in the Modelfile)
+ ///
+ public string? Template { get; set; }
+
+ ///
+ /// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)
+ ///
+ public new float? Temperature { get; set; }
+
+ ///
+ /// controls how long the model will stay loaded into memory following the request (default: 5m)
+ ///
+ public string? KeepAlive { get; set; }
+
+ ///
+ /// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
+ ///
+ public int? MiroStat { get; set; }
+
+ ///
+ /// Influences how quickly the algorithm responds to feedback from the generated text.
+ /// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)
+ ///
+ public float? MiroStatEta { get; set; }
+
+ ///
+ /// Controls the balance between coherence and diversity of the output.
+ /// A lower value will result in more focused and coherent text. (Default: 5.0)
+ ///
+ public float? MiroStatTau { get; set; }
+
+ ///
+ /// Sets the size of the context window used to generate the next token. (Default: 2048)
+ ///
+ public int? NumCtx { get; set; }
+
+ ///
+ /// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b
+ ///
+ public int? NumGqa { get; set; }
+
+ ///
+ /// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.
+ ///
+ public int? NumGpu { get; set; }
+
+ ///
+ /// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance.
+ /// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).
+ ///
+ public int? NumThread { get; set; }
+
+ ///
+ /// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
+ ///
+ public int? RepeatLastN { get; set; }
+
+ ///
+ /// Sets how strongly to penalize repetitions.
+ /// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
+ ///
+ public float? RepeatPenalty { get; set; }
+
+ ///
+ /// Sets the random number seed to use for generation.
+ /// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)
+ ///
+ public int? Seed { get; set; }
+
+ ///
+ /// Tail free sampling is used to reduce the impact of less probable tokens from the output.
+ /// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)
+ ///
+ public float? TfsZ { get; set; }
+
+ ///
+ /// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
+ ///
+ public new int? MaxToken { get; set; }
+
+ ///
+ /// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
+ ///
+ public int? TopK { get; set; }
+
+ ///
+ /// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
+ ///
+ public int? TopP { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
new file mode 100644
index 00000000000..5ce0dc8cc40
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ITextEmbeddingService.cs
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Ollama;
+
+public interface ITextEmbeddingService
+{
+ public Task GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken);
+}
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
new file mode 100644
index 00000000000..2e431e7bcb8
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaTextEmbeddingService.cs
+
+using System;
+using System.IO;
+using System.Net.Http;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Ollama;
+
+public class OllamaTextEmbeddingService : ITextEmbeddingService
+{
+ private readonly HttpClient _client;
+
+ public OllamaTextEmbeddingService(HttpClient client)
+ {
+ _client = client;
+ }
+ public async Task GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken = default)
+ {
+ using (HttpResponseMessage? response = await _client
+ .SendAsync(BuildPostRequest(request), HttpCompletionOption.ResponseContentRead, cancellationToken))
+ {
+ response.EnsureSuccessStatusCode();
+
+ Stream? streamResponse = await response.Content.ReadAsStreamAsync();
+ TextEmbeddingsResponse output = await JsonSerializer
+ .DeserializeAsync(streamResponse, cancellationToken: cancellationToken)
+ ?? throw new Exception("Failed to deserialize response");
+ return output;
+ }
+ }
+ private static HttpRequestMessage BuildPostRequest(TextEmbeddingsRequest request)
+ {
+ string serialized = JsonSerializer.Serialize(request);
+ return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.EmbeddingsEndpoint)
+ {
+ Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType)
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
new file mode 100644
index 00000000000..7f2531c522a
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// TextEmbeddingsRequest.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class TextEmbeddingsRequest
+{
+ ///
+ /// name of model to generate embeddings from
+ ///
+ [JsonPropertyName("model")]
+ public string Model { get; set; } = string.Empty;
+ ///
+ /// text to generate embeddings for
+ ///
+ [JsonPropertyName("prompt")]
+ public string Prompt { get; set; } = string.Empty;
+ ///
+ /// additional model parameters listed in the documentation for the Modelfile such as temperature
+ ///
+ [JsonPropertyName("options")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public ModelReplyOptions? Options { get; set; }
+ ///
+ /// controls how long the model will stay loaded into memory following the request (default: 5m)
+ ///
+ [JsonPropertyName("keep_alive")]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public string? KeepAlive { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
new file mode 100644
index 00000000000..580059c033b
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// TextEmbeddingsResponse.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Ollama;
+
+public class TextEmbeddingsResponse
+{
+ [JsonPropertyName("embedding")]
+ public double[]? Embedding { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Ollama/Extension/OllamaAgentExtension.cs b/dotnet/src/AutoGen.Ollama/Extension/OllamaAgentExtension.cs
new file mode 100644
index 00000000000..4c0df513ef8
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Extension/OllamaAgentExtension.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaAgentExtension.cs
+
+using AutoGen.Core;
+
+namespace AutoGen.Ollama.Extension;
+
+public static class OllamaAgentExtension
+{
+ ///
+ /// Register an to the
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this OllamaAgent agent, OllamaMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new OllamaMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+
+ ///
+ /// Register an to the where T is
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this MiddlewareStreamingAgent agent, OllamaMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new OllamaMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+}
diff --git a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
new file mode 100644
index 00000000000..a21ec3a1c99
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
@@ -0,0 +1,186 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaMessageConnector.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Core;
+
+namespace AutoGen.Ollama;
+
+public class OllamaMessageConnector : IStreamingMiddleware
+{
+ public string Name => nameof(OllamaMessageConnector);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
+ CancellationToken cancellationToken = default)
+ {
+ var messages = ProcessMessage(context.Messages, agent);
+ IMessage reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken);
+
+ return reply switch
+ {
+ IMessage messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From),
+ IMessage messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"),
+ _ => reply
+ };
+ }
+
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var messages = ProcessMessage(context.Messages, agent);
+ var chunks = new List();
+ await foreach (var update in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
+ {
+ if (update is IStreamingMessage chatResponseUpdate)
+ {
+ var response = chatResponseUpdate.Content switch
+ {
+ _ when chatResponseUpdate.Content.Message?.Value is string content => new TextMessageUpdate(Role.Assistant, content, chatResponseUpdate.From),
+ _ => null,
+ };
+
+ if (response != null)
+ {
+ chunks.Add(chatResponseUpdate.Content);
+ yield return response;
+ }
+ }
+ else
+ {
+ yield return update;
+ }
+ }
+
+ if (chunks.Count == 0)
+ {
+ yield break;
+ }
+
+ // if the chunks are not empty, aggregate them into a single message
+ var messageContent = string.Join(string.Empty, chunks.Select(c => c.Message?.Value));
+ var message = new TextMessage(Role.Assistant, messageContent, agent.Name);
+
+ yield return message;
+ }
+
+ private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent)
+ {
+ return messages.SelectMany(m =>
+ {
+ if (m is IMessage messageEnvelope)
+ {
+ return [m];
+ }
+ else
+ {
+ return m switch
+ {
+ TextMessage textMessage => ProcessTextMessage(textMessage, agent),
+ ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
+ MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
+ _ => [m],
+ };
+ }
+ });
+ }
+
+ private IEnumerable ProcessMultiModalMessage(MultiModalMessage multiModalMessage, IAgent agent)
+ {
+ var textMessages = multiModalMessage.Content.Where(m => m is TextMessage textMessage && textMessage.GetContent() is not null);
+ var imageMessages = multiModalMessage.Content.Where(m => m is ImageMessage);
+
+ // aggregate the text messages into one message
+ // by concatenating the content using newline
+ var textContent = string.Join("\n", textMessages.Select(m => ((TextMessage)m).Content));
+
+ // collect all the images
+ var images = imageMessages.SelectMany(m => ProcessImageMessage((ImageMessage)m, agent)
+ .SelectMany(m => (m as IMessage)?.Content.Images));
+
+ var message = new Message()
+ {
+ Role = "user",
+ Value = textContent,
+ Images = images.ToList(),
+ };
+
+ return [MessageEnvelope.Create(message, agent.Name)];
+ }
+
+ private IEnumerable ProcessImageMessage(ImageMessage imageMessage, IAgent agent)
+ {
+ byte[]? data = imageMessage.Data?.ToArray();
+ if (data is null)
+ {
+ if (imageMessage.Url is null)
+ {
+ throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
+ }
+
+ var uri = new Uri(imageMessage.Url);
+ // download the image from the URL
+ using var client = new HttpClient();
+ var response = client.GetAsync(uri).Result;
+ if (!response.IsSuccessStatusCode)
+ {
+ throw new HttpRequestException($"Failed to download the image from {uri}");
+ }
+
+ data = response.Content.ReadAsByteArrayAsync().Result;
+ }
+
+ var base64Image = Convert.ToBase64String(data);
+ var message = imageMessage.From switch
+ {
+ null when imageMessage.Role == Role.User => new Message { Role = "user", Images = [base64Image] },
+ null => throw new InvalidOperationException("Invalid Role, the role must be user"),
+ _ when imageMessage.From != agent.Name => new Message { Role = "user", Images = [base64Image] },
+ _ => throw new InvalidOperationException("The from field must be null or the agent name"),
+ };
+
+ return [MessageEnvelope.Create(message, agent.Name)];
+ }
+
+ private IEnumerable ProcessTextMessage(TextMessage textMessage, IAgent agent)
+ {
+ if (textMessage.Role == Role.System)
+ {
+ var message = new Message
+ {
+ Role = "system",
+ Value = textMessage.Content
+ };
+
+ return [MessageEnvelope.Create(message, agent.Name)];
+ }
+ else if (textMessage.From == agent.Name)
+ {
+ var message = new Message
+ {
+ Role = "assistant",
+ Value = textMessage.Content
+ };
+
+ return [MessageEnvelope.Create(message, agent.Name)];
+ }
+ else
+ {
+ var message = textMessage.From switch
+ {
+ null when textMessage.Role == Role.User => new Message { Role = "user", Value = textMessage.Content },
+ null when textMessage.Role == Role.Assistant => new Message { Role = "assistant", Value = textMessage.Content },
+ null => throw new InvalidOperationException("Invalid Role"),
+ _ when textMessage.From != agent.Name => new Message { Role = "user", Value = textMessage.Content },
+ _ => throw new InvalidOperationException("The from field must be null or the agent name"),
+ };
+
+ return [MessageEnvelope.Create(message, agent.Name)];
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.Ollama/OllamaConsts.cs b/dotnet/src/AutoGen.Ollama/OllamaConsts.cs
new file mode 100644
index 00000000000..f305446a9aa
--- /dev/null
+++ b/dotnet/src/AutoGen.Ollama/OllamaConsts.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaConsts.cs
+
+namespace AutoGen.Ollama;
+
+public class OllamaConsts
+{
+ public const string JsonFormatType = "json";
+ public const string JsonMediaType = "application/json";
+ public const string ChatCompletionEndpoint = "/api/chat";
+ public const string EmbeddingsEndpoint = "/api/embeddings";
+}
diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
index 52070788e34..cdc6cc464d1 100644
--- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
+++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
@@ -29,10 +29,8 @@ namespace AutoGen.OpenAI;
///
public class GPTAgent : IStreamingAgent
{
- private readonly IDictionary>>? functionMap;
private readonly OpenAIClient openAIClient;
- private readonly string? modelName;
- private readonly OpenAIChatAgent _innerAgent;
+ private readonly IStreamingAgent _innerAgent;
public GPTAgent(
string name,
@@ -52,16 +50,23 @@ public GPTAgent(
_ => throw new ArgumentException($"Unsupported config type {config.GetType()}"),
};
- modelName = config switch
+ var modelName = config switch
{
AzureOpenAIConfig azureConfig => azureConfig.DeploymentName,
OpenAIConfig openAIConfig => openAIConfig.ModelId,
_ => throw new ArgumentException($"Unsupported config type {config.GetType()}"),
};
- _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions);
+ _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions)
+ .RegisterMessageConnector();
+
+ if (functionMap is not null)
+ {
+ var functionMapMiddleware = new FunctionCallMiddleware(functionMap: functionMap);
+ _innerAgent = _innerAgent.RegisterStreamingMiddleware(functionMapMiddleware);
+ }
+
Name = name;
- this.functionMap = functionMap;
}
public GPTAgent(
@@ -77,10 +82,16 @@ public GPTAgent(
IDictionary>>? functionMap = null)
{
this.openAIClient = openAIClient;
- this.modelName = modelName;
Name = name;
- this.functionMap = functionMap;
- _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions);
+
+ _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions)
+ .RegisterMessageConnector();
+
+ if (functionMap is not null)
+ {
+ var functionMapMiddleware = new FunctionCallMiddleware(functionMap: functionMap);
+ _innerAgent = _innerAgent.RegisterStreamingMiddleware(functionMapMiddleware);
+ }
}
public string Name { get; }
@@ -90,14 +101,7 @@ public async Task GenerateReplyAsync(
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
- var agent = this._innerAgent.RegisterMessageConnector();
- if (this.functionMap is not null)
- {
- var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
- agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
- }
-
- return await agent.GenerateReplyAsync(messages, options, cancellationToken);
+ return await _innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
public IAsyncEnumerable GenerateStreamingReplyAsync(
@@ -105,13 +109,6 @@ public IAsyncEnumerable GenerateStreamingReplyAsync(
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
- var agent = this._innerAgent.RegisterMessageConnector();
- if (this.functionMap is not null)
- {
- var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
- agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
- }
-
- return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
+ return _innerAgent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
}
diff --git a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs
index b3dfb1e8668..ed795e5e8ed 100644
--- a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs
+++ b/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs
@@ -12,6 +12,8 @@ public static class MessageExtension
{
public static string TEXT_CONTENT_TYPE = "text";
public static string IMAGE_CONTENT_TYPE = "image";
+
+ [Obsolete("This method is deprecated, please replace Message with one of the built-in message types.")]
public static ChatRequestUserMessage ToChatRequestUserMessage(this Message message)
{
if (message.Value is ChatRequestUserMessage message1)
@@ -50,6 +52,7 @@ public static ChatRequestUserMessage ToChatRequestUserMessage(this Message messa
throw new ArgumentException("Content is null and metadata is null");
}
+ [Obsolete("This method is deprecated")]
public static IEnumerable ToOpenAIChatRequestMessage(this IAgent agent, IMessage message)
{
if (message is IMessage oaiMessage)
diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
index 2bd9470ffa7..8f1825e2fa0 100644
--- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
+++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
@@ -19,7 +19,6 @@ namespace AutoGen.OpenAI;
/// -
/// -
/// -
-/// -
/// - where T is
/// - where TMessage1 is and TMessage2 is
///
@@ -27,6 +26,11 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
{
private bool strictMode = false;
+ ///
+ /// Create a new instance of .
+ ///
+ /// If true, will throw an
+ /// When the message type is not supported. If false, it will ignore the unsupported message type.
public OpenAIChatRequestMessageConnector(bool strictMode = false)
{
this.strictMode = strictMode;
@@ -36,8 +40,7 @@ public OpenAIChatRequestMessageConnector(bool strictMode = false)
public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
- var chatMessages = ProcessIncomingMessages(agent, context.Messages)
- .Select(m => new MessageEnvelope(m));
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
@@ -49,8 +52,7 @@ public async IAsyncEnumerable InvokeAsync(
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- var chatMessages = ProcessIncomingMessages(agent, context.Messages)
- .Select(m => new MessageEnvelope(m));
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
string? currentToolName = null;
await foreach (var reply in streamingReply)
@@ -73,7 +75,14 @@ public async IAsyncEnumerable InvokeAsync(
}
else
{
- yield return reply;
+ if (this.strictMode)
+ {
+ throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}");
+ }
+ else
+ {
+ yield return reply;
+ }
}
}
}
@@ -82,16 +91,10 @@ public IMessage PostProcessMessage(IMessage message)
{
return message switch
{
- TextMessage => message,
- ImageMessage => message,
- MultiModalMessage => message,
- ToolCallMessage => message,
- ToolCallResultMessage => message,
- Message => message,
- AggregateMessage => message,
- IMessage m => PostProcessMessage(m),
- IMessage m => PostProcessMessage(m),
- _ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage, AggregateMessage"),
+ IMessage m => PostProcessChatResponseMessage(m.Content, m.From),
+ IMessage m => PostProcessChatCompletions(m),
+ _ when strictMode is false => message,
+ _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"),
};
}
@@ -120,12 +123,7 @@ public IMessage PostProcessMessage(IMessage message)
}
}
- private IMessage PostProcessMessage(IMessage message)
- {
- return PostProcessMessage(message.Content, message.From);
- }
-
- private IMessage PostProcessMessage(IMessage message)
+ private IMessage PostProcessChatCompletions(IMessage message)
{
// throw exception if prompt filter results is not null
if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
@@ -133,12 +131,12 @@ private IMessage PostProcessMessage(IMessage message)
throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
}
- return PostProcessMessage(message.Content.Choices[0].Message, message.From);
+ return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From);
}
- private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, string? from)
+ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
{
- if (chatResponseMessage.Content is string content)
+ if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content))
{
return new TextMessage(Role.Assistant, content, from);
}
@@ -154,7 +152,7 @@ private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, str
.Where(tc => tc is ChatCompletionsFunctionToolCall)
.Select(tc => (ChatCompletionsFunctionToolCall)tc);
- var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
+ var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
return new ToolCallMessage(toolCalls, from);
}
@@ -162,112 +160,44 @@ private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, str
throw new InvalidOperationException("Invalid ChatResponseMessage");
}
- public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages)
+ public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages)
{
- return messages.SelectMany(m =>
+ return messages.SelectMany(m =>
{
- if (m.From == null)
+ if (m is IMessage crm)
{
- return ProcessIncomingMessagesWithEmptyFrom(m);
- }
- else if (m.From == agent.Name)
- {
- return ProcessIncomingMessagesForSelf(m);
+ return [crm];
}
else
{
- return ProcessIncomingMessagesForOther(m);
+ var chatRequestMessages = m switch
+ {
+ TextMessage textMessage => ProcessTextMessage(agent, textMessage),
+ ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage),
+ MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage),
+ ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage),
+#pragma warning disable CS0618 // deprecated
+ Message msg => ProcessMessage(agent, msg),
+#pragma warning restore CS0618 // deprecated
+ _ when strictMode is false => [],
+ _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"),
+ };
+
+ if (chatRequestMessages.Any())
+ {
+ return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From));
+ }
+ else
+ {
+ return [m];
+ }
}
});
}
- private IEnumerable ProcessIncomingMessagesForSelf(IMessage message)
- {
- return message switch
- {
- TextMessage textMessage => ProcessIncomingMessagesForSelf(textMessage),
- ImageMessage imageMessage => ProcessIncomingMessagesForSelf(imageMessage),
- MultiModalMessage multiModalMessage => ProcessIncomingMessagesForSelf(multiModalMessage),
- ToolCallMessage toolCallMessage => ProcessIncomingMessagesForSelf(toolCallMessage),
- ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForSelf(toolCallResultMessage),
- Message msg => ProcessIncomingMessagesForSelf(msg),
- IMessage crm => ProcessIncomingMessagesForSelf(crm),
- AggregateMessage aggregateMessage => ProcessIncomingMessagesForSelf(aggregateMessage),
- _ => throw new NotImplementedException(),
- };
- }
-
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message)
- {
- return message switch
- {
- TextMessage textMessage => ProcessIncomingMessagesWithEmptyFrom(textMessage),
- ImageMessage imageMessage => ProcessIncomingMessagesWithEmptyFrom(imageMessage),
- MultiModalMessage multiModalMessage => ProcessIncomingMessagesWithEmptyFrom(multiModalMessage),
- ToolCallMessage toolCallMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallMessage),
- ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallResultMessage),
- Message msg => ProcessIncomingMessagesWithEmptyFrom(msg),
- IMessage crm => ProcessIncomingMessagesWithEmptyFrom(crm),
- AggregateMessage aggregateMessage => ProcessIncomingMessagesWithEmptyFrom(aggregateMessage),
- _ => throw new NotImplementedException(),
- };
- }
-
- private IEnumerable ProcessIncomingMessagesForOther(IMessage message)
- {
- return message switch
- {
- TextMessage textMessage => ProcessIncomingMessagesForOther(textMessage),
- ImageMessage imageMessage => ProcessIncomingMessagesForOther(imageMessage),
- MultiModalMessage multiModalMessage => ProcessIncomingMessagesForOther(multiModalMessage),
- ToolCallMessage toolCallMessage => ProcessIncomingMessagesForOther(toolCallMessage),
- ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForOther(toolCallResultMessage),
- Message msg => ProcessIncomingMessagesForOther(msg),
- IMessage crm => ProcessIncomingMessagesForOther(crm),
- AggregateMessage aggregateMessage => ProcessIncomingMessagesForOther(aggregateMessage),
- _ => throw new NotImplementedException(),
- };
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(TextMessage message)
- {
- if (message.Role == Role.System)
- {
- return new[] { new ChatRequestSystemMessage(message.Content) };
- }
- else
- {
- return new[] { new ChatRequestAssistantMessage(message.Content) };
- }
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(ImageMessage _)
- {
- return [new ChatRequestAssistantMessage("// Image Message is not supported")];
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(MultiModalMessage _)
- {
- return [new ChatRequestAssistantMessage("// MultiModal Message is not supported")];
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(ToolCallMessage message)
- {
- var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
- var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty);
- foreach (var tc in toolCall)
- {
- chatRequestMessage.ToolCalls.Add(tc);
- }
-
- return new[] { chatRequestMessage };
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(ToolCallResultMessage message)
- {
- return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
- }
-
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
private IEnumerable ProcessIncomingMessagesForSelf(Message message)
{
if (message.Role == Role.System)
@@ -303,151 +233,147 @@ private IEnumerable ProcessIncomingMessagesForSelf(Message m
}
}
- private IEnumerable ProcessIncomingMessagesForSelf(IMessage message)
- {
- return new[] { message.Content };
- }
-
- private IEnumerable ProcessIncomingMessagesForSelf(AggregateMessage aggregateMessage)
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
+ private IEnumerable ProcessIncomingMessagesForOther(Message message)
{
- var toolCallMessage1 = aggregateMessage.Message1;
- var toolCallResultMessage = aggregateMessage.Message2;
-
- var assistantMessage = new ChatRequestAssistantMessage(string.Empty);
- var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
- foreach (var tc in toolCalls)
+ if (message.Role == Role.System)
{
- assistantMessage.ToolCalls.Add(tc);
+ return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
}
+ else if (message.Content is string content && content is { Length: > 0 })
+ {
+ if (message.FunctionName is not null)
+ {
+ return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
+ }
- var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
-
- // return assistantMessage and tool call result messages
- var messages = new List { assistantMessage };
- messages.AddRange(toolCallResults);
-
- return messages;
+ return [new ChatRequestUserMessage(message.Content) { Name = message.From }];
+ }
+ else if (message.FunctionName is string _)
+ {
+ return [new ChatRequestUserMessage("// Message type is not supported") { Name = message.From }];
+ }
+ else
+ {
+ throw new InvalidOperationException("Invalid Message as message from other.");
+ }
}
- private IEnumerable ProcessIncomingMessagesForOther(TextMessage message)
+ private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message)
{
if (message.Role == Role.System)
{
- return new[] { new ChatRequestSystemMessage(message.Content) };
+ return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
+ }
+
+ if (agent.Name == message.From)
+ {
+ return [new ChatRequestAssistantMessage(message.Content) { Name = agent.Name }];
}
else
{
- return new[] { new ChatRequestUserMessage(message.Content) };
+ return message.From switch
+ {
+ null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)],
+ null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage(message.Content)],
+ null => throw new InvalidOperationException("Invalid Role"),
+ _ => [new ChatRequestUserMessage(message.Content) { Name = message.From }]
+ };
}
}
- private IEnumerable ProcessIncomingMessagesForOther(ImageMessage message)
+ private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message)
{
- return new[] { new ChatRequestUserMessage([
- new ChatMessageImageContentItem(new Uri(message.Url ?? message.BuildDataUri())),
- ])};
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent");
+ }
+
+ var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message);
+ return [new ChatRequestUserMessage([imageContentItem]) { Name = message.From }];
}
- private IEnumerable ProcessIncomingMessagesForOther(MultiModalMessage message)
+ private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message)
{
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent");
+ }
+
IEnumerable items = message.Content.Select(ci => ci switch
{
TextMessage text => new ChatMessageTextContentItem(text.Content),
- ImageMessage image => new ChatMessageImageContentItem(new Uri(image.Url ?? image.BuildDataUri())),
+ ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image),
_ => throw new NotImplementedException(),
});
- return new[] { new ChatRequestUserMessage(items) };
+ return [new ChatRequestUserMessage(items) { Name = message.From }];
}
- private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage msg)
+ private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message)
{
- throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
+ return message.Data is null
+ ? new ChatMessageImageContentItem(new Uri(message.Url))
+ : new ChatMessageImageContentItem(message.Data, message.Data.MediaType);
}
- private IEnumerable ProcessIncomingMessagesForOther(ToolCallResultMessage message)
+ private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message)
{
- return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
- }
-
- private IEnumerable ProcessIncomingMessagesForOther(Message message)
- {
- if (message.Role == Role.System)
+ if (message.From is not null && message.From != agent.Name)
{
- return new[] { new ChatRequestSystemMessage(message.Content) };
+ throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}
- else if (message.Content is string content && content is { Length: > 0 })
- {
- if (message.FunctionName is not null)
- {
- return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
- }
- return new[] { new ChatRequestUserMessage(message.Content) };
- }
- else if (message.FunctionName is string _)
- {
- return new[]
- {
- new ChatRequestUserMessage("// Message type is not supported"),
- };
- }
- else
+ var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
+ var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
+ foreach (var tc in toolCall)
{
- throw new InvalidOperationException("Invalid Message as message from other.");
+ chatRequestMessage.ToolCalls.Add(tc);
}
- }
-
- private IEnumerable ProcessIncomingMessagesForOther(IMessage message)
- {
- return new[] { message.Content };
- }
-
- private IEnumerable ProcessIncomingMessagesForOther(AggregateMessage aggregateMessage)
- {
- // convert as user message
- var resultMessage = aggregateMessage.Message2;
-
- return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result));
- }
-
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(TextMessage message)
- {
- return ProcessIncomingMessagesForOther(message);
- }
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ImageMessage message)
- {
- return ProcessIncomingMessagesForOther(message);
+ return [chatRequestMessage];
}
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(MultiModalMessage message)
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message)
{
- return ProcessIncomingMessagesForOther(message);
+ return message.ToolCalls
+ .Where(tc => tc.Result is not null)
+ .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
}
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallMessage message)
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
+ private IEnumerable ProcessMessage(IAgent agent, Message message)
{
- return ProcessIncomingMessagesForSelf(message);
+ if (message.From is not null && message.From != agent.Name)
+ {
+ return ProcessIncomingMessagesForOther(message);
+ }
+ else
+ {
+ return ProcessIncomingMessagesForSelf(message);
+ }
}
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallResultMessage message)
+ private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage)
{
- return ProcessIncomingMessagesForOther(message);
- }
+ if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name)
+ {
+ // convert as user message
+ var resultMessage = aggregateMessage.Message2;
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(Message message)
- {
- return ProcessIncomingMessagesForOther(message);
- }
+ return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result) { Name = aggregateMessage.From });
+ }
+ else
+ {
+ var toolCallMessage1 = aggregateMessage.Message1;
+ var toolCallResultMessage = aggregateMessage.Message2;
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message)
- {
- return new[] { message.Content };
- }
+ var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1);
+ var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage);
- private IEnumerable ProcessIncomingMessagesWithEmptyFrom(AggregateMessage aggregateMessage)
- {
- return ProcessIncomingMessagesForOther(aggregateMessage);
+ return assistantMessage.Concat(toolCallResults);
+ }
}
}
diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
index 6a8395ef22e..6ce242eb1ab 100644
--- a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
@@ -133,7 +133,9 @@ private IEnumerable ProcessMessageForSelf(IMessage message)
{
TextMessage textMessage => ProcessMessageForSelf(textMessage),
MultiModalMessage multiModalMessage => ProcessMessageForSelf(multiModalMessage),
+#pragma warning disable CS0618 // deprecated
Message m => ProcessMessageForSelf(m),
+#pragma warning restore CS0618 // deprecated
_ => throw new System.NotImplementedException(),
};
}
@@ -145,7 +147,9 @@ private IEnumerable ProcessMessageForOthers(IMessage message
TextMessage textMessage => ProcessMessageForOthers(textMessage),
MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage),
ImageMessage imageMessage => ProcessMessageForOthers(imageMessage),
+#pragma warning disable CS0618 // deprecated
Message m => ProcessMessageForOthers(m),
+#pragma warning restore CS0618 // deprecated
_ => throw new InvalidOperationException("unsupported message type, only support TextMessage, ImageMessage, MultiModalMessage and Message."),
};
}
@@ -208,7 +212,7 @@ private IEnumerable ProcessMessageForOthers(MultiModalMessag
return [new ChatMessageContent(AuthorRole.User, collections)];
}
-
+ [Obsolete("This method is deprecated, please use the specific method instead.")]
private IEnumerable ProcessMessageForSelf(Message message)
{
if (message.Role == Role.System)
@@ -229,6 +233,7 @@ private IEnumerable ProcessMessageForSelf(Message message)
}
}
+ [Obsolete("This method is deprecated, please use the specific method instead.")]
private IEnumerable ProcessMessageForOthers(Message message)
{
if (message.Role == Role.System)
diff --git a/dotnet/src/AutoGen.SourceGenerator/AutoGen.SourceGenerator.csproj b/dotnet/src/AutoGen.SourceGenerator/AutoGen.SourceGenerator.csproj
index 4558160722d..37f344ed11e 100644
--- a/dotnet/src/AutoGen.SourceGenerator/AutoGen.SourceGenerator.csproj
+++ b/dotnet/src/AutoGen.SourceGenerator/AutoGen.SourceGenerator.csproj
@@ -14,7 +14,7 @@
-
+
AutoGen.SourceGenerator
@@ -50,6 +50,10 @@
+
+
+
+
True
diff --git a/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs b/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs
index 50bdc03f0af..cd01416182b 100644
--- a/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs
@@ -144,7 +144,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
private class PartialClassOutput
{
- public PartialClassOutput(string fullClassName, ClassDeclarationSyntax classDeclarationSyntax, IEnumerable functionContracts)
+ public PartialClassOutput(string fullClassName, ClassDeclarationSyntax classDeclarationSyntax, IEnumerable functionContracts)
{
FullClassName = fullClassName;
ClassDeclarationSyntax = classDeclarationSyntax;
@@ -155,10 +155,10 @@ public PartialClassOutput(string fullClassName, ClassDeclarationSyntax classDecl
public ClassDeclarationSyntax ClassDeclarationSyntax { get; }
- public IEnumerable FunctionContracts { get; }
+ public IEnumerable FunctionContracts { get; }
}
- private FunctionContract CreateFunctionContract(MethodDeclarationSyntax method, string? className, string? namespaceName)
+ private SourceGeneratorFunctionContract CreateFunctionContract(MethodDeclarationSyntax method, string? className, string? namespaceName)
{
// get function_call attribute
var functionCallAttribute = method.AttributeLists.SelectMany(attributeList => attributeList.Attributes)
@@ -208,7 +208,7 @@ private FunctionContract CreateFunctionContract(MethodDeclarationSyntax method,
description = System.Text.RegularExpressions.Regex.Replace(description, @"[^\S\r\n]+\/[\/]+\s*", string.Empty);
}
var jsonItemType = parameter.Type!.ToString().EndsWith("[]") ? parameter.Type!.ToString().Substring(0, parameter.Type!.ToString().Length - 2) : null;
- return new ParameterContract
+ return new SourceGeneratorParameterContract
{
Name = parameter.Identifier.ToString(),
JsonType = parameter.Type!.ToString() switch
@@ -234,7 +234,7 @@ private FunctionContract CreateFunctionContract(MethodDeclarationSyntax method,
};
});
- return new FunctionContract
+ return new SourceGeneratorFunctionContract
{
ClassName = className,
Namespace = namespaceName,
diff --git a/dotnet/src/AutoGen.SourceGenerator/FunctionExtension.cs b/dotnet/src/AutoGen.SourceGenerator/FunctionExtension.cs
index a56e4cb54f4..cfb77d26a2b 100644
--- a/dotnet/src/AutoGen.SourceGenerator/FunctionExtension.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/FunctionExtension.cs
@@ -5,27 +5,27 @@
internal static class FunctionExtension
{
- public static string GetFunctionName(this FunctionContract function)
+ public static string GetFunctionName(this SourceGeneratorFunctionContract function)
{
return function.Name ?? string.Empty;
}
- public static string GetFunctionSchemaClassName(this FunctionContract function)
+ public static string GetFunctionSchemaClassName(this SourceGeneratorFunctionContract function)
{
return $"{function.GetFunctionName()}Schema";
}
- public static string GetFunctionDefinitionName(this FunctionContract function)
+ public static string GetFunctionDefinitionName(this SourceGeneratorFunctionContract function)
{
return $"{function.GetFunctionName()}Function";
}
- public static string GetFunctionWrapperName(this FunctionContract function)
+ public static string GetFunctionWrapperName(this SourceGeneratorFunctionContract function)
{
return $"{function.GetFunctionName()}Wrapper";
}
- public static string GetFunctionContractName(this FunctionContract function)
+ public static string GetFunctionContractName(this SourceGeneratorFunctionContract function)
{
return $"{function.GetFunctionName()}FunctionContract";
}
diff --git a/dotnet/src/AutoGen.SourceGenerator/FunctionContract.cs b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
similarity index 81%
rename from dotnet/src/AutoGen.SourceGenerator/FunctionContract.cs
rename to dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
index 2f26352173d..24e42affa3b 100644
--- a/dotnet/src/AutoGen.SourceGenerator/FunctionContract.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
@@ -3,7 +3,7 @@
namespace AutoGen.SourceGenerator
{
- internal class FunctionContract
+ internal class SourceGeneratorFunctionContract
{
public string? Namespace { get; set; }
@@ -15,12 +15,12 @@ internal class FunctionContract
public string? ReturnDescription { get; set; }
- public ParameterContract[]? Parameters { get; set; }
+ public SourceGeneratorParameterContract[]? Parameters { get; set; }
public string? ReturnType { get; set; }
}
- internal class ParameterContract
+ internal class SourceGeneratorParameterContract
{
public string? Name { get; set; }
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
index 1d455bd3041..e56db112eb7 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
@@ -31,7 +31,6 @@ public virtual string TransformText()
// This code was generated by a tool.
//
//----------------------
-using Azure.AI.OpenAI;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
@@ -152,7 +151,8 @@ public virtual string TransformText()
}
this.Write(" },\r\n");
}
- this.Write(" };\r\n }\r\n\r\n public Azure.AI.OpenAI.FunctionDefinition ");
+ this.Write(" };\r\n }\r\n\r\n public global::Azure.AI.OpenAI.FunctionDefin" +
+ "ition ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionDefinitionName()));
this.Write("\r\n {\r\n get => this.");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionContractName()));
@@ -168,7 +168,7 @@ public virtual string TransformText()
public string NameSpace {get; set;}
public string ClassName {get; set;}
-public IEnumerable FunctionContracts {get; set;}
+public IEnumerable FunctionContracts {get; set;}
public bool IsStatic {get; set;} = false;
}
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
index baa2a680fe2..526dfe400ce 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
@@ -8,7 +8,6 @@
// This code was generated by a tool.
//
//----------------------
-using Azure.AI.OpenAI;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
@@ -98,7 +97,7 @@ namespace <#=NameSpace#>
};
}
- public Azure.AI.OpenAI.FunctionDefinition <#=functionContract.GetFunctionDefinitionName()#>
+ public global::Azure.AI.OpenAI.FunctionDefinition <#=functionContract.GetFunctionDefinitionName()#>
{
get => this.<#=functionContract.GetFunctionContractName()#>.ToOpenAIFunctionDefinition();
}
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj
new file mode 100644
index 00000000000..cf2c24eaf78
--- /dev/null
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj
@@ -0,0 +1,24 @@
+
+
+
+ $(TestTargetFramework)
+ enable
+ false
+ True
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs
new file mode 100644
index 00000000000..0e36053c45e
--- /dev/null
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs
@@ -0,0 +1,82 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// DotnetInteractiveServiceTest.cs
+
+using FluentAssertions;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace AutoGen.DotnetInteractive.Tests;
+
+public class DotnetInteractiveServiceTest : IDisposable
+{
+ private ITestOutputHelper _output;
+ private InteractiveService _interactiveService;
+ private string _workingDir;
+
+ public DotnetInteractiveServiceTest(ITestOutputHelper output)
+ {
+ _output = output;
+ _workingDir = Path.Combine(Path.GetTempPath(), "test", Path.GetRandomFileName());
+ if (!Directory.Exists(_workingDir))
+ {
+ Directory.CreateDirectory(_workingDir);
+ }
+
+ _interactiveService = new InteractiveService(_workingDir);
+ _interactiveService.StartAsync(_workingDir, default).Wait();
+ }
+
+ public void Dispose()
+ {
+ _interactiveService.Dispose();
+ }
+
+ [Fact]
+ public async Task ItRunCSharpCodeSnippetTestsAsync()
+ {
+ var cts = new CancellationTokenSource();
+ var isRunning = await _interactiveService.StartAsync(_workingDir, cts.Token);
+
+ isRunning.Should().BeTrue();
+
+ _interactiveService.IsRunning().Should().BeTrue();
+
+ // test code snippet
+ var hello_world = @"
+Console.WriteLine(""hello world"");
+";
+
+ await this.TestCSharpCodeSnippet(_interactiveService, hello_world, "hello world");
+ await this.TestCSharpCodeSnippet(
+ _interactiveService,
+ code: @"
+Console.WriteLine(""hello world""
+",
+ expectedOutput: "Error: (2,32): error CS1026: ) expected");
+
+ await this.TestCSharpCodeSnippet(
+ service: _interactiveService,
+ code: "throw new Exception();",
+ expectedOutput: "Error: System.Exception: Exception of type 'System.Exception' was thrown");
+ }
+
+ [Fact]
+ public async Task ItRunPowershellScriptTestsAsync()
+ {
+ // test power shell
+ var ps = @"Write-Output ""hello world""";
+ await this.TestPowershellCodeSnippet(_interactiveService, ps, "hello world");
+ }
+
+ private async Task TestPowershellCodeSnippet(InteractiveService service, string code, string expectedOutput)
+ {
+ var result = await service.SubmitPowershellCodeAsync(code, CancellationToken.None);
+ result.Should().StartWith(expectedOutput);
+ }
+
+ private async Task TestCSharpCodeSnippet(InteractiveService service, string code, string expectedOutput)
+ {
+ var result = await service.SubmitCSharpCodeAsync(code, CancellationToken.None);
+ result.Should().StartWith(expectedOutput);
+ }
+}
diff --git a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs
index 2b6839dd0ef..3aa61a7a71d 100644
--- a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs
+++ b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs
@@ -87,11 +87,15 @@ public async Task MistralAgentFunctionCallMessageTest()
}
""";
var functionCallResult = await this.GetWeatherWrapper(weatherFunctionArgumets);
-
+ var toolCall = new ToolCall(this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets)
+ {
+ ToolCallId = "012345678", // Mistral AI requires the tool call id to be a length of 9
+ Result = functionCallResult,
+ };
IMessage[] chatHistory = [
new TextMessage(Role.User, "what's the weather in Seattle?"),
- new ToolCallMessage(this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets, from: agent.Name),
- new ToolCallResultMessage(functionCallResult, this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets),
+ new ToolCallMessage([toolCall], from: agent.Name),
+ new ToolCallResultMessage([toolCall], weatherFunctionArgumets),
];
var reply = await agent.SendAsync(chatHistory: chatHistory);
@@ -152,7 +156,7 @@ public async Task MistralAgentFunctionCallMiddlewareMessageTest()
var question = new TextMessage(Role.User, "what's the weather in Seattle?");
var reply = await functionCallAgent.SendAsync(question);
- reply.Should().BeOfType>();
+ reply.Should().BeOfType();
// resend the reply to the same agent so it can generate the final response
// because the reply's from is the agent's name
diff --git a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj
new file mode 100644
index 00000000000..27f80716f1c
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj
@@ -0,0 +1,33 @@
+
+
+
+ $(TestTargetFramework)
+ enable
+ false
+ True
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+ PreserveNewest
+
+
+
+
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
new file mode 100644
index 00000000000..c1fb466f0b0
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
@@ -0,0 +1,224 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaAgentTests.cs
+
+using System.Text.Json;
+using AutoGen.Core;
+using AutoGen.Ollama.Extension;
+using AutoGen.Tests;
+using FluentAssertions;
+
+namespace AutoGen.Ollama.Tests;
+
+public class OllamaAgentTests
+{
+ [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
+ public async Task GenerateReplyAsync_ReturnsValidMessage_WhenCalled()
+ {
+ string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
+ ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
+ OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
+
+ var message = new Message("user", "hey how are you");
+ var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) };
+ IMessage result = await ollamaAgent.GenerateReplyAsync(messages);
+
+ result.Should().NotBeNull();
+ result.Should().BeOfType>();
+ result.From.Should().Be(ollamaAgent.Name);
+ }
+
+ [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
+ public async Task GenerateReplyAsync_ReturnsValidJsonMessageContent_WhenCalled()
+ {
+ string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
+ ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
+ OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
+
+ var message = new Message("user", "What color is the sky at different times of the day? Respond using JSON");
+ var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) };
+ IMessage result = await ollamaAgent.GenerateReplyAsync(messages, new OllamaReplyOptions
+ {
+ Format = FormatType.Json
+ });
+
+ result.Should().NotBeNull();
+ result.Should().BeOfType>();
+ result.From.Should().Be(ollamaAgent.Name);
+
+ string jsonContent = ((MessageEnvelope)result).Content.Message!.Value;
+ bool isValidJson = IsValidJsonMessage(jsonContent);
+ isValidJson.Should().BeTrue();
+ }
+
+ [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
+ public async Task GenerateStreamingReplyAsync_ReturnsValidMessages_WhenCalled()
+ {
+ string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
+ ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
+ OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
+
+ var msg = new Message("user", "hey how are you");
+ var messages = new IMessage[] { MessageEnvelope.Create(msg, from: modelName) };
+ IStreamingMessage? finalReply = default;
+ await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ {
+ message.Should().NotBeNull();
+ message.From.Should().Be(ollamaAgent.Name);
+ var streamingMessage = (IMessage)message;
+ if (streamingMessage.Content.Done)
+ {
+ finalReply = message;
+ break;
+ }
+ else
+ {
+ streamingMessage.Content.Message.Should().NotBeNull();
+ streamingMessage.Content.Done.Should().BeFalse();
+ }
+ }
+
+ finalReply.Should().BeOfType>();
+ var update = ((MessageEnvelope)finalReply!).Content;
+ update.Done.Should().BeTrue();
+ update.TotalDuration.Should().BeGreaterThan(0);
+ }
+
+ [ApiKeyFact("OLLAMA_HOST")]
+ public async Task ItReturnValidMessageUsingLLavaAsync()
+ {
+ var host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ var modelName = "llava:latest";
+ var ollamaAgent = BuildOllamaAgent(host, modelName);
+ var imagePath = Path.Combine("images", "image.png");
+ var base64Image = Convert.ToBase64String(File.ReadAllBytes(imagePath));
+ var message = new Message()
+ {
+ Role = "user",
+ Value = "What's the color of the background in this image",
+ Images = [base64Image],
+ };
+
+ var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) };
+ var reply = await ollamaAgent.GenerateReplyAsync(messages);
+
+ reply.Should().BeOfType>();
+ var chatResponse = ((MessageEnvelope)reply).Content;
+ chatResponse.Message.Should().NotBeNull();
+ }
+
+ [ApiKeyFact("OLLAMA_HOST")]
+ public async Task ItCanProcessMultiModalMessageUsingLLavaAsync()
+ {
+ var host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ var modelName = "llava:latest";
+ var ollamaAgent = BuildOllamaAgent(host, modelName)
+ .RegisterMessageConnector();
+ var image = Path.Combine("images", "image.png");
+ var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
+ var imageMessage = new ImageMessage(Role.User, binaryData);
+ var textMessage = new TextMessage(Role.User, "What's in this image?");
+ var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
+
+ var reply = await ollamaAgent.SendAsync(multiModalMessage);
+ reply.Should().BeOfType();
+ reply.GetRole().Should().Be(Role.Assistant);
+ reply.GetContent().Should().NotBeNullOrEmpty();
+ reply.From.Should().Be(ollamaAgent.Name);
+ }
+
+ [ApiKeyFact("OLLAMA_HOST")]
+ public async Task ItCanProcessImageMessageUsingLLavaAsync()
+ {
+ var host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ var modelName = "llava:latest";
+ var ollamaAgent = BuildOllamaAgent(host, modelName)
+ .RegisterMessageConnector();
+ var image = Path.Combine("images", "image.png");
+ var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
+ var imageMessage = new ImageMessage(Role.User, binaryData);
+
+ var reply = await ollamaAgent.SendAsync(imageMessage);
+ reply.Should().BeOfType();
+ reply.GetRole().Should().Be(Role.Assistant);
+ reply.GetContent().Should().NotBeNullOrEmpty();
+ reply.From.Should().Be(ollamaAgent.Name);
+ }
+
+ [ApiKeyFact("OLLAMA_HOST")]
+ public async Task ItReturnValidStreamingMessageUsingLLavaAsync()
+ {
+ var host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ var modelName = "llava:latest";
+ var ollamaAgent = BuildOllamaAgent(host, modelName);
+ var squareImagePath = Path.Combine("images", "square.png");
+ var base64Image = Convert.ToBase64String(File.ReadAllBytes(squareImagePath));
+ var imageMessage = new Message()
+ {
+ Role = "user",
+ Value = "What's in this image?",
+ Images = [base64Image],
+ };
+
+ var messages = new IMessage[] { MessageEnvelope.Create(imageMessage, from: modelName) };
+
+ IStreamingMessage? finalReply = default;
+ await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ {
+ message.Should().NotBeNull();
+ message.From.Should().Be(ollamaAgent.Name);
+ var streamingMessage = (IMessage)message;
+ if (streamingMessage.Content.Done)
+ {
+ finalReply = message;
+ break;
+ }
+ else
+ {
+ streamingMessage.Content.Message.Should().NotBeNull();
+ streamingMessage.Content.Done.Should().BeFalse();
+ }
+ }
+
+ finalReply.Should().BeOfType>();
+ var update = ((MessageEnvelope)finalReply!).Content;
+ update.Done.Should().BeTrue();
+ update.TotalDuration.Should().BeGreaterThan(0);
+ }
+
+ private static bool IsValidJsonMessage(string input)
+ {
+ try
+ {
+ JsonDocument.Parse(input);
+ return true;
+ }
+ catch (JsonException)
+ {
+ return false;
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine("An unexpected exception occurred: " + ex.Message);
+ return false;
+ }
+ }
+
+ private static OllamaAgent BuildOllamaAgent(string host, string modelName)
+ {
+ var httpClient = new HttpClient
+ {
+ BaseAddress = new Uri(host)
+ };
+ return new OllamaAgent(httpClient, "TestAgent", modelName);
+ }
+}
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
new file mode 100644
index 00000000000..b19291e9767
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
@@ -0,0 +1,176 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaMessageTests.cs
+
+using AutoGen.Core;
+using AutoGen.Tests;
+using FluentAssertions;
+using Xunit;
+namespace AutoGen.Ollama.Tests;
+
+public class OllamaMessageTests
+{
+ [Fact]
+ public async Task ItProcessUserTextMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.First();
+ innerMessage.Should().BeOfType>();
+ var message = (IMessage)innerMessage;
+ message.Content.Value.Should().Be("Hello");
+ message.Content.Images.Should().BeNullOrEmpty();
+ message.Content.Role.Should().Be("user");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(messageConnector);
+
+ // when from is null and role is user
+ await agent.SendAsync("Hello");
+
+ // when from is user and role is user
+ var userMessage = new TextMessage(Role.User, "Hello", from: "user");
+ await agent.SendAsync(userMessage);
+
+ // when from is user but role is assistant
+ userMessage = new TextMessage(Role.Assistant, "Hello", from: "user");
+ await agent.SendAsync(userMessage);
+ }
+
+ [Fact]
+ public async Task ItProcessStreamingTextMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterStreamingMiddleware(messageConnector);
+
+ var messageChunks = Enumerable.Range(0, 10)
+ .Select(i => new ChatResponseUpdate()
+ {
+ Message = new Message()
+ {
+ Value = i.ToString(),
+ Role = "assistant",
+ }
+ })
+ .Select(m => MessageEnvelope.Create(m));
+
+ IStreamingMessage? finalReply = null;
+ await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
+ {
+ reply.Should().BeAssignableTo();
+ finalReply = reply;
+ }
+
+ finalReply.Should().BeOfType();
+ var textMessage = (TextMessage)finalReply!;
+ textMessage.GetContent().Should().Be("0123456789");
+ }
+
+ [Fact]
+ public async Task ItProcessAssistantTextMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.First();
+ innerMessage.Should().BeOfType>();
+ var message = (IMessage)innerMessage;
+ message.Content.Value.Should().Be("Hello");
+ message.Content.Images.Should().BeNullOrEmpty();
+ message.Content.Role.Should().Be("assistant");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(messageConnector);
+
+ // when from is null and role is assistant
+ var assistantMessage = new TextMessage(Role.Assistant, "Hello");
+ await agent.SendAsync(assistantMessage);
+
+ // when from is assistant and role is assistant
+ assistantMessage = new TextMessage(Role.Assistant, "Hello", from: "assistant");
+ await agent.SendAsync(assistantMessage);
+
+ // when from is assistant but role is user
+ assistantMessage = new TextMessage(Role.User, "Hello", from: "assistant");
+ await agent.SendAsync(assistantMessage);
+ }
+
+ [Fact]
+ public async Task ItProcessSystemTextMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.First();
+ innerMessage.Should().BeOfType>();
+ var message = (IMessage)innerMessage;
+ message.Content.Value.Should().Be("Hello");
+ message.Content.Images.Should().BeNullOrEmpty();
+ message.Content.Role.Should().Be("system");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(messageConnector);
+
+ // when role is system
+ var systemMessage = new TextMessage(Role.System, "Hello");
+ await agent.SendAsync(systemMessage);
+ }
+
+ [Fact]
+ public async Task ItProcessImageMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.First();
+ innerMessage.Should().BeOfType>();
+ var message = (IMessage)innerMessage;
+ message.Content.Images!.Count.Should().Be(1);
+ message.Content.Role.Should().Be("user");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(messageConnector);
+
+ var square = Path.Combine("images", "square.png");
+ BinaryData imageBinaryData = BinaryData.FromBytes(File.ReadAllBytes(square), "image/png");
+ var imageMessage = new ImageMessage(Role.User, imageBinaryData);
+ await agent.SendAsync(imageMessage);
+ }
+
+ [Fact]
+ public async Task ItProcessMultiModalMessageAsync()
+ {
+ var messageConnector = new OllamaMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, ct) =>
+ {
+ msgs.Count().Should().Be(1);
+ var message = msgs.First();
+ message.Should().BeOfType>();
+
+ var multiModalMessage = (IMessage)message;
+ multiModalMessage.Content.Images!.Count.Should().Be(1);
+ multiModalMessage.Content.Value.Should().Be("Hello");
+
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(messageConnector);
+
+ var square = Path.Combine("images", "square.png");
+ BinaryData imageBinaryData = BinaryData.FromBytes(File.ReadAllBytes(square), "image/png");
+ var imageMessage = new ImageMessage(Role.User, imageBinaryData);
+ var textMessage = new TextMessage(Role.User, "Hello");
+ var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
+
+ await agent.SendAsync(multiModalMessage);
+ }
+}
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
new file mode 100644
index 00000000000..06522bdd823
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OllamaTextEmbeddingServiceTests.cs
+
+using AutoGen.Tests;
+using FluentAssertions;
+
+namespace AutoGen.Ollama.Tests;
+
+public class OllamaTextEmbeddingServiceTests
+{
+ [ApiKeyFact("OLLAMA_HOST", "OLLAMA_EMBEDDING_MODEL_NAME")]
+ public async Task GenerateAsync_ReturnsEmbeddings_WhenApiResponseIsSuccessful()
+ {
+ string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
+ ?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
+ string embeddingModelName = Environment.GetEnvironmentVariable("OLLAMA_EMBEDDING_MODEL_NAME")
+ ?? throw new InvalidOperationException("OLLAMA_EMBEDDING_MODEL_NAME is not set.");
+ var httpClient = new HttpClient
+ {
+ BaseAddress = new Uri(host)
+ };
+ var request = new TextEmbeddingsRequest { Model = embeddingModelName, Prompt = "Llamas are members of the camelid family", };
+ var service = new OllamaTextEmbeddingService(httpClient);
+ TextEmbeddingsResponse response = await service.GenerateAsync(request);
+ response.Should().NotBeNull();
+ }
+}
diff --git a/dotnet/test/AutoGen.Ollama.Tests/images/image.png b/dotnet/test/AutoGen.Ollama.Tests/images/image.png
new file mode 100644
index 00000000000..ca276f81f5b
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/images/image.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:300b7c9d6ba0c23a3e52fbd2e268141ddcca0434a9fb9dcf7e58e7e903d36dcf
+size 2126185
diff --git a/dotnet/test/AutoGen.Ollama.Tests/images/square.png b/dotnet/test/AutoGen.Ollama.Tests/images/square.png
new file mode 100644
index 00000000000..afb4f4cd4df
--- /dev/null
+++ b/dotnet/test/AutoGen.Ollama.Tests/images/square.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918
+size 491
diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
similarity index 73%
rename from dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
rename to dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
index 2cb58f4d88c..e8e9af84dbd 100644
--- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
+++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
@@ -3,6 +3,7 @@
"OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )",
"ConvertedMessages": [
{
+ "Name": null,
"Role": "system",
"Content": "You are a helpful AI assistant"
}
@@ -14,6 +15,7 @@
{
"Role": "user",
"Content": "Hello",
+ "Name": "user",
"MultiModaItem": null
}
]
@@ -24,71 +26,20 @@
{
"Role": "assistant",
"Content": "How can I help you?",
+ "Name": "assistant",
"TooCall": [],
"FunctionCallName": null,
"FunctionCallArguments": null
}
]
},
- {
- "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )",
- "ConvertedMessages": [
- {
- "Role": "system",
- "Content": "You are a helpful AI assistant"
- }
- ]
- },
- {
- "OriginalMessage": "Message(user, Hello, user, , )",
- "ConvertedMessages": [
- {
- "Role": "user",
- "Content": "Hello",
- "MultiModaItem": null
- }
- ]
- },
- {
- "OriginalMessage": "Message(assistant, How can I help you?, assistant, , )",
- "ConvertedMessages": [
- {
- "Role": "assistant",
- "Content": "How can I help you?",
- "TooCall": [],
- "FunctionCallName": null,
- "FunctionCallArguments": null
- }
- ]
- },
- {
- "OriginalMessage": "Message(function, result, user, , )",
- "ConvertedMessages": [
- {
- "Role": "user",
- "Content": "result",
- "MultiModaItem": null
- }
- ]
- },
- {
- "OriginalMessage": "Message(assistant, , assistant, functionName, functionArguments)",
- "ConvertedMessages": [
- {
- "Role": "assistant",
- "Content": null,
- "TooCall": [],
- "FunctionCallName": "functionName",
- "FunctionCallArguments": "functionArguments"
- }
- ]
- },
{
"OriginalMessage": "ImageMessage(user, https://example.com/image.png, user)",
"ConvertedMessages": [
{
"Role": "user",
"Content": null,
+ "Name": "user",
"MultiModaItem": [
{
"Type": "Image",
@@ -107,6 +58,7 @@
{
"Role": "user",
"Content": null,
+ "Name": "user",
"MultiModaItem": [
{
"Type": "Text",
@@ -129,6 +81,7 @@
{
"Role": "assistant",
"Content": "",
+ "Name": "assistant",
"TooCall": [
{
"Type": "Function",
@@ -158,12 +111,12 @@
{
"Role": "tool",
"Content": "test",
- "ToolCallId": "result"
+ "ToolCallId": "result_0"
},
{
"Role": "tool",
"Content": "test",
- "ToolCallId": "result"
+ "ToolCallId": "result_1"
}
]
},
@@ -173,18 +126,19 @@
{
"Role": "assistant",
"Content": "",
+ "Name": "assistant",
"TooCall": [
{
"Type": "Function",
"Name": "test",
"Arguments": "test",
- "Id": "test"
+ "Id": "test_0"
},
{
"Type": "Function",
"Name": "test",
"Arguments": "test",
- "Id": "test"
+ "Id": "test_1"
}
],
"FunctionCallName": null,
@@ -198,6 +152,7 @@
{
"Role": "assistant",
"Content": "",
+ "Name": "assistant",
"TooCall": [
{
"Type": "Function",
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
new file mode 100644
index 00000000000..044975354b8
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
@@ -0,0 +1,32 @@
+
+
+
+ $(TestTargetFramework)
+ false
+ True
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ $([System.String]::Copy('%(FileName)').Split('.')[0])
+ $(ProjectExt.Replace('proj', ''))
+ %(ParentFile)%(ParentExtension)
+
+
+
+
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs
new file mode 100644
index 00000000000..d66bf001ed5
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs
@@ -0,0 +1,4 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GlobalUsing.cs
+
+global using AutoGen.Core;
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
new file mode 100644
index 00000000000..87fc0767020
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
@@ -0,0 +1,223 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MathClassTest.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.Extension;
+using AutoGen.Tests;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using Xunit.Abstractions;
+
+namespace AutoGen.OpenAI.Tests
+{
+ public partial class MathClassTest
+ {
+ private readonly ITestOutputHelper _output;
+
+ // as of 2024-05-20, aoai return 500 error when round > 1
+ // I'm pretty sure that round > 5 was supported before
+ // So this is probably some wield regression on aoai side
+ // I'll keep this test case here for now, plus setting round to 1
+ // so the test can still pass.
+ // In the future, we should rewind this test case to round > 1 (previously was 5)
+ private int round = 1;
+ public MathClassTest(ITestOutputHelper output)
+ {
+ _output = output;
+ }
+
+ private Task Print(IEnumerable messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
+ {
+ try
+ {
+ var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
+
+ _output.WriteLine(reply.FormatMessage());
+ return Task.FromResult(reply);
+ }
+ catch (Exception)
+ {
+ _output.WriteLine("Request failed");
+ _output.WriteLine($"agent name: {agent.Name}");
+ foreach (var message in messages)
+ {
+ _output.WriteLine(message.FormatMessage());
+ }
+
+ throw;
+ }
+
+ }
+
+ [FunctionAttribute]
+ public async Task CreateMathQuestion(string question, int question_index)
+ {
+ return $@"[MATH_QUESTION]
+Question {question_index}:
+{question}
+
+Student, please answer";
+ }
+
+ [FunctionAttribute]
+ public async Task AnswerQuestion(string answer)
+ {
+ return $@"[MATH_ANSWER]
+The answer is {answer}
+teacher please check answer";
+ }
+
+ [FunctionAttribute]
+ public async Task AnswerIsCorrect(string message)
+ {
+ return $@"[ANSWER_IS_CORRECT]
+{message}
+please update progress";
+ }
+
+ [FunctionAttribute]
+ public async Task UpdateProgress(int correctAnswerCount)
+ {
+ if (correctAnswerCount >= this.round)
+ {
+ return $@"[UPDATE_PROGRESS]
+{GroupChatExtension.TERMINATE}";
+ }
+ else
+ {
+ return $@"[UPDATE_PROGRESS]
+the number of resolved question is {correctAnswerCount}
+teacher, please create the next math question";
+ }
+ }
+
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")]
+ public async Task OpenAIAgentMathChatTestAsync()
+ {
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
+ var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
+
+ var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
+ var model = "gpt-35-turbo-16k";
+ var teacher = await CreateTeacherAgentAsync(openaiClient, model);
+ var student = await CreateStudentAssistantAgentAsync(openaiClient, model);
+
+ var adminFunctionMiddleware = new FunctionCallMiddleware(
+ functions: [this.UpdateProgressFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.UpdateProgressFunction.Name!, this.UpdateProgressWrapper },
+ });
+ var admin = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ modelName: model,
+ name: "Admin",
+ systemMessage: $@"You are admin. You update progress after each question is answered.")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(adminFunctionMiddleware)
+ .RegisterMiddleware(Print);
+
+ var groupAdmin = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ modelName: model,
+ name: "GroupAdmin",
+ systemMessage: "You are group admin. You manage the group chat.")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(Print);
+ await RunMathChatAsync(teacher, student, admin, groupAdmin);
+ }
+
+ private async Task CreateTeacherAgentAsync(OpenAIClient client, string model)
+ {
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
+ { this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
+ });
+
+ var teacher = new OpenAIChatAgent(
+ openAIClient: client,
+ name: "Teacher",
+ systemMessage: @"You are a preschool math teacher.
+You create math question and ask student to answer it.
+Then you check if the answer is correct.
+If the answer is wrong, you ask student to fix it",
+ modelName: model)
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware)
+ .RegisterMiddleware(Print);
+
+ return teacher;
+ }
+
+ private async Task CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
+ {
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.AnswerQuestionFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
+ });
+ var student = new OpenAIChatAgent(
+ openAIClient: client,
+ name: "Student",
+ modelName: model,
+ systemMessage: @"You are a student. You answer math question from teacher.")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware)
+ .RegisterMiddleware(Print);
+
+ return student;
+ }
+
+ private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
+ {
+ var teacher2Student = Transition.Create(teacher, student);
+ var student2Teacher = Transition.Create(student, teacher);
+ var teacher2Admin = Transition.Create(teacher, admin);
+ var admin2Teacher = Transition.Create(admin, teacher);
+ var workflow = new Graph(
+ [
+ teacher2Student,
+ student2Teacher,
+ teacher2Admin,
+ admin2Teacher,
+ ]);
+ var group = new GroupChat(
+ workflow: workflow,
+ members: [
+ admin,
+ teacher,
+ student,
+ ],
+ admin: groupAdmin);
+
+ var groupChatManager = new GroupChatManager(group);
+ var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
+
+ chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ // check if there's terminate chat message from admin
+ chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
+ .Count()
+ .Should().Be(1);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
similarity index 93%
rename from dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs
rename to dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
index c504eb06a18..284cd3856bb 100644
--- a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
@@ -5,12 +5,12 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
-using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
+using AutoGen.Tests;
using Azure.AI.OpenAI;
using FluentAssertions;
-namespace AutoGen.Tests;
+namespace AutoGen.OpenAI.Tests;
public partial class OpenAIChatAgentTest
{
@@ -79,7 +79,6 @@ public async Task OpenAIChatMessageContentConnectorTestAsync()
new TextMessage(Role.Assistant, "Hello", from: "user"),
],
from: "user"),
- new Message(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
};
foreach (var message in messages)
@@ -133,7 +132,6 @@ public async Task OpenAIChatAgentToolCallTestAsync()
new TextMessage(Role.Assistant, question, from: "user"),
],
from: "user"),
- new Message(Role.Assistant, question, from: "user"), // Message type is going to be deprecated, please use TextMessage instead
};
foreach (var message in messages)
@@ -202,14 +200,13 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
new TextMessage(Role.Assistant, question, from: "user"),
],
from: "user"),
- new Message(Role.Assistant, question, from: "user"), // Message type is going to be deprecated, please use TextMessage instead
};
foreach (var message in messages)
{
var reply = await functionCallAgent.SendAsync(message);
- reply.Should().BeOfType>();
+ reply.Should().BeOfType();
reply.From.Should().Be("assistant");
reply.GetToolCalls()!.Count().Should().Be(1);
reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
@@ -229,7 +226,7 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
}
else
{
- streamingMessage.Should().BeOfType>();
+ streamingMessage.Should().BeOfType();
streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle");
}
}
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
new file mode 100644
index 00000000000..81581d068ee
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
@@ -0,0 +1,720 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIMessageTests.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using System.Text.Json;
+using System.Threading.Tasks;
+using ApprovalTests;
+using ApprovalTests.Namers;
+using ApprovalTests.Reporters;
+using AutoGen.Tests;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.OpenAI.Tests;
+
+public class OpenAIMessageTests
+{
+ private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
+ {
+ WriteIndented = true,
+ IgnoreReadOnlyProperties = false,
+ };
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("ApprovalTests")]
+ public void BasicMessageTest()
+ {
+ IMessage[] messages = [
+ new TextMessage(Role.System, "You are a helpful AI assistant"),
+ new TextMessage(Role.User, "Hello", "user"),
+ new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.User, "Hello", "user"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "user"),
+ ], "user"),
+ new ToolCallMessage("test", "test", "assistant"),
+ new ToolCallResultMessage("result", "test", "test", "user"),
+ new ToolCallResultMessage(
+ [
+ new ToolCall("result", "test", "test"),
+ new ToolCall("result", "test", "test"),
+ ], "user"),
+ new ToolCallMessage(
+ [
+ new ToolCall("test", "test"),
+ new ToolCall("test", "test"),
+ ], "assistant"),
+ new AggregateMessage(
+ message1: new ToolCallMessage("test", "test", "assistant"),
+ message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"),
+ ];
+ var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant");
+
+ var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m])));
+ VerifyOAIMessages(oaiMessages);
+ }
+
+ [Fact]
+ public async Task ItProcessUserTextMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("Hello");
+ chatRequestMessage.Name.Should().Be("user");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new TextMessage(Role.User, "Hello", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItShortcutChatRequestMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("hello");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = new ChatRequestUserMessage("hello");
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ await agent.GenerateReplyAsync([chatRequestMessage]);
+ }
+
+ [Fact]
+ public async Task ItShortcutMessageWhenStrictModelIsFalseAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+
+ var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Should().Be("hello");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = "hello";
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ await agent.GenerateReplyAsync([chatRequestMessage]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenStrictModeIsTrueAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = "hello";
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1");
+ }
+
+ [Fact]
+ public async Task ItProcessAssistantTextMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("How can I help you?");
+ chatRequestMessage.Name.Should().Be("assistant");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // assistant message
+ IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessSystemTextMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("You are a helpful AI assistant");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // system message
+ IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessImageMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.Name.Should().Be("user");
+ chatRequestMessage.MultimodalContentItems.Count().Should().Be(1);
+ chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant");
+ Func action = async () => await agent.GenerateReplyAsync([imageMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessMultiModalMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.Name.Should().Be("user");
+ chatRequestMessage.MultimodalContentItems.Count().Should().Be(2);
+ chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
+ chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType();
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new MultiModalMessage(
+ Role.User,
+ [
+ new TextMessage(Role.User, "Hello", "user"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "user"),
+ ], "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var multiModalMessage = new MultiModalMessage(
+ Role.Assistant,
+ [
+ new TextMessage(Role.User, "Hello", "assistant"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "assistant"),
+ ], "assistant");
+
+ Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessToolCallMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.Name.Should().Be("assistant");
+ chatRequestMessage.ToolCalls.Count().Should().Be(1);
+ chatRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be("test");
+ functionToolCall.Arguments.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ToolCallMessage("test", "test", "assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelToolCallMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.Name.Should().Be("assistant");
+ chatRequestMessage.ToolCalls.Count().Should().Be(2);
+ for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
+ {
+ chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be($"test_{i}");
+ functionToolCall.Arguments.Should().Be("test");
+ }
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test"),
+ new ToolCall("test", "test"),
+ };
+ IMessage message = new ToolCallMessage(toolCalls, "assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector(strictMode: true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var toolCallMessage = new ToolCallMessage("test", "test", "user");
+ Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]);
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessToolCallResultMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ToolCallResultMessage("result", "test", "test", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelToolCallResultMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(2);
+
+ for (int i = 0; i < msgs.Count(); i++)
+ {
+ var innerMessage = msgs.ElementAt(i);
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
+ }
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test", "result"),
+ new ToolCall("test", "test", "result"),
+ };
+ IMessage message = new ToolCallResultMessage(toolCalls, "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.Name.Should().Be("user");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCallMessage = new ToolCallMessage("test", "test", "user");
+ var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user");
+ var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(2);
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be("test");
+
+ var toolCallMessage = msgs.First();
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallRequestMessage.Content.Should().BeNullOrEmpty();
+ toolCallRequestMessage.ToolCalls.Count().Should().Be(1);
+ toolCallRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be("test");
+ functionToolCall.Arguments.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCallMessage = new ToolCallMessage("test", "test", "assistant");
+ var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant");
+ var aggregateMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "assistant");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(3);
+ var toolCallMessage = msgs.First();
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallRequestMessage.Content.Should().BeNullOrEmpty();
+ toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
+
+ for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
+ {
+ toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be($"test_{i}");
+ functionToolCall.Arguments.Should().Be("test");
+ }
+
+ for (int i = 1; i < msgs.Count(); i++)
+ {
+ var toolCallResultMessage = msgs.ElementAt(i);
+ toolCallResultMessage!.Should().BeOfType>();
+ var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content;
+ toolCallResultRequestMessage.Content.Should().Be("result");
+ toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
+ }
+
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test", "result"),
+ new ToolCall("test", "test", "result"),
+ };
+ var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
+ var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
+ var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItConvertChatResponseMessageToTextMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = CreateInstance(ChatRole.Assistant, "hello");
+ var chatRequestMessage = MessageEnvelope.Create(textMessage);
+
+ var message = await agent.GenerateReplyAsync([chatRequestMessage]);
+ message.Should().BeOfType();
+ message.GetContent().Should().Be("hello");
+ message.GetRole().Should().Be(Role.Assistant);
+ }
+
+ [Fact]
+ public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // tool call message
+ var toolCallMessage = CreateInstance(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary());
+ var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
+ var message = await agent.GenerateReplyAsync([chatRequestMessage]);
+ message.Should().BeOfType();
+ message.GetToolCalls()!.Count().Should().Be(1);
+ message.GetToolCalls()!.First().FunctionName.Should().Be("test");
+ message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
+ }
+
+ [Fact]
+ public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = "hello";
+ var messageToSend = MessageEnvelope.Create(textMessage);
+
+ var message = await agent.GenerateReplyAsync([messageToSend]);
+ message.Should().BeOfType>();
+ }
+
+ [Fact]
+ public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync()
+ {
+ var middleware = new OpenAIChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = new ChatRequestUserMessage("hello");
+ var messageToSend = MessageEnvelope.Create(textMessage);
+ Func action = async () => await agent.GenerateReplyAsync([messageToSend]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1");
+ }
+
+ [Fact]
+ public void ToOpenAIChatRequestMessageShortCircuitTest()
+ {
+ var agent = new EchoAgent("assistant");
+ var middleware = new OpenAIChatRequestMessageConnector();
+ ChatRequestMessage[] messages =
+ [
+ new ChatRequestUserMessage("Hello"),
+ new ChatRequestAssistantMessage("How can I help you?"),
+ new ChatRequestSystemMessage("You are a helpful AI assistant"),
+ new ChatRequestFunctionMessage("result", "functionName"),
+ new ChatRequestToolMessage("test", "test"),
+ ];
+
+ foreach (var oaiMessage in messages)
+ {
+ IMessage message = new MessageEnvelope(oaiMessage);
+ var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]);
+ oaiMessages.Count().Should().Be(1);
+ //oaiMessages.First().Should().BeOfType>();
+ if (oaiMessages.First() is IMessage chatRequestMessage)
+ {
+ chatRequestMessage.Content.Should().Be(oaiMessage);
+ }
+ else
+ {
+ // fail the test
+ Assert.True(false);
+ }
+ }
+ }
+ private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages)
+ {
+ var jsonObjects = messages.Select(pair =>
+ {
+ var (originalMessage, ms) = pair;
+ var objs = new List