-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Finish gemini, anthropic, llama3.1, and gpt-4o
- Loading branch information
1 parent
8cefdd1
commit 63e2193
Showing
6 changed files
with
234 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,54 @@ | ||
# import all the tools | ||
import json | ||
from . import * | ||
|
||
|
||
# Load the tools from the tool_config_file and convert them to the format required by the model API | ||
def general_tools_loading(tool_config_file, model_api_dict): | ||
tools = json.load(open(tool_config_file)) | ||
|
||
if model_api_dict['api_type'] in ['openai', 'nvidia_llama31']: | ||
return tools | ||
elif model_api_dict['api_type'] == 'anthropic_message': | ||
return_tools = [] | ||
for tool in tools: | ||
if tool.get('type') == 'function': | ||
function_instance = tool.get('function') | ||
new_tool = { | ||
'name': function_instance.get('name'), | ||
'description': function_instance.get('description'), | ||
'input_schema': {} | ||
} | ||
for key, value in function_instance.get('parameters').items(): | ||
if key in ['type', 'properties', 'required']: | ||
new_tool['input_schema'][key] = value | ||
return_tools.append(new_tool) | ||
return return_tools | ||
def api_tools_loading(tools, api_type): | ||
if tools is None: | ||
if api_type in ['openai', 'nvidia_llama31']: | ||
return None | ||
elif api_type == 'anthropic_message': | ||
return [] | ||
elif api_type == 'gemini': | ||
return None | ||
else: | ||
raise ValueError(f"model_type {model_api_dict['model_type']} not supported") | ||
# We use OpenAI's tools format as the default format | ||
if api_type in ['openai', 'nvidia_llama31']: | ||
return tools | ||
elif api_type == 'anthropic_message': | ||
return_tools = [] | ||
for tool in tools: | ||
if tool.get('type') == 'function': | ||
function_instance = tool.get('function') | ||
new_tool = { | ||
'name': function_instance.get('name'), | ||
'description': function_instance.get('description'), | ||
'input_schema': {} | ||
} | ||
for key, value in function_instance.get('parameters').items(): | ||
if key in ['type', 'properties', 'required']: | ||
new_tool['input_schema'][key] = value | ||
return_tools.append(new_tool) | ||
return return_tools | ||
elif api_type == 'gemini': | ||
import google.generativeai as genai # pip install google-generativeai | ||
return_tools = [] | ||
for tool in tools: | ||
if tool.get('type') == 'function': | ||
function_instance = tool.get('function') | ||
function_name = function_instance.get('name') | ||
description=function_instance.get('description') | ||
parameters = function_instance.get('parameters') | ||
|
||
parameters['type'] = genai.protos.Type[parameters['type'].upper()] | ||
for prop in parameters['properties'].values(): | ||
prop['type'] = genai.protos.Type[prop['type'].upper()] | ||
new_tool = genai.protos.FunctionDeclaration( | ||
name=function_name, | ||
description=description, | ||
parameters=parameters | ||
) | ||
return_tools.append(new_tool) | ||
return return_tools | ||
else: | ||
raise ValueError(f"model_type {api_type} not supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.