Skip to content

Commit

Permalink
feat: add gemini support (#953)
Browse files Browse the repository at this point in the history
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
  • Loading branch information
yihong0618 and fangyinc authored Dec 23, 2023
1 parent e1ace14 commit 12234ae
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 42 deletions.
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
- [x] [智谱·ChatGLM](http://open.bigmodel.cn/)
- [x] [讯飞·星火](https://xinghuo.xfyun.cn/)
- [x] [Google·Bard](https://bard.google.com/)
- [x] [Google·Gemini](https://makersuite.google.com/app/apikey)

- **隐私安全**

Expand Down
10 changes: 9 additions & 1 deletion dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self) -> None:
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
"ZHIPU_MODEL_VERSION", "chatglm_pro"
"ZHIPU_MODEL_VERSION"
)

# wenxin
Expand Down Expand Up @@ -95,6 +95,14 @@ def __init__(self) -> None:
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version

# gemini proxy
self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY")
if self.gemini_proxy_api_key:
os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key
os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv(
"GEMINI_MODEL_VERSION", "gemini-pro"
)

self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
Expand Down
1 change: 1 addition & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_device() -> str:
"wenxin_proxyllm": "wenxin_proxyllm",
"tongyi_proxyllm": "tongyi_proxyllm",
"zhipu_proxyllm": "zhipu_proxyllm",
"gemini_proxyllm": "gemini_proxyllm",
"bc_proxyllm": "bc_proxyllm",
"spark_proxyllm": "spark_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
Expand Down
62 changes: 54 additions & 8 deletions dbgpt/core/interface/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,65 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]


def _parse_model_messages(
def parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
"""
Parameters:
messages: List of message from base chat.
Parse model messages to extract the user prompt, system messages, and a history of conversation.
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
pairs of human and AI messages.
Args:
messages (List[ModelMessage]): List of messages from a chat conversation.
Returns:
A tuple contains user prompt, system message list and history message list
str: user prompt
List[str]: system messages
List[List[str]]: history message of user and assistant
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
Examples:
.. code-block:: python
# Example 1: Single round of conversation
messages = [
ModelMessage(role="human", content="Hello"),
ModelMessage(role="ai", content="Hi there!"),
ModelMessage(role="human", content="How are you?"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "How are you?"
# system_messages: []
# history: [["Hello", "Hi there!"]]
# Example 2: Conversation with system messages
messages = [
ModelMessage(role="system", content="System initializing..."),
ModelMessage(role="human", content="Is it sunny today?"),
ModelMessage(role="ai", content="Yes, it's sunny."),
ModelMessage(role="human", content="Great!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Great!"
# system_messages: ["System initializing..."]
# history: [["Is it sunny today?", "Yes, it's sunny."]]
# Example 3: Multiple rounds with system message
messages = [
ModelMessage(role="human", content="Hi"),
ModelMessage(role="ai", content="Hello!"),
ModelMessage(role="system", content="Error 404"),
ModelMessage(role="human", content="What's the error?"),
ModelMessage(role="ai", content="Just a joke."),
ModelMessage(role="human", content="Funny!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Funny!"
# system_messages: ["Error 404"]
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
"""
user_prompt = ""

system_messages: List[str] = []
history_messages: List[List[str]] = [[]]

Expand Down
65 changes: 65 additions & 0 deletions dbgpt/core/interface/tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,71 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
assert isinstance(new_conversation.messages[1], AIMessage)


def test_parse_model_messages_no_history_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Hello"
assert system_messages == []
assert history_messages == []


def test_parse_model_messages_single_round_conversation():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi there!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello again"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Hello again"
assert system_messages == []
assert history_messages == [["Hello", "Hi there!"]]


def test_parse_model_messages_two_round_conversation_with_system_message():
messages = [
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content="System initializing..."
),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How's the weather?"),
ModelMessage(role=ModelMessageRoleType.AI, content="It's sunny!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Great to hear!"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Great to hear!"
assert system_messages == ["System initializing..."]
assert history_messages == [["How's the weather?", "It's sunny!"]]


def test_parse_model_messages_three_round_conversation():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="What's up?"),
ModelMessage(role=ModelMessageRoleType.AI, content="Not much, you?"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Same here."),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Same here."
assert system_messages == []
assert history_messages == [["Hi", "Hello!"], ["What's up?", "Not much, you?"]]


def test_parse_model_messages_multiple_system_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System start"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hey"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System check"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How are you?"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "How are you?"
assert system_messages == ["System start", "System check"]
assert history_messages == [["Hey", "Hello!"]]


def test_to_openai_messages(
human_model_message, ai_model_message, system_model_message
):
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/model/llm_out/proxy_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
from dbgpt.model.proxy.llms.spark import spark_generate_stream
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
Expand All @@ -25,6 +26,7 @@ def proxyllm_generate_stream(
"wenxin_proxyllm": wenxin_generate_stream,
"tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream,
"gemini_proxyllm": gemini_generate_stream,
"bc_proxyllm": baichuan_generate_stream,
"spark_proxyllm": spark_generate_stream,
}
Expand Down
109 changes: 109 additions & 0 deletions dbgpt/model/proxy/llms/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List, Tuple, Dict, Any

from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core.interface.message import ModelMessage, parse_model_messages

GEMINI_DEFAULT_MODEL = "gemini-pro"


def gemini_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
global history

# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend

generation_config = {
"temperature": 0.7,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]

import google.generativeai as genai

if model_params.proxy_api_base:
from google.api_core import client_options

client_opts = client_options.ClientOptions(
api_endpoint=model_params.proxy_api_base
)
genai.configure(
api_key=proxy_api_key, transport="rest", client_options=client_opts
)
else:
genai.configure(api_key=proxy_api_key)
model = genai.GenerativeModel(
model_name=proxyllm_backend,
generation_config=generation_config,
safety_settings=safety_settings,
)
messages: List[ModelMessage] = params["messages"]
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
chat = model.start_chat(history=gemini_hist)
response = chat.send_message(user_prompt, stream=True)
text = ""
for chunk in response:
text += chunk.text
print(text)
yield text


def _transform_to_gemini_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[Dict[str, Any]]]:
"""Transform messages to gemini format
See https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
Args:
messages (List[ModelMessage]): messages
Returns:
Tuple[str, List[Dict[str, Any]]]: user_prompt, gemini_hist
Examples:
.. code-block:: python
messages = [
ModelMessage(role="human", content="Hello"),
ModelMessage(role="ai", content="Hi there!"),
ModelMessage(role="human", content="How are you?"),
]
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
assert user_prompt == "How are you?"
assert gemini_hist == [
{"role": "user", "parts": {"text": "Hello"}},
{"role": "model", "parts": {"text": "Hi there!"}}
]
"""
user_prompt, system_messages, history_messages = parse_model_messages(messages)
if system_messages:
user_prompt = "".join(system_messages) + "\n" + user_prompt
gemini_hist = []
if history_messages:
for user_message, model_message in history_messages:
gemini_hist.append({"role": "user", "parts": {"text": user_message}})
gemini_hist.append({"role": "model", "parts": {"text": model_message}})
return user_prompt, gemini_hist
35 changes: 2 additions & 33 deletions dbgpt/model/proxy/llms/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CHATGLM_DEFAULT_MODEL = "chatglm_pro"


def __convert_2_wenxin_messages(messages: List[ModelMessage]):
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []

Expand Down Expand Up @@ -57,38 +57,7 @@ def zhipu_generate_stream(
zhipuai.api_key = proxy_api_key

messages: List[ModelMessage] = params["messages"]
# Add history conversation
# system = ""
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
# role_define = messages.pop(0)
# system = role_define.content
# else:
# message = messages.pop(0)
# if message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
# for message in messages:
# if message.role == ModelMessageRoleType.SYSTEM:
# history.append({"role": "user", "content": message.content})
# # elif message.role == ModelMessageRoleType.HUMAN:
# # history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.AI:
# history.append({"role": "assistant", "content": message.content})
# else:
# pass
#
# # temp_his = history[::-1]
# temp_his = history
# last_user_input = None
# for m in temp_his:
# if m["role"] == "user":
# last_user_input = m
# break
#
# if last_user_input:
# history.remove(last_user_input)
# history.append(last_user_input)

history, systems = __convert_2_wenxin_messages(messages)
history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,
Expand Down

0 comments on commit 12234ae

Please sign in to comment.