Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance vertexai integration (safety settings, authentication...) #3067

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order


def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
else:
return str(carryover_item)


def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()

Expand All @@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
Expand Down
3 changes: 2 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from openai import BadRequestError

from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import model_dump
Expand Down Expand Up @@ -2364,7 +2365,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
Expand Down
96 changes: 80 additions & 16 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from __future__ import annotations

import base64
import logging
import os
import random
import re
Expand All @@ -45,13 +46,19 @@
import vertexai
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

logger = logging.getLogger(__name__)


class GeminiClient:
Expand All @@ -77,27 +84,35 @@ def _initialize_vartexai(self, **params):
# Path to JSON Keyfile
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
vertexai_init_args = {}
if "project_id" in params:
vertexai_init_args["project"] = params["project_id"]
if "project" in params:
vertexai_init_args["project"] = params["project"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
if "credentials" in params:
assert isinstance(
params["credentials"], Credentials
), "Object type google.auth.credentials.Credentials is expected!"
vertexai_init_args["credentials"] = params["credentials"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)

def __init__(self, **kwargs):
"""Uses either either api_key for authentication from the LLM config
(specifying the GOOGLE_API_KEY environment variable also works),
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
where project_id and location can also be passed as parameters. Service account key file can also be used.
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
where project and location can also be passed as parameters. Previously created credentials object can be provided,
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
like in Google Cloud Shell.

Args:
api_key (str): The API key for using Gemini.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.

project (str): Google Cloud project, which is only valid in case no API key is specified.
location (str): Compute region to be used, like 'us-west1'.
This parameter is only valid in case no API key is specified.
"""
Expand All @@ -112,7 +127,7 @@ def __init__(self, **kwargs):
else:
self.use_vertexai = False
if not self.use_vertexai:
assert ("project_id" not in kwargs) and (
assert ("project" not in kwargs) and (
"location" not in kwargs
), "Google Cloud project and compute location cannot be set when using an API Key!"

Expand Down Expand Up @@ -144,7 +159,7 @@ def create(self, params: Dict) -> ChatCompletion:
if self.use_vertexai:
self._initialize_vartexai(**params)
else:
assert ("project_id" not in params) and (
assert ("project" not in params) and (
"location" not in params
), "Google Cloud project and compute location cannot be set when using an API Key!"
model_name = params.get("model", "gemini-pro")
Expand All @@ -159,13 +174,18 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
response_validation = params.get("response_validation", True)

generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
else:
safety_settings = params.get("safety_settings", {})

if stream:
warnings.warn(
Expand All @@ -181,20 +201,29 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1], stream=stream)
response = chat.send_message(
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand All @@ -218,16 +247,22 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
# response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
Expand Down Expand Up @@ -372,6 +407,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li

return rst

@staticmethod
def _to_vertexai_safety_settings(safety_settings):
"""Convert safety settings to VertexAI format if needed,
like when specifying them in the OAI_CONFIG_LIST
"""
if isinstance(safety_settings, list) and all(
[
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
for safety_setting in safety_settings
]
):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings
else:
return safety_settings


def _to_pil(data: str) -> Image.Image:
"""
Expand Down
10 changes: 9 additions & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from openai.types.beta.assistant import Assistant
from packaging.version import parse

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
NON_CACHE_KEY = [
"api_key",
"base_url",
"api_type",
"api_version",
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
Expand Down
Loading
Loading