From 0920f61a402c49839578348fdbd689d1443b9f16 Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Fri, 29 Dec 2023 17:10:25 +0000 Subject: [PATCH 1/6] add mistral and refactor --- fastchat/model/model_registry.py | 2 +- fastchat/serve/api_provider.py | 42 +++++++ fastchat/serve/gradio_block_arena_anony.py | 17 ++- fastchat/serve/gradio_web_server.py | 137 ++++++++------------- fastchat/serve/gradio_web_server_multi.py | 87 ++----------- fastchat/utils.py | 3 + 6 files changed, 120 insertions(+), 168 deletions(-) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 1c244b50a..fb07d3d8e 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -29,7 +29,7 @@ def get_model_info(name: str) -> ModelInfo: register_model_info( - ["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"], + ["mistral-medium", "mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"], "Mixtral of experts", "https://mistral.ai/news/mixtral-of-experts/", "A Mixture-of-Experts model by Mistral AI", diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index 6f3ad7dda..acd7ba5de 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -239,3 +239,45 @@ def ai2_api_stream_iter( "error_code": 0, } yield data + + +def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage + + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key) + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + new_messages = [] + for message in messages: + new_messages.append(ChatMessage(role=message["role"], content=message["content"])) + new_messages = [ChatMessage(role=message["role"], content=message["content"]) for message in messages] + + res = client.chat_stream( + model=model_name, + temperature=temperature, + messages=new_messages, + max_tokens=max_new_tokens, + top_p=top_p, + ) + + text = "" + for chunk in res: + if chunk.choices[0].delta.content is not None: + text += chunk.choices[0].delta.content + data = { + "text": text, + "error_code": 0, + } + yield data diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 6797290ba..44128fb94 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -162,6 +162,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re # tier 0 "gpt-4": 4, "gpt-4-0314": 4, + "gpt-4-0613": 4, "gpt-4-turbo": 4, "gpt-3.5-turbo-0613": 2, "gpt-3.5-turbo-1106": 2, @@ -174,6 +175,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re "pplx-70b-online": 4, "solar-10.7b-instruct-v1.0": 2, "mixtral-8x7b-instruct-v0.1": 4, + "mistral-medium": 8, "openhermes-2.5-mistral-7b": 2, "dolphin-2.2.1-mistral-7b": 2, "wizardlm-70b": 2, @@ -235,6 +237,12 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re "gpt-3.5-turbo-0613", "llama-2-70b-chat", }, + "mistral-medium": { + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0613", + "gpt-4-turbo", + "mixtral-8x7b-instruct-v0.1", + }, "mixtral-8x7b-instruct-v0.1": { "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0613", @@ -292,15 +300,16 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re # "tulu-2-dpo-70b", # "yi-34b-chat", "claude-2.1", - "claude-1", + # "claude-1", "gpt-4-0613", # "gpt-3.5-turbo-1106", # "gpt-4-0314", "gpt-4-turbo", # "dolphin-2.2.1-mistral-7b", - "mixtral-8x7b-instruct-v0.1", - "gemini-pro", - "solar-10.7b-instruct-v1.0", + # "mixtral-8x7b-instruct-v0.1", + "mistral-medium", + # "gemini-pro", + # "solar-10.7b-instruct-v1.0", ] # outage models won't be sampled. diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 9642ce1b4..c0d1a488f 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -36,6 +36,7 @@ anthropic_api_stream_iter, openai_api_stream_iter, palm_api_stream_iter, + mistral_api_stream_iter, init_palm_chat, ) from fastchat.utils import ( @@ -121,9 +122,7 @@ def get_conv_log_filename(): return name -def get_model_list( - controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm -): +def get_model_list(controller_url, register_openai_compatible_models): if controller_url: ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 @@ -140,27 +139,21 @@ def get_model_list( ) models += list(openai_compatible_models_info.keys()) - if add_chatgpt: - models += [ - "gpt-4-0314", - "gpt-4-0613", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - ] - if add_claude: - models += ["claude-2.1", "claude-2.0", "claude-instant-1"] - if add_palm: - models += ["gemini-pro"] models = list(set(models)) - - hidden_models = ["gpt-4-0314", "gpt-4-0613"] - for hm in hidden_models: - del models[models.index(hm)] + visible_models = models.copy() + for mdl in visible_models: + if mdl not in openai_compatible_models_info: + continue + mdl_dict = openai_compatible_models_info[mdl] + if mdl_dict["anony_only"]: + visible_models.remove(mdl) priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} models.sort(key=lambda x: priority.get(x, x)) - logger.info(f"Models: {models}") - return models + visible_models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"All models: {models}") + logger.info(f"Visible models: {visible_models}") + return visible_models, models def load_demo_single(models, url_params): @@ -186,12 +179,9 @@ def load_demo(url_params, request: gr.Request): ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME if args.model_list_mode == "reload": - models = get_model_list( + models, all_models = get_model_list( controller_url, args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, ) return load_demo_single(models, url_params) @@ -376,49 +366,10 @@ def bot_response( return conv, model_name = state.conv, state.model_name - if model_name in openai_compatible_models_info: - model_info = openai_compatible_models_info[model_name] - prompt = conv.to_openai_api_messages() - stream_iter = openai_api_stream_iter( - model_info["model_name"], - prompt, - temperature, - top_p, - max_new_tokens, - api_base=model_info["api_base"], - api_key=model_info["api_key"], - ) - elif model_name in [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-turbo", - ]: - # avoid conflict with Azure OpenAI - assert model_name not in openai_compatible_models_info - prompt = conv.to_openai_api_messages() - stream_iter = openai_api_stream_iter( - model_name, prompt, temperature, top_p, max_new_tokens - ) - elif model_name in ANTHROPIC_MODEL_LIST: - prompt = conv.get_prompt() - stream_iter = anthropic_api_stream_iter( - model_name, prompt, temperature, top_p, max_new_tokens - ) - elif model_name in ["palm-2", "gemini-pro"]: - stream_iter = palm_api_stream_iter( - model_name, - state.palm_chat, - conv.messages[-2][1], - temperature, - top_p, - max_new_tokens, - ) - else: + model_api_dict = (openai_compatible_models_info[model_name] + if model_name in openai_compatible_models_info else None) + + if model_api_dict is None: # Query worker address ret = requests.post( controller_url + "/get_worker_address", json={"model": model_name} @@ -460,6 +411,38 @@ def bot_response( top_p, max_new_tokens, ) + elif model_api_dict["api_type"] == "openai": + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "anthropic": + prompt = conv.get_prompt() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "palm": + stream_iter = palm_api_stream_iter( + model_name, + state.palm_chat, + conv.messages[-2][1], + temperature, + top_p, + max_new_tokens, + ) + elif model_api_dict["api_type"] == "mistral": + prompt = conv.to_openai_api_messages() + stream_iter = mistral_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + else: + raise NotImplementedError conv.update_last_message("▌") yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 @@ -860,21 +843,6 @@ def build_demo(models): action="store_true", help="Shows term of use before loading the demo", ) - parser.add_argument( - "--add-chatgpt", - action="store_true", - help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", - ) - parser.add_argument( - "--add-claude", - action="store_true", - help="Add Anthropic's Claude models (claude-2, claude-instant-1)", - ) - parser.add_argument( - "--add-palm", - action="store_true", - help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", - ) parser.add_argument( "--register-openai-compatible-models", type=str, @@ -895,12 +863,9 @@ def build_demo(models): # Set global variables set_global_vars(args.controller_url, args.moderate) - models = get_model_list( + models, all_models = get_model_list( args.controller_url, args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, ) # Set authorization credentials diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index 0009c02ad..a2304c55b 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -44,7 +44,7 @@ def load_demo(url_params, request: gr.Request): - global models + global models, all_models ip = get_ip(request) logger.info(f"load_demo. ip: {ip}. params: {url_params}") @@ -61,49 +61,14 @@ def load_demo(url_params, request: gr.Request): selected = 3 if args.model_list_mode == "reload": - if args.anony_only_for_proprietary_model: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - False, - False, - False, - ) - else: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, - ) + models, all_models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + ) single_updates = load_demo_single(models, url_params) - models_anony = list(models) - if args.anony_only_for_proprietary_model: - # Only enable these models in anony battles. - if args.add_chatgpt: - models_anony += [ - "gpt-4-0314", - "gpt-4-0613", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - ] - if args.add_claude: - models_anony += ["claude-2.1", "claude-2.0", "claude-1", "claude-instant-1"] - if args.add_palm: - models_anony += ["gemini-pro"] - anony_only_models = [ - "claude-1", - "gpt-4-0314", - "gpt-4-0613", - ] - for mdl in anony_only_models: - models_anony.append(mdl) - models_anony = list(set(models_anony)) - - side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params) + side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) return ( (gr.Tabs.update(selected=selected),) @@ -198,26 +163,6 @@ def build_demo(models, elo_results_file, leaderboard_table_file): action="store_true", help="Shows term of use before loading the demo", ) - parser.add_argument( - "--add-chatgpt", - action="store_true", - help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", - ) - parser.add_argument( - "--add-claude", - action="store_true", - help="Add Anthropic's Claude models (claude-2, claude-instant-1)", - ) - parser.add_argument( - "--add-palm", - action="store_true", - help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", - ) - parser.add_argument( - "--anony-only-for-proprietary-model", - action="store_true", - help="Only add ChatGPT, Claude, Bard under anony battle tab", - ) parser.add_argument( "--register-openai-compatible-models", type=str, @@ -247,22 +192,10 @@ def build_demo(models, elo_results_file, leaderboard_table_file): set_global_vars(args.controller_url, args.moderate) set_global_vars_named(args.moderate) set_global_vars_anony(args.moderate) - if args.anony_only_for_proprietary_model: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - False, - False, - False, - ) - else: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, - ) + models, all_models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + ) # Set authorization credentials auth = None diff --git a/fastchat/utils.py b/fastchat/utils.py index 01f1f4783..053d84aa1 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -57,6 +57,9 @@ def build_logger(logger_name, logger_filename): logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) + # Avoid httpx flooding POST logs + logging.getLogger('httpx').setLevel(logging.WARNING) + # if LOGDIR is empty, then don't try output log to local file if LOGDIR != "": os.makedirs(LOGDIR, exist_ok=True) From 3365cbc0fd6507f185af294741c263062372c8cc Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Fri, 29 Dec 2023 17:12:03 +0000 Subject: [PATCH 2/6] format --- fastchat/serve/api_provider.py | 8 ++++---- fastchat/serve/gradio_web_server.py | 7 +++++-- fastchat/utils.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index acd7ba5de..d88011113 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -259,10 +259,10 @@ def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_to } logger.info(f"==== request ====\n{gen_params}") - new_messages = [] - for message in messages: - new_messages.append(ChatMessage(role=message["role"], content=message["content"])) - new_messages = [ChatMessage(role=message["role"], content=message["content"]) for message in messages] + new_messages = [ + ChatMessage(role=message["role"], content=message["content"]) + for message in messages + ] res = client.chat_stream( model=model_name, diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index c0d1a488f..4b72c87e2 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -366,8 +366,11 @@ def bot_response( return conv, model_name = state.conv, state.model_name - model_api_dict = (openai_compatible_models_info[model_name] - if model_name in openai_compatible_models_info else None) + model_api_dict = ( + openai_compatible_models_info[model_name] + if model_name in openai_compatible_models_info + else None + ) if model_api_dict is None: # Query worker address diff --git a/fastchat/utils.py b/fastchat/utils.py index 053d84aa1..c7feb353a 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -58,7 +58,7 @@ def build_logger(logger_name, logger_filename): logger.setLevel(logging.INFO) # Avoid httpx flooding POST logs - logging.getLogger('httpx').setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) # if LOGDIR is empty, then don't try output log to local file if LOGDIR != "": From 8b60a00006146546b86a1c6d44fff12cb0d4829c Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Wed, 3 Jan 2024 13:16:08 +0000 Subject: [PATCH 3/6] fix --- fastchat/model/model_registry.py | 2 +- fastchat/serve/gradio_block_arena_anony.py | 1 - fastchat/serve/gradio_block_arena_named.py | 1 - fastchat/serve/gradio_web_server.py | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index fb07d3d8e..858c1445e 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -29,7 +29,7 @@ def get_model_info(name: str) -> ModelInfo: register_model_info( - ["mistral-medium", "mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"], + ["mixtral-8x7b-instruct-v0.1", "mistral-medium", "mistral-7b-instruct"], "Mixtral of experts", "https://mistral.ai/news/mixtral-of-experts/", "A Mixture-of-Experts model by Mistral AI", diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 44128fb94..4dc924300 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -553,7 +553,6 @@ def build_side_by_side_ui_anony(models): textbox = gr.Textbox( show_label=False, placeholder="👉 Enter your prompt and press ENTER", - container=False, elem_id="input_box", ) send_btn = gr.Button(value="Send", variant="primary", scale=0) diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index 60823f06b..861112627 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -326,7 +326,6 @@ def build_side_by_side_ui_named(models): textbox = gr.Textbox( show_label=False, placeholder="👉 Enter your prompt and press ENTER", - container=False, elem_id="input_box", ) send_btn = gr.Button(value="Send", variant="primary", scale=0) diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 4b72c87e2..7c63a609d 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -690,7 +690,6 @@ def build_single_model_ui(models, add_promotion_links=False): textbox = gr.Textbox( show_label=False, placeholder="👉 Enter your prompt and press ENTER", - container=False, elem_id="input_box", ) send_btn = gr.Button(value="Send", variant="primary", scale=0) From 14b4ae7e4d8bc8f6e207ae2f5ac2909cfe3daa6b Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Wed, 3 Jan 2024 13:33:14 +0000 Subject: [PATCH 4/6] rename --- fastchat/serve/gradio_web_server.py | 37 +++++++++++------------ fastchat/serve/gradio_web_server_multi.py | 8 ++--- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 7c63a609d..802006be3 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -74,16 +74,17 @@ ip_expiration_dict = defaultdict(lambda: 0) -# Information about custom OpenAI compatible API models. -# JSON file format: +# JSON file format of API-based models: # { # "vicuna-7b": { # "model_name": "vicuna-7b-v1.5", # "api_base": "http://8.8.8.55:5555/v1", -# "api_key": "password" +# "api_key": "password", +# "api_type": "openai", # openai, anthropic, palm, mistral +# "anony_only": false, # whether to show this model in anonymous mode only # }, # } -openai_compatible_models_info = {} +api_endpoint_info = {} class State: @@ -122,7 +123,7 @@ def get_conv_log_filename(): return name -def get_model_list(controller_url, register_openai_compatible_models): +def get_model_list(controller_url, register_api_endpoint_file): if controller_url: ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 @@ -132,19 +133,17 @@ def get_model_list(controller_url, register_openai_compatible_models): models = [] # Add API providers - if register_openai_compatible_models: - global openai_compatible_models_info - openai_compatible_models_info = json.load( - open(register_openai_compatible_models) - ) - models += list(openai_compatible_models_info.keys()) + if register_api_endpoint_file: + global api_endpoint_info + api_endpoint_info = json.load(open(register_api_endpoint_file)) + models += list(api_endpoint_info.keys()) models = list(set(models)) visible_models = models.copy() for mdl in visible_models: - if mdl not in openai_compatible_models_info: + if mdl not in api_endpoint_info: continue - mdl_dict = openai_compatible_models_info[mdl] + mdl_dict = api_endpoint_info[mdl] if mdl_dict["anony_only"]: visible_models.remove(mdl) @@ -181,7 +180,7 @@ def load_demo(url_params, request: gr.Request): if args.model_list_mode == "reload": models, all_models = get_model_list( controller_url, - args.register_openai_compatible_models, + args.register_api_endpoint_file, ) return load_demo_single(models, url_params) @@ -367,9 +366,7 @@ def bot_response( conv, model_name = state.conv, state.model_name model_api_dict = ( - openai_compatible_models_info[model_name] - if model_name in openai_compatible_models_info - else None + api_endpoint_info[model_name] if model_name in api_endpoint_info else None ) if model_api_dict is None: @@ -846,9 +843,9 @@ def build_demo(models): help="Shows term of use before loading the demo", ) parser.add_argument( - "--register-openai-compatible-models", + "--register-api-endpoint-file", type=str, - help="Register custom OpenAI API compatible models by loading them from a JSON file", + help="Register API-based model endpoints from a JSON file", ) parser.add_argument( "--gradio-auth-path", @@ -867,7 +864,7 @@ def build_demo(models): set_global_vars(args.controller_url, args.moderate) models, all_models = get_model_list( args.controller_url, - args.register_openai_compatible_models, + args.register_api_endpoint_file, ) # Set authorization credentials diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index a2304c55b..b97588fe8 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -63,7 +63,7 @@ def load_demo(url_params, request: gr.Request): if args.model_list_mode == "reload": models, all_models = get_model_list( args.controller_url, - args.register_openai_compatible_models, + args.register_api_endpoint_file, ) single_updates = load_demo_single(models, url_params) @@ -164,9 +164,9 @@ def build_demo(models, elo_results_file, leaderboard_table_file): help="Shows term of use before loading the demo", ) parser.add_argument( - "--register-openai-compatible-models", + "--register-api-endpoint-file", type=str, - help="Register custom OpenAI API compatible models by loading them from a JSON file", + help="Register API-based model endpoints from a JSON file", ) parser.add_argument( "--gradio-auth-path", @@ -194,7 +194,7 @@ def build_demo(models, elo_results_file, leaderboard_table_file): set_global_vars_anony(args.moderate) models, all_models = get_model_list( args.controller_url, - args.register_openai_compatible_models, + args.register_api_endpoint_file, ) # Set authorization credentials From 939c115bee5dea01ea348fcfb5cda1f3b4f130f7 Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Wed, 3 Jan 2024 18:26:19 +0000 Subject: [PATCH 5/6] add gemini dev api --- fastchat/conversation.py | 9 ++++++ fastchat/model/model_adapter.py | 4 +-- fastchat/model/model_registry.py | 2 +- fastchat/serve/api_provider.py | 48 +++++++++++++++++++++++++++++ fastchat/serve/gradio_web_server.py | 5 +++ 5 files changed, 65 insertions(+), 3 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index ef6e316d1..40082b7c5 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -789,6 +789,15 @@ def get_conv_template(name: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="gemini", + roles=("user", "model"), + sep_style=None, + sep=None, + ) +) + # BiLLa default template register_conv_template( Conversation( diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 6578f8441..24bc8f4cd 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1126,13 +1126,13 @@ class GeminiAdapter(BaseModelAdapter): """The model adapter for Gemini""" def match(self, model_path: str): - return model_path in ["gemini-pro"] + return model_path in ["gemini-pro", "gemini-pro-dev-api"] def load_model(self, model_path: str, from_pretrained_kwargs: dict): raise NotImplementedError() def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("bard") + return get_conv_template("gemini") class BiLLaAdapter(BaseModelAdapter): diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 858c1445e..5dd2b13a6 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -36,7 +36,7 @@ def get_model_info(name: str) -> ModelInfo: ) register_model_info( - ["gemini-pro"], + ["gemini-pro", "gemini-pro-dev-api"], "Gemini", "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", "Gemini by Google", diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index d88011113..8f1dbf89c 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -167,6 +167,54 @@ def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_ yield data +def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens): + import google.generativeai as genai # pip install google-generativeai + + genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + + generation_config = { + "temperature": temperature, + "max_output_tokens": max_new_tokens, + "top_p": top_p, + } + + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + }, + ] + model = genai.GenerativeModel(model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings) + history = [] + for role, message in conv.messages[:-2]: + history.append({"role": role, "parts": message}) + convo = model.start_chat(history=history) + response = convo.send_message(conv.messages[-2][1], stream=True) + + text = "" + for chunk in response: + text += chunk.text + data = { + "text": text, + "error_code": 0, + } + yield data + + def ai2_api_stream_iter( model_name, messages, diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 802006be3..ea0cc7585 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -36,6 +36,7 @@ anthropic_api_stream_iter, openai_api_stream_iter, palm_api_stream_iter, + gemini_api_stream_iter, mistral_api_stream_iter, init_palm_chat, ) @@ -436,6 +437,10 @@ def bot_response( top_p, max_new_tokens, ) + elif model_api_dict["api_type"] == "gemini": + stream_iter = gemini_api_stream_iter( + model_api_dict["model_name"], conv, temperature, top_p, max_new_tokens + ) elif model_api_dict["api_type"] == "mistral": prompt = conv.to_openai_api_messages() stream_iter = mistral_api_stream_iter( From bda00f5f4b61271441b396390a6ae0192dd2c016 Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Wed, 3 Jan 2024 18:26:56 +0000 Subject: [PATCH 6/6] format --- fastchat/serve/api_provider.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index 8f1dbf89c..b519ada22 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -179,26 +179,16 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens) } safety_settings = [ - { - "category": "HARM_CATEGORY_HARASSMENT", - "threshold": "BLOCK_NONE" - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "threshold": "BLOCK_NONE" - }, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_NONE" - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_NONE" - }, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, ] - model = genai.GenerativeModel(model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings) + model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings, + ) history = [] for role, message in conv.messages[:-2]: history.append({"role": role, "parts": message})