Skip to content

Commit

Permalink
Finish gemini, anthropic, llama3.1, and gpt-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunghan-wu committed Nov 19, 2024
1 parent 8cefdd1 commit 63e2193
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 61 deletions.
135 changes: 125 additions & 10 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests

from fastchat.utils import build_logger
from fastchat.tools import api_tools_loading


logger = build_logger("gradio_web_server", "gradio_web_server.log")
Expand All @@ -29,6 +30,7 @@ def get_api_provider_stream_iter(
# use our own streaming implementation for agent mode
if model_api_dict.get("agent-mode", False):
prompt = conv.to_openai_api_messages()
tools = api_tools_loading(tools, model_api_dict["api_type"])
stream_iter = openai_api_stream_iter_agent(
model_api_dict["model_name"],
prompt,
Expand Down Expand Up @@ -98,6 +100,7 @@ def get_api_provider_stream_iter(
elif model_api_dict["api_type"] == "anthropic_message":
if model_api_dict.get("agent-mode", False):
prompt = conv.to_openai_api_messages()
tools = api_tools_loading(tools, model_api_dict["api_type"])
stream_iter = anthropic_message_api_stream_iter_agent(
model_api_dict["model_name"],
prompt,
Expand Down Expand Up @@ -129,14 +132,26 @@ def get_api_provider_stream_iter(
)
elif model_api_dict["api_type"] == "gemini":
prompt = conv.to_gemini_api_messages()
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict["api_key"],
)
if model_api_dict.get("agent-mode", False):
tools = api_tools_loading(tools, model_api_dict["api_type"])
stream_iter = gemini_api_stream_iter_agent(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict["api_key"],
tools=tools
)
else:
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "gemini_no_stream":
prompt = conv.to_gemini_api_messages()
stream_iter = gemini_api_stream_iter(
Expand Down Expand Up @@ -185,6 +200,10 @@ def get_api_provider_stream_iter(
)
elif model_api_dict["api_type"] == "nvidia_llama31":
prompt = conv.to_openai_api_messages()
if model_api_dict.get("agent-mode", False):
tools = api_tools_loading(tools, model_api_dict["api_type"])
else:
tools = None
stream_iter = nvidia_llama31_api_stream_iter_agent(
model_api_dict["model_name"],
prompt,
Expand Down Expand Up @@ -885,8 +904,6 @@ def anthropic_message_api_stream_iter_agent(
messages = messages[1:]

# Convert it to the format that the API expects
if tools is None:
tools = []
res = client.messages.create(
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -1014,6 +1031,104 @@ def gemini_api_stream_iter(
}


def gemini_api_stream_iter_agent(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
use_stream=False,
tools=None,
):
assert use_stream == False, "Hasn't supported streaming for agent mode yet"
import google.generativeai as genai # pip install google-generativeai

if api_key is None:
api_key = os.environ["GEMINI_API_KEY"]
genai.configure(api_key=api_key)

generation_config = {
"temperature": temperature,
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": messages,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]

history = []
system_prompt = None
for message in messages[:-1]:
if message["role"] == "system":
system_prompt = message["content"]
continue
history.append({"role": message["role"], "parts": message["content"]})

model = genai.GenerativeModel(
model_name=model_name,
system_instruction=system_prompt,
generation_config=generation_config,
safety_settings=safety_settings,
)
convo = model.start_chat(history=history)


try:
if tools is None:
response = convo.send_message(messages[-1]["content"], stream=False)
text = response.candidates[0].content.parts[0].text
data = {
"function_name": None,
"arguments": None,
"text": text,
"error_code": 0,
}
yield data
else:
# Automatically predict function calls
tool_config = genai.protos.ToolConfig(
function_calling_config=genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.AUTO, # The default model behavior. The model decides whether to predict a function call or a natural language response.
)
)
response = convo.send_message(messages[-1]["content"], stream=False, tools=tools, tool_config=tool_config)
if "function_call" in response.candidates[0].content.parts[0]:
function_call = response.candidates[0].content.parts[0].function_call
function_name = function_call.name
arguments = function_call.args
text = f"\n\n**Function Call:** {function_name}\n**Arguments:** {arguments}"
else:
text = response.candidates[0].content.parts[0].text
function_name = None
arguments = None
data = {
"function_name": function_name,
"arguments": arguments,
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}.",
"error_code": 1,
"function_name": None,
"arguments": None
}


def ai2_api_stream_iter(
model_name,
model_id,
Expand Down
3 changes: 1 addition & 2 deletions fastchat/serve/gradio_web_server_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
parse_json_from_string,
)
from fastchat.tools.search import web_search
from fastchat.tools import general_tools_loading

logger = build_logger("gradio_web_server", "gradio_web_server.log")

Expand Down Expand Up @@ -537,7 +536,7 @@ def bot_response(
# Agent mode --> load tools first
tool_config_file = model_api_dict.get("tool_config_file", "")
try:
tools = general_tools_loading(tool_config_file, model_api_dict)
tools = json.load(open(tool_config_file))
except Exception as e:
conv.update_last_message(f"No tools are available for this model for agent mode. Provided tool_config_file {tool_config_file} is invalid.")
yield (
Expand Down
70 changes: 48 additions & 22 deletions fastchat/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,54 @@
# import all the tools
import json
from . import *


# Load the tools from the tool_config_file and convert them to the format required by the model API
def general_tools_loading(tool_config_file, model_api_dict):
tools = json.load(open(tool_config_file))

if model_api_dict['api_type'] in ['openai', 'nvidia_llama31']:
return tools
elif model_api_dict['api_type'] == 'anthropic_message':
return_tools = []
for tool in tools:
if tool.get('type') == 'function':
function_instance = tool.get('function')
new_tool = {
'name': function_instance.get('name'),
'description': function_instance.get('description'),
'input_schema': {}
}
for key, value in function_instance.get('parameters').items():
if key in ['type', 'properties', 'required']:
new_tool['input_schema'][key] = value
return_tools.append(new_tool)
return return_tools
def api_tools_loading(tools, api_type):
if tools is None:
if api_type in ['openai', 'nvidia_llama31']:
return None
elif api_type == 'anthropic_message':
return []
elif api_type == 'gemini':
return None
else:
raise ValueError(f"model_type {model_api_dict['model_type']} not supported")
# We use OpenAI's tools format as the default format
if api_type in ['openai', 'nvidia_llama31']:
return tools
elif api_type == 'anthropic_message':
return_tools = []
for tool in tools:
if tool.get('type') == 'function':
function_instance = tool.get('function')
new_tool = {
'name': function_instance.get('name'),
'description': function_instance.get('description'),
'input_schema': {}
}
for key, value in function_instance.get('parameters').items():
if key in ['type', 'properties', 'required']:
new_tool['input_schema'][key] = value
return_tools.append(new_tool)
return return_tools
elif api_type == 'gemini':
import google.generativeai as genai # pip install google-generativeai
return_tools = []
for tool in tools:
if tool.get('type') == 'function':
function_instance = tool.get('function')
function_name = function_instance.get('name')
description=function_instance.get('description')
parameters = function_instance.get('parameters')

parameters['type'] = genai.protos.Type[parameters['type'].upper()]
for prop in parameters['properties'].values():
prop['type'] = genai.protos.Type[prop['type'].upper()]
new_tool = genai.protos.FunctionDeclaration(
name=function_name,
description=description,
parameters=parameters
)
return_tools.append(new_tool)
return return_tools
else:
raise ValueError(f"model_type {api_type} not supported")
3 changes: 1 addition & 2 deletions fastchat/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def formulate_web_summary(results: List[Dict[str, Any]], query: str, topk: int =
for result in results:
search_summary += f"- [{result['title']}]({result['url']})\n"
# add the snippets to the summary
for snippet in result['text']:
search_summary += f" - {snippet}\n"
search_summary += f"Description: {result['text']}\n"
return search_summary

def web_search(key_words: str, topk: int) -> str:
Expand Down
Loading

0 comments on commit 63e2193

Please sign in to comment.