Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance vertexai integration #3086

Merged
merged 18 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order


def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
else:
return str(carryover_item)


def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()

Expand All @@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
Expand Down
3 changes: 2 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from openai import BadRequestError

from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import model_dump
Expand Down Expand Up @@ -2364,7 +2365,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
Expand Down
99 changes: 82 additions & 17 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"config_list": [{
"api_type": "google",
"model": "gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY"),
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
Expand All @@ -32,6 +32,7 @@
from __future__ import annotations

import base64
import logging
import os
import random
import re
Expand All @@ -45,13 +46,19 @@
import vertexai
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

logger = logging.getLogger(__name__)


class GeminiClient:
Expand Down Expand Up @@ -81,29 +88,36 @@ def _initialize_vartexai(self, **params):
vertexai_init_args["project"] = params["project_id"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
if "credentials" in params:
assert isinstance(
params["credentials"], Credentials
), "Object type google.auth.credentials.Credentials is expected!"
vertexai_init_args["credentials"] = params["credentials"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)

def __init__(self, **kwargs):
"""Uses either either api_key for authentication from the LLM config
(specifying the GOOGLE_API_KEY environment variable also works),
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
where project_id and location can also be passed as parameters. Service account key file can also be used.
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
like in Google Cloud Shell.

Args:
api_key (str): The API key for using Gemini.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
location (str): Compute region to be used, like 'us-west1'.
This parameter is only valid in case no API key is specified.
This parameter is only valid in case no API key is specified.
"""
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GOOGLE_API_KEY")
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
if self.api_key is None:
self.use_vertexai = True
self._initialize_vartexai(**kwargs)
Expand Down Expand Up @@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
response_validation = params.get("response_validation", True)

generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
else:
safety_settings = params.get("safety_settings", {})

if stream:
warnings.warn(
Expand All @@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1], stream=stream)
response = chat.send_message(
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand All @@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
# response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
Expand Down Expand Up @@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
if content == "":
content = "empty" # Empty content is not allowed.
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
else:
Expand Down Expand Up @@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li

return rst

@staticmethod
def _to_vertexai_safety_settings(safety_settings):
"""Convert safety settings to VertexAI format if needed,
like when specifying them in the OAI_CONFIG_LIST
"""
if isinstance(safety_settings, list) and all(
[
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
for safety_setting in safety_settings
]
):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings
else:
return safety_settings


def _to_pil(data: str) -> Image.Image:
"""
Expand Down
10 changes: 9 additions & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from openai.types.beta.assistant import Assistant
from packaging.version import parse

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
NON_CACHE_KEY = [
"api_key",
"base_url",
"api_type",
"api_version",
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
Expand Down
11 changes: 11 additions & 0 deletions test/agentchat/test_chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import autogen
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats
from autogen.agentchat.chat import _post_process_carryover_item

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import reason, skip_openai # noqa: E402
Expand Down Expand Up @@ -620,6 +621,15 @@ def my_writing_task(sender, recipient, context):
print(chat_results[1].summary, chat_results[1].cost)


def test_post_process_carryover_item():
gemini_carryover_item = {"content": "How can I help you?", "role": "model"}
assert (
_post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"]
), "Incorrect carryover postprocessing"
carryover_item = "How can I help you?"
assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing"


if __name__ == "__main__":
test_chats()
# test_chats_general()
Expand All @@ -628,3 +638,4 @@ def my_writing_task(sender, recipient, context):
# test_chats_w_func()
# test_chat_messages_for_summary()
# test_udf_message_in_chats()
test_post_process_carryover_item()
56 changes: 56 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,58 @@ def sample_function():
)


def test_process_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"


def test_process_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


def test_handle_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant"
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"


def test_handle_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


if __name__ == "__main__":
# test_trigger()
# test_context()
Expand All @@ -1473,6 +1525,10 @@ def sample_function():
# test_max_turn()
# test_process_before_send()
# test_message_func()

test_summary()
test_adding_duplicate_function_warning()
# test_function_registration_e2e_sync()

test_process_gemini_carryover()
test_process_carryover()
Loading
Loading